@@ -8,14 +8,19 @@ module Network.Socket.Shutdown (
88 , gracefulClose
99 ) where
1010
11+ import Control.Concurrent (threadDelay , yield )
1112import qualified Control.Exception as E
1213import Foreign.Marshal.Alloc (mallocBytes , free )
1314
14- import Control.Concurrent (threadDelay , yield )
15+ #if !defined(mingw32_HOST_OS)
16+ import Control.Concurrent.STM
17+ import qualified GHC.Event as Ev
18+ #endif
1519
1620import Network.Socket.Buffer
1721import Network.Socket.Imports
1822import Network.Socket.Internal
23+ import Network.Socket.STM
1924import Network.Socket.Types
2025
2126data ShutdownCmd = ShutdownReceive
@@ -59,19 +64,68 @@ gracefulClose s tmout0 = sendRecvFIN `E.finally` close s
5964 -- FIN arrives meanwhile.
6065 yield
6166 -- Waiting TCP FIN.
62- E. bracket (mallocBytes bufSize) free recvEOFloop
63- recvEOFloop buf = loop 1 0
64- where
65- loop delay tmout = do
66- -- We don't check the (positive) length.
67- -- In normal case, it's 0. That is, only FIN is received.
68- -- In error cases, data is available. But there is no
69- -- application which can read it. So, let's stop receiving
70- -- to prevent attacks.
71- r <- recvBufNoWait s buf bufSize
72- when (r == - 1 && tmout < tmout0) $ do
73- threadDelay (delay * 1000 )
74- loop (delay * 2 ) (tmout + delay)
75- -- Don't use 4092 here. The GHC runtime takes the global lock
76- -- if the length is over 3276 bytes in 32bit or 3272 bytes in 64bit.
77- bufSize = 1024
67+ E. bracket (mallocBytes bufSize) free (recvEOF s tmout0)
68+
69+ recvEOF :: Socket -> Int -> Ptr Word8 -> IO ()
70+ #if !defined(mingw32_HOST_OS)
71+ recvEOF s tmout0 buf = do
72+ mevmgr <- Ev. getSystemEventManager
73+ case mevmgr of
74+ Nothing -> recvEOFloop s tmout0 buf
75+ Just _ -> recvEOFevent s tmout0 buf
76+ #else
77+ recvEOF = recvEOFloop
78+ #endif
79+
80+ -- Don't use 4092 here. The GHC runtime takes the global lock
81+ -- if the length is over 3276 bytes in 32bit or 3272 bytes in 64bit.
82+ bufSize :: Int
83+ bufSize = 1024
84+
85+ recvEOFloop :: Socket -> Int -> Ptr Word8 -> IO ()
86+ recvEOFloop s tmout0 buf = loop 1 0
87+ where
88+ loop delay tmout = do
89+ -- We don't check the (positive) length.
90+ -- In normal case, it's 0. That is, only FIN is received.
91+ -- In error cases, data is available. But there is no
92+ -- application which can read it. So, let's stop receiving
93+ -- to prevent attacks.
94+ r <- recvBufNoWait s buf bufSize
95+ when (r == - 1 && tmout < tmout0) $ do
96+ threadDelay (delay * 1000 )
97+ loop (delay * 2 ) (tmout + delay)
98+
99+ #if !defined(mingw32_HOST_OS)
100+ data Wait = MoreData | TimeoutTripped
101+
102+ recvEOFevent :: Socket -> Int -> Ptr Word8 -> IO ()
103+ recvEOFevent s tmout0 buf = do
104+ tmmgr <- Ev. getSystemTimerManager
105+ tvar <- newTVarIO False
106+ E. bracket (setup tmmgr tvar) teardown $ \ (wait, _) -> do
107+ waitRes <- wait
108+ case waitRes of
109+ TimeoutTripped -> return ()
110+ -- We don't check the (positive) length.
111+ -- In normal case, it's 0. That is, only FIN is received.
112+ -- In error cases, data is available. But there is no
113+ -- application which can read it. So, let's stop receiving
114+ -- to prevent attacks.
115+ MoreData -> void $ recvBufNoWait s buf bufSize
116+ where
117+ setup tmmgr tvar = do
118+ -- millisecond to microsecond
119+ key <- Ev. registerTimeout tmmgr (tmout0 * 1000 ) $
120+ atomically $ writeTVar tvar True
121+ (evWait, evCancel) <- waitAndCancelReadSocketSTM s
122+ let toWait = do
123+ tmout <- readTVar tvar
124+ check tmout
125+ toCancel = Ev. unregisterTimeout tmmgr key
126+ wait = atomically ((toWait >> return TimeoutTripped )
127+ <|> (evWait >> return MoreData ))
128+ cancel = evCancel >> toCancel
129+ return (wait, cancel)
130+ teardown (_, cancel) = cancel
131+ #endif
0 commit comments