Skip to content

Commit 868cc4c

Browse files
committed
revisiting gracefulClose with STM racing
1 parent 287f2a9 commit 868cc4c

File tree

1 file changed

+71
-17
lines changed

1 file changed

+71
-17
lines changed

Network/Socket/Shutdown.hs

Lines changed: 71 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,19 @@ module Network.Socket.Shutdown (
88
, gracefulClose
99
) where
1010

11+
import Control.Concurrent (threadDelay, yield)
1112
import qualified Control.Exception as E
1213
import Foreign.Marshal.Alloc (mallocBytes, free)
1314

14-
import Control.Concurrent (threadDelay, yield)
15+
#if !defined(mingw32_HOST_OS)
16+
import Control.Concurrent.STM
17+
import qualified GHC.Event as Ev
18+
#endif
1519

1620
import Network.Socket.Buffer
1721
import Network.Socket.Imports
1822
import Network.Socket.Internal
23+
import Network.Socket.STM
1924
import Network.Socket.Types
2025

2126
data ShutdownCmd = ShutdownReceive
@@ -59,19 +64,68 @@ gracefulClose s tmout0 = sendRecvFIN `E.finally` close s
5964
-- FIN arrives meanwhile.
6065
yield
6166
-- Waiting TCP FIN.
62-
E.bracket (mallocBytes bufSize) free recvEOFloop
63-
recvEOFloop buf = loop 1 0
64-
where
65-
loop delay tmout = do
66-
-- We don't check the (positive) length.
67-
-- In normal case, it's 0. That is, only FIN is received.
68-
-- In error cases, data is available. But there is no
69-
-- application which can read it. So, let's stop receiving
70-
-- to prevent attacks.
71-
r <- recvBufNoWait s buf bufSize
72-
when (r == -1 && tmout < tmout0) $ do
73-
threadDelay (delay * 1000)
74-
loop (delay * 2) (tmout + delay)
75-
-- Don't use 4092 here. The GHC runtime takes the global lock
76-
-- if the length is over 3276 bytes in 32bit or 3272 bytes in 64bit.
77-
bufSize = 1024
67+
E.bracket (mallocBytes bufSize) free (recvEOF s tmout0)
68+
69+
recvEOF :: Socket -> Int -> Ptr Word8 -> IO ()
70+
#if !defined(mingw32_HOST_OS)
71+
recvEOF s tmout0 buf = do
72+
mevmgr <- Ev.getSystemEventManager
73+
case mevmgr of
74+
Nothing -> recvEOFloop s tmout0 buf
75+
Just _ -> recvEOFevent s tmout0 buf
76+
#else
77+
recvEOF = recvEOFloop
78+
#endif
79+
80+
-- Don't use 4092 here. The GHC runtime takes the global lock
81+
-- if the length is over 3276 bytes in 32bit or 3272 bytes in 64bit.
82+
bufSize :: Int
83+
bufSize = 1024
84+
85+
recvEOFloop :: Socket -> Int -> Ptr Word8 -> IO ()
86+
recvEOFloop s tmout0 buf = loop 1 0
87+
where
88+
loop delay tmout = do
89+
-- We don't check the (positive) length.
90+
-- In normal case, it's 0. That is, only FIN is received.
91+
-- In error cases, data is available. But there is no
92+
-- application which can read it. So, let's stop receiving
93+
-- to prevent attacks.
94+
r <- recvBufNoWait s buf bufSize
95+
when (r == -1 && tmout < tmout0) $ do
96+
threadDelay (delay * 1000)
97+
loop (delay * 2) (tmout + delay)
98+
99+
#if !defined(mingw32_HOST_OS)
100+
data Wait = MoreData | TimeoutTripped
101+
102+
recvEOFevent :: Socket -> Int -> Ptr Word8 -> IO ()
103+
recvEOFevent s tmout0 buf = do
104+
tmmgr <- Ev.getSystemTimerManager
105+
tvar <- newTVarIO False
106+
E.bracket (setup tmmgr tvar) teardown $ \(wait, _) -> do
107+
waitRes <- wait
108+
case waitRes of
109+
TimeoutTripped -> return ()
110+
-- We don't check the (positive) length.
111+
-- In normal case, it's 0. That is, only FIN is received.
112+
-- In error cases, data is available. But there is no
113+
-- application which can read it. So, let's stop receiving
114+
-- to prevent attacks.
115+
MoreData -> void $ recvBufNoWait s buf bufSize
116+
where
117+
setup tmmgr tvar = do
118+
-- millisecond to microsecond
119+
key <- Ev.registerTimeout tmmgr (tmout0 * 1000) $
120+
atomically $ writeTVar tvar True
121+
(evWait, evCancel) <- waitAndCancelReadSocketSTM s
122+
let toWait = do
123+
tmout <- readTVar tvar
124+
check tmout
125+
toCancel = Ev.unregisterTimeout tmmgr key
126+
wait = atomically ((toWait >> return TimeoutTripped)
127+
<|> (evWait >> return MoreData))
128+
cancel = evCancel >> toCancel
129+
return (wait, cancel)
130+
teardown (_, cancel) = cancel
131+
#endif

0 commit comments

Comments
 (0)