Skip to content

Commit bda5016

Browse files
Merge pull request #589 from kazu-yamamoto/gracefulClose-based-on-racing-again
Revisiting gracefulClose with STM racing
2 parents 287f2a9 + 10ab2cb commit bda5016

File tree

1 file changed

+61
-17
lines changed

1 file changed

+61
-17
lines changed

Network/Socket/Shutdown.hs

Lines changed: 61 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,20 @@ module Network.Socket.Shutdown (
88
, gracefulClose
99
) where
1010

11+
import Control.Concurrent (yield)
1112
import qualified Control.Exception as E
1213
import Foreign.Marshal.Alloc (mallocBytes, free)
14+
import System.Timeout
1315

14-
import Control.Concurrent (threadDelay, yield)
16+
#if !defined(mingw32_HOST_OS)
17+
import Control.Concurrent.STM
18+
import qualified GHC.Event as Ev
19+
#endif
1520

1621
import Network.Socket.Buffer
1722
import Network.Socket.Imports
1823
import Network.Socket.Internal
24+
import Network.Socket.STM
1925
import Network.Socket.Types
2026

2127
data ShutdownCmd = ShutdownReceive
@@ -59,19 +65,57 @@ gracefulClose s tmout0 = sendRecvFIN `E.finally` close s
5965
-- FIN arrives meanwhile.
6066
yield
6167
-- Waiting TCP FIN.
62-
E.bracket (mallocBytes bufSize) free recvEOFloop
63-
recvEOFloop buf = loop 1 0
64-
where
65-
loop delay tmout = do
66-
-- We don't check the (positive) length.
67-
-- In normal case, it's 0. That is, only FIN is received.
68-
-- In error cases, data is available. But there is no
69-
-- application which can read it. So, let's stop receiving
70-
-- to prevent attacks.
71-
r <- recvBufNoWait s buf bufSize
72-
when (r == -1 && tmout < tmout0) $ do
73-
threadDelay (delay * 1000)
74-
loop (delay * 2) (tmout + delay)
75-
-- Don't use 4092 here. The GHC runtime takes the global lock
76-
-- if the length is over 3276 bytes in 32bit or 3272 bytes in 64bit.
77-
bufSize = 1024
68+
E.bracket (mallocBytes bufSize) free (recvEOF s tmout0)
69+
70+
recvEOF :: Socket -> Int -> Ptr Word8 -> IO ()
71+
#if !defined(mingw32_HOST_OS)
72+
recvEOF s tmout0 buf = do
73+
mevmgr <- Ev.getSystemEventManager
74+
case mevmgr of
75+
Nothing -> recvEOFloop s tmout0 buf
76+
Just _ -> recvEOFevent s tmout0 buf
77+
#else
78+
recvEOF = recvEOFloop
79+
#endif
80+
81+
-- Don't use 4092 here. The GHC runtime takes the global lock
82+
-- if the length is over 3276 bytes in 32bit or 3272 bytes in 64bit.
83+
bufSize :: Int
84+
bufSize = 1024
85+
86+
recvEOFloop :: Socket -> Int -> Ptr Word8 -> IO ()
87+
recvEOFloop s tmout0 buf = void $ timeout tmout0 $ recvBuf s buf bufSize
88+
89+
#if !defined(mingw32_HOST_OS)
90+
data Wait = MoreData | TimeoutTripped
91+
92+
recvEOFevent :: Socket -> Int -> Ptr Word8 -> IO ()
93+
recvEOFevent s tmout0 buf = do
94+
tmmgr <- Ev.getSystemTimerManager
95+
tvar <- newTVarIO False
96+
E.bracket (setup tmmgr tvar) teardown $ \(wait, _) -> do
97+
waitRes <- wait
98+
case waitRes of
99+
TimeoutTripped -> return ()
100+
-- We don't check the (positive) length.
101+
-- In normal case, it's 0. That is, only FIN is received.
102+
-- In error cases, data is available. But there is no
103+
-- application which can read it. So, let's stop receiving
104+
-- to prevent attacks.
105+
MoreData -> void $ recvBufNoWait s buf bufSize
106+
where
107+
setup tmmgr tvar = do
108+
-- millisecond to microsecond
109+
key <- Ev.registerTimeout tmmgr (tmout0 * 1000) $
110+
atomically $ writeTVar tvar True
111+
(evWait, evCancel) <- waitAndCancelReadSocketSTM s
112+
let toWait = do
113+
tmout <- readTVar tvar
114+
check tmout
115+
toCancel = Ev.unregisterTimeout tmmgr key
116+
wait = atomically ((toWait >> return TimeoutTripped)
117+
<|> (evWait >> return MoreData))
118+
cancel = evCancel >> toCancel
119+
return (wait, cancel)
120+
teardown (_, cancel) = cancel
121+
#endif

0 commit comments

Comments
 (0)