Skip to content

Commit 15cd90b

Browse files
committed
Merge PR #417
2 parents bbb66aa + 2f62dd3 commit 15cd90b

File tree

6 files changed

+152
-4
lines changed

6 files changed

+152
-4
lines changed

Network/Socket.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ module Network.Socket
122122
-- ** Closing
123123
, close
124124
, close'
125+
, gracefulClose
125126
, shutdown
126127
, ShutdownCmd(..)
127128

Network/Socket/Buffer.hs renamed to Network/Socket/Buffer.hsc

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
{-# LANGUAGE CPP #-}
22

3-
#include "HsNetDef.h"
3+
##include "HsNetDef.h"
4+
#if defined(mingw32_HOST_OS)
5+
# include "windows.h"
6+
#endif
47

58
module Network.Socket.Buffer (
69
sendBufTo
710
, sendBuf
811
, recvBufFrom
912
, recvBuf
13+
, recvBufNoWait
1014
) where
1115

16+
#if !defined(mingw32_HOST_OS)
17+
import Foreign.C.Error (getErrno, eAGAIN, eWOULDBLOCK)
18+
#endif
1219
import Foreign.Marshal.Alloc (alloca)
1320
import GHC.IO.Exception (IOErrorType(InvalidArgument))
1421
import System.IO.Error (mkIOError, ioeSetErrorString, catchIOError)
@@ -129,6 +136,43 @@ recvBuf s ptr nbytes
129136
#endif
130137
return $ fromIntegral len
131138

139+
-- | Receive data from the socket. This function returns immediately
140+
-- even if data is not available. In other words, IO manager is NOT
141+
-- involved. The length of data is returned if received.
142+
-- -1 is returned in the case of EAGAIN or EWOULDBLOCK.
143+
-- -2 is returned in other error cases.
144+
recvBufNoWait :: Socket -> Ptr Word8 -> Int -> IO Int
145+
recvBufNoWait s ptr nbytes = withFdSocket s $ \fd -> do
146+
#if defined(mingw32_HOST_OS)
147+
alloca $ \ptr_bytes -> do
148+
res <- c_ioctlsocket fd #{const FIONREAD} ptr_bytes
149+
avail <- peek ptr_bytes
150+
r <- if res == #{const NO_ERROR} && avail > 0 then
151+
c_recv fd (castPtr ptr) (fromIntegral nbytes) 0{-flags-}
152+
else if avail == 0 then
153+
-- Socket would block, could also mean socket is closed but
154+
-- can't distinguish
155+
return (-1)
156+
else do err <- c_WSAGetLastError
157+
if err == #{const WSAEWOULDBLOCK}
158+
|| err == #{const WSAEINPROGRESS} then
159+
return (-1)
160+
else
161+
return (-2)
162+
return $ fromIntegral r
163+
164+
#else
165+
r <- c_recv fd (castPtr ptr) (fromIntegral nbytes) 0{-flags-}
166+
if r >= 0 then
167+
return $ fromIntegral r
168+
else do
169+
err <- getErrno
170+
if err == eAGAIN || err == eWOULDBLOCK then
171+
return (-1)
172+
else
173+
return (-2)
174+
#endif
175+
132176
mkInvalidRecvArgError :: String -> IOError
133177
mkInvalidRecvArgError loc = ioeSetErrorString (mkIOError
134178
InvalidArgument
@@ -137,9 +181,14 @@ mkInvalidRecvArgError loc = ioeSetErrorString (mkIOError
137181
#if !defined(mingw32_HOST_OS)
138182
foreign import ccall unsafe "send"
139183
c_send :: CInt -> Ptr a -> CSize -> CInt -> IO CInt
184+
#else
185+
foreign import CALLCONV SAFE_ON_WIN "ioctlsocket"
186+
c_ioctlsocket :: CInt -> CLong -> Ptr CULong -> IO CInt
187+
foreign import CALLCONV SAFE_ON_WIN "WSAGetLastError"
188+
c_WSAGetLastError :: IO CInt
189+
#endif
140190
foreign import ccall unsafe "recv"
141191
c_recv :: CInt -> Ptr CChar -> CSize -> CInt -> IO CInt
142-
#endif
143192
foreign import CALLCONV SAFE_ON_WIN "sendto"
144193
c_sendto :: CInt -> Ptr a -> CSize -> CInt -> Ptr sa -> CInt -> IO CInt
145194
foreign import CALLCONV SAFE_ON_WIN "recvfrom"

Network/Socket/Shutdown.hs

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,21 @@
66
module Network.Socket.Shutdown (
77
ShutdownCmd(..)
88
, shutdown
9+
, gracefulClose
910
) where
1011

12+
import qualified Control.Exception as E
13+
import Foreign.Marshal.Alloc (mallocBytes, free)
14+
15+
#if defined(mingw32_HOST_OS)
16+
import Control.Concurrent (threadDelay)
17+
#else
18+
import Control.Concurrent (putMVar, takeMVar, newEmptyMVar)
19+
import qualified GHC.Event as Ev
20+
import System.Posix.Types (Fd(..))
21+
#endif
22+
23+
import Network.Socket.Buffer
1124
import Network.Socket.Imports
1225
import Network.Socket.Internal
1326
import Network.Socket.Types
@@ -34,3 +47,75 @@ shutdown s stype = void $ withFdSocket s $ \fd ->
3447

3548
foreign import CALLCONV unsafe "shutdown"
3649
c_shutdown :: CInt -> CInt -> IO CInt
50+
51+
#if !defined(mingw32_HOST_OS)
52+
data Wait = MoreData | TimeoutTripped
53+
#endif
54+
55+
-- | Closing a socket gracefully.
56+
-- This sends TCP FIN and check if TCP FIN is received from the peer.
57+
-- The second argument is time out to receive TCP FIN in millisecond.
58+
-- In both normal cases and error cases, socket is deallocated finally.
59+
--
60+
-- Since: 3.1.1.0
61+
gracefulClose :: Socket -> Int -> IO ()
62+
gracefulClose s tmout = sendRecvFIN `E.finally` close s
63+
where
64+
sendRecvFIN = do
65+
-- Sending TCP FIN.
66+
shutdown s ShutdownSend
67+
-- Waiting TCP FIN.
68+
recvEOF
69+
#if defined(mingw32_HOST_OS)
70+
-- milliseconds. Taken from BSD fast clock value.
71+
clock = 200
72+
recvEOF = E.bracket (mallocBytes bufSize) free $ loop 0
73+
where
74+
loop delay buf = do
75+
-- We don't check the (positive) length.
76+
-- In normal case, it's 0. That is, only FIN is received.
77+
-- In error cases, data is available. But there is no
78+
-- application which can read it. So, let's stop receiving
79+
-- to prevent attacks.
80+
r <- recvBufNoWait s buf bufSize
81+
let delay' = delay + clock
82+
when (r == -1 && delay' < tmout) $ do
83+
threadDelay (clock * 1000)
84+
loop delay' buf
85+
#else
86+
recvEOF = do
87+
Just evmgr <- Ev.getSystemEventManager
88+
tmmgr <- Ev.getSystemTimerManager
89+
mvar <- newEmptyMVar
90+
E.bracket (register evmgr tmmgr mvar) (unregister evmgr tmmgr) $ \_ -> do
91+
wait <- takeMVar mvar
92+
case wait of
93+
TimeoutTripped -> return ()
94+
-- We don't check the (positive) length.
95+
-- In normal case, it's 0. That is, only FIN is received.
96+
-- In error cases, data is available. But there is no
97+
-- application which can read it. So, let's stop receiving
98+
-- to prevent attacks.
99+
MoreData -> E.bracket (mallocBytes bufSize)
100+
free
101+
(\buf -> void $ recvBufNoWait s buf bufSize)
102+
register evmgr tmmgr mvar = do
103+
-- millisecond to microsecond
104+
key1 <- Ev.registerTimeout tmmgr (tmout * 1000) $
105+
putMVar mvar TimeoutTripped
106+
key2 <- withFdSocket s $ \fd' -> do
107+
let callback _ _ = putMVar mvar MoreData
108+
fd = Fd fd'
109+
#if __GLASGOW_HASKELL__ < 709
110+
Ev.registerFd evmgr callback fd Ev.evtRead
111+
#else
112+
Ev.registerFd evmgr callback fd Ev.evtRead Ev.OneShot
113+
#endif
114+
return (key1, key2)
115+
unregister evmgr tmmgr (key1,key2) = do
116+
Ev.unregisterTimeout tmmgr key1
117+
Ev.unregisterFd evmgr key2
118+
#endif
119+
-- Don't use 4092 here. The GHC runtime takes the global lock
120+
-- if the length is over 3276 bytes in 32bit or 3272 bytes in 64bit.
121+
bufSize = 1024

configure.ac

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
AC_INIT([Haskell network package],
2-
[3.1.0.1],
2+
[3.1.1.0],
33
44
[network])
55

network.cabal

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
cabal-version: 1.18
22
name: network
3-
version: 3.1.0.1
3+
version: 3.1.1.0
44
license: BSD3
55
license-file: LICENSE
66
maintainer: Kazu Yamamoto, Evan Borden

tests/Network/SocketSpec.hs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,16 @@ spec = do
196196
cred1 <- getPeerCredential s
197197
cred1 `shouldBe` (Nothing,Nothing,Nothing)
198198
-}
199+
200+
describe "gracefulClose" $ do
201+
it "does not send TCP RST back" $ do
202+
let server sock = do
203+
void $ recv sock 1024 -- receiving "GOAWAY"
204+
gracefulClose sock 3000
205+
client sock = do
206+
sendAll sock "GOAWAY"
207+
threadDelay 10000
208+
sendAll sock "PING"
209+
threadDelay 10000
210+
void $ recv sock 1024
211+
tcpTest client server

0 commit comments

Comments
 (0)