@@ -8,14 +8,20 @@ module Network.Socket.Shutdown (
8
8
, gracefulClose
9
9
) where
10
10
11
+ import Control.Concurrent (yield )
11
12
import qualified Control.Exception as E
12
13
import Foreign.Marshal.Alloc (mallocBytes , free )
14
+ import System.Timeout
13
15
14
- import Control.Concurrent (threadDelay , yield )
16
+ #if !defined(mingw32_HOST_OS)
17
+ import Control.Concurrent.STM
18
+ import qualified GHC.Event as Ev
19
+ #endif
15
20
16
21
import Network.Socket.Buffer
17
22
import Network.Socket.Imports
18
23
import Network.Socket.Internal
24
+ import Network.Socket.STM
19
25
import Network.Socket.Types
20
26
21
27
data ShutdownCmd = ShutdownReceive
@@ -59,19 +65,57 @@ gracefulClose s tmout0 = sendRecvFIN `E.finally` close s
59
65
-- FIN arrives meanwhile.
60
66
yield
61
67
-- Waiting TCP FIN.
62
- E. bracket (mallocBytes bufSize) free recvEOFloop
63
- recvEOFloop buf = loop 1 0
64
- where
65
- loop delay tmout = do
66
- -- We don't check the (positive) length.
67
- -- In normal case, it's 0. That is, only FIN is received.
68
- -- In error cases, data is available. But there is no
69
- -- application which can read it. So, let's stop receiving
70
- -- to prevent attacks.
71
- r <- recvBufNoWait s buf bufSize
72
- when (r == - 1 && tmout < tmout0) $ do
73
- threadDelay (delay * 1000 )
74
- loop (delay * 2 ) (tmout + delay)
75
- -- Don't use 4092 here. The GHC runtime takes the global lock
76
- -- if the length is over 3276 bytes in 32bit or 3272 bytes in 64bit.
77
- bufSize = 1024
68
+ E. bracket (mallocBytes bufSize) free (recvEOF s tmout0)
69
+
70
+ recvEOF :: Socket -> Int -> Ptr Word8 -> IO ()
71
+ #if !defined(mingw32_HOST_OS)
72
+ recvEOF s tmout0 buf = do
73
+ mevmgr <- Ev. getSystemEventManager
74
+ case mevmgr of
75
+ Nothing -> recvEOFloop s tmout0 buf
76
+ Just _ -> recvEOFevent s tmout0 buf
77
+ #else
78
+ recvEOF = recvEOFloop
79
+ #endif
80
+
81
+ -- Don't use 4092 here. The GHC runtime takes the global lock
82
+ -- if the length is over 3276 bytes in 32bit or 3272 bytes in 64bit.
83
+ bufSize :: Int
84
+ bufSize = 1024
85
+
86
+ recvEOFloop :: Socket -> Int -> Ptr Word8 -> IO ()
87
+ recvEOFloop s tmout0 buf = void $ timeout tmout0 $ recvBuf s buf bufSize
88
+
89
+ #if !defined(mingw32_HOST_OS)
90
+ data Wait = MoreData | TimeoutTripped
91
+
92
+ recvEOFevent :: Socket -> Int -> Ptr Word8 -> IO ()
93
+ recvEOFevent s tmout0 buf = do
94
+ tmmgr <- Ev. getSystemTimerManager
95
+ tvar <- newTVarIO False
96
+ E. bracket (setup tmmgr tvar) teardown $ \ (wait, _) -> do
97
+ waitRes <- wait
98
+ case waitRes of
99
+ TimeoutTripped -> return ()
100
+ -- We don't check the (positive) length.
101
+ -- In normal case, it's 0. That is, only FIN is received.
102
+ -- In error cases, data is available. But there is no
103
+ -- application which can read it. So, let's stop receiving
104
+ -- to prevent attacks.
105
+ MoreData -> void $ recvBufNoWait s buf bufSize
106
+ where
107
+ setup tmmgr tvar = do
108
+ -- millisecond to microsecond
109
+ key <- Ev. registerTimeout tmmgr (tmout0 * 1000 ) $
110
+ atomically $ writeTVar tvar True
111
+ (evWait, evCancel) <- waitAndCancelReadSocketSTM s
112
+ let toWait = do
113
+ tmout <- readTVar tvar
114
+ check tmout
115
+ toCancel = Ev. unregisterTimeout tmmgr key
116
+ wait = atomically ((toWait >> return TimeoutTripped )
117
+ <|> (evWait >> return MoreData ))
118
+ cancel = evCancel >> toCancel
119
+ return (wait, cancel)
120
+ teardown (_, cancel) = cancel
121
+ #endif
0 commit comments