Skip to content

Commit 507a2c0

Browse files
Mistukekazu-yamamoto
authored andcommitted
Finish windows implementation
1 parent fb8529f commit 507a2c0

File tree

10 files changed

+110
-29
lines changed

10 files changed

+110
-29
lines changed

Network/Socket.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ module Network.Socket
140140
,RecvIPv4TTL,RecvIPv4TOS,RecvIPv4PktInfo
141141
,RecvIPv6HopLimit,RecvIPv6TClass,RecvIPv6PktInfo)
142142
, isSupportedSocketOption
143+
, whenSupported
143144
, getSocketOption
144145
, setSocketOption
145146
, getSockOpt

Network/Socket/Buffer.hsc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,9 +267,17 @@ recvBufMsg s bufsizs clen flags = do
267267
, msgBuffer = wsaBPtr
268268
, msgBufferLen = fromIntegral wsaBLen
269269
#endif
270+
#if !defined(mingw32_HOST_OS)
270271
, msgCtrl = castPtr ctrlPtr
272+
#else
273+
, msgCtrl = if clen == 0 then nullPtr else castPtr ctrlPtr
274+
#endif
271275
, msgCtrlLen = fromIntegral clen
276+
#if !defined(mingw32_HOST_OS)
272277
, msgFlags = 0
278+
#else
279+
, msgFlags = fromIntegral $ fromMsgFlag flags
280+
#endif
273281
}
274282
_cflags = fromMsgFlag flags
275283
withFdSocket s $ \fd -> do
@@ -280,7 +288,7 @@ recvBufMsg s bufsizs clen flags = do
280288
c_recvmsg fd msgHdrPtr _cflags
281289
#else
282290
alloca $ \len_ptr -> do
283-
_ <- throwSocketErrorWaitRead s "Network.Socket.Buffer.recvmg" $
291+
_ <- throwSocketErrorWaitReadBut (== #{const WSAEMSGSIZE}) s "Network.Socket.Buffer.recvmg" $
284292
c_recvmsg fd msgHdrPtr len_ptr nullPtr nullPtr
285293
peek len_ptr
286294
#endif

Network/Socket/Internal.hs

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,14 @@ module Network.Socket.Internal
3434
, throwSocketErrorIfMinus1Retry
3535
, throwSocketErrorIfMinus1Retry_
3636
, throwSocketErrorIfMinus1RetryMayBlock
37+
, throwSocketErrorIfMinus1ButRetry
3738

3839
-- ** Guards that wait and retry if the operation would block
3940
-- | These guards are based on 'throwSocketErrorIfMinus1RetryMayBlock'.
4041
-- They wait for socket readiness if the action fails with @EWOULDBLOCK@
4142
-- or similar.
4243
, throwSocketErrorWaitRead
44+
, throwSocketErrorWaitReadBut
4345
, throwSocketErrorWaitWrite
4446

4547
-- * Initialization
@@ -134,16 +136,37 @@ throwSocketErrorIfMinus1RetryMayBlock
134136
{-# SPECIALIZE throwSocketErrorIfMinus1RetryMayBlock
135137
:: String -> IO b -> IO CInt -> IO CInt #-}
136138

139+
140+
-- | Throw an 'IOError' corresponding to the current socket error if
141+
-- the IO action returns a result of @-1@, but retries in case of an
142+
-- interrupted operation. Checks for operations that would block and
143+
-- executes an alternative action before retrying in that case. If the error
144+
-- is one handled by the exempt filter then ignore it and return the errorcode.
145+
throwSocketErrorIfMinus1RetryMayBlockBut
146+
:: (Eq a, Num a)
147+
=> (CInt -> Bool) -- ^ exception exempt filter
148+
-> String -- ^ textual description of the location
149+
-> IO b -- ^ action to execute before retrying if an
150+
-- immediate retry would block
151+
-> IO a -- ^ the 'IO' operation to be executed
152+
-> IO a
153+
154+
{-# SPECIALIZE throwSocketErrorIfMinus1RetryMayBlock
155+
:: String -> IO b -> IO CInt -> IO CInt #-}
156+
137157
#if defined(mingw32_HOST_OS)
138158

139159
throwSocketErrorIfMinus1RetryMayBlock name _ act
140160
= throwSocketErrorIfMinus1Retry name act
141161

162+
throwSocketErrorIfMinus1RetryMayBlockBut exempt name _ act
163+
= throwSocketErrorIfMinus1ButRetry exempt name act
164+
142165
throwSocketErrorIfMinus1_ name act = do
143166
_ <- throwSocketErrorIfMinus1Retry name act
144167
return ()
145168

146-
throwSocketErrorIfMinus1Retry name act = do
169+
throwSocketErrorIfMinus1ButRetry exempt name act = do
147170
r <- act
148171
if (r == -1)
149172
then do
@@ -155,7 +178,9 @@ throwSocketErrorIfMinus1Retry name act = do
155178
then throwSocketError name
156179
else return r'
157180
else
158-
throwSocketError name
181+
if (exempt rc)
182+
then return r
183+
else throwSocketError name
159184
else return r
160185

161186
throwSocketErrorCode name rc = do
@@ -177,6 +202,9 @@ foreign import ccall unsafe "getWSErrorDescr"
177202
throwSocketErrorIfMinus1RetryMayBlock name on_block act =
178203
throwErrnoIfMinus1RetryMayBlock name act on_block
179204

205+
throwSocketErrorIfMinus1RetryMayBlockBut _exempt name on_block act =
206+
throwErrnoIfMinus1RetryMayBlock name act on_block
207+
180208
throwSocketErrorIfMinus1Retry = throwErrnoIfMinus1Retry
181209

182210
throwSocketErrorIfMinus1_ = throwErrnoIfMinus1_
@@ -188,6 +216,9 @@ throwSocketErrorCode loc errno =
188216

189217
#endif
190218

219+
throwSocketErrorIfMinus1Retry
220+
= throwSocketErrorIfMinus1ButRetry (const False)
221+
191222
-- | Like 'throwSocketErrorIfMinus1Retry', but if the action fails with
192223
-- @EWOULDBLOCK@ or similar, wait for the socket to be read-ready,
193224
-- and try again.
@@ -196,6 +227,15 @@ throwSocketErrorWaitRead s name io = withFdSocket s $ \fd ->
196227
throwSocketErrorIfMinus1RetryMayBlock name
197228
(threadWaitRead $ fromIntegral fd) io
198229

230+
-- | Like 'throwSocketErrorIfMinus1Retry', but if the action fails with
231+
-- @EWOULDBLOCK@ or similar, wait for the socket to be read-ready,
232+
-- and try again. If it fails with the error the user was expecting then
233+
-- ignore the error
234+
throwSocketErrorWaitReadBut :: (Eq a, Num a) => (CInt -> Bool) -> Socket -> String -> IO a -> IO a
235+
throwSocketErrorWaitReadBut exempt s name io = withFdSocket s $ \fd ->
236+
throwSocketErrorIfMinus1RetryMayBlockBut exempt name
237+
(threadWaitRead $ fromIntegral fd) io
238+
199239
-- | Like 'throwSocketErrorIfMinus1Retry', but if the action fails with
200240
-- @EWOULDBLOCK@ or similar, wait for the socket to be write-ready,
201241
-- and try again.

Network/Socket/Options.hsc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ module Network.Socket.Options (
1616
,RecvIPv4TTL,RecvIPv4TOS,RecvIPv4PktInfo
1717
,RecvIPv6HopLimit,RecvIPv6TClass,RecvIPv6PktInfo)
1818
, isSupportedSocketOption
19+
, whenSupported
1920
, getSocketType
2021
, getSocketOption
2122
, setSocketOption
@@ -289,6 +290,13 @@ instance Storable StructLinger where
289290
(#poke struct linger, l_linger) p linger
290291
#endif
291292

293+
-- | Executes the given action and ignoring the result only when the specified
294+
-- socket option is valid.
295+
whenSupported :: SocketOption -> IO a -> IO ()
296+
whenSupported s action
297+
| isSupportedSocketOption s = action >> return ()
298+
| otherwise = return ()
299+
292300
-- | Set a socket option that expects an Int value.
293301
-- There is currently no API to set e.g. the timeval socket options
294302
setSocketOption :: Socket

Network/Socket/Win32/Cmsg.hsc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@ pattern CmsgIdIPv6HopLimit = CmsgId (#const IPPROTO_IPV6) (#const IPV6_HOPLIMIT)
4141

4242
-- | The identifier for 'IPv4TOS'.
4343
pattern CmsgIdIPv4TOS :: CmsgId
44-
pattern CmsgIdIPv4TOS = CmsgId (#const IPPROTO_IP) (#const IP_RECVTOS)
44+
pattern CmsgIdIPv4TOS = CmsgId (#const IPPROTO_IP) (#const IP_TOS)
4545

4646
-- | The identifier for 'IPv6TClass'.
4747
pattern CmsgIdIPv6TClass :: CmsgId
48-
pattern CmsgIdIPv6TClass = CmsgId (#const IPPROTO_IPV6) (#const IPV6_RECVTCLASS)
48+
pattern CmsgIdIPv6TClass = CmsgId (#const IPPROTO_IPV6) (#const IPV6_TCLASS)
4949

5050
-- | The identifier for 'IPv4PktInfo'.
5151
pattern CmsgIdIPv4PktInfo :: CmsgId

Network/Socket/Win32/CmsgHdr.hsc

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,21 @@ parseCmsgs msgptr = do
7272
loop ptr build
7373
| ptr == nullPtr = return $ build []
7474
| otherwise = do
75-
cmsg <- fromCmsgHdr ptr
76-
nextPtr <- c_cmsg_nxthdr msgptr ptr
77-
loop nextPtr (build . (cmsg :))
78-
79-
fromCmsgHdr :: Ptr CmsgHdr -> IO Cmsg
75+
val <- fromCmsgHdr ptr
76+
case val of
77+
Nothing -> return $ build []
78+
Just cmsg -> do
79+
nextPtr <- c_cmsg_nxthdr msgptr ptr
80+
loop nextPtr (build . (cmsg :))
81+
82+
fromCmsgHdr :: Ptr CmsgHdr -> IO (Maybe Cmsg)
8083
fromCmsgHdr ptr = do
8184
CmsgHdr len lvl typ <- peek ptr
8285
src <- c_cmsg_data ptr
8386
let siz = fromIntegral len - (src `minusPtr` ptr)
84-
Cmsg (CmsgId lvl typ) <$> create (fromIntegral siz) (\dst -> memcpy dst src siz)
87+
if siz < 0
88+
then return Nothing
89+
else Just . Cmsg (CmsgId lvl typ) <$> create (fromIntegral siz) (\dst -> memcpy dst src siz)
8590

8691
foreign import ccall unsafe "cmsg_firsthdr"
8792
c_cmsg_firsthdr :: Ptr (MsgHdr sa) -> IO (Ptr CmsgHdr)

Network/Socket/Win32/MsgHdr.hsc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ data MsgHdr sa = MsgHdr
2525
, msgCtrl :: !(Ptr Word8)
2626
, msgCtrlLen :: !ULONG
2727
, msgFlags :: !DWORD
28-
}
28+
} deriving Show
2929

3030
instance Storable (MsgHdr sa) where
3131
sizeOf = const #{size WSAMSG}

cbits/cmsg.c

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,17 @@ WSARecvMsg (SOCKET s, LPWSAMSG lpMsg, LPDWORD lpdwNumberOfBytesRecvd,
6262
return -1;
6363
}
6464

65-
return ptr_RecvMsg (s, lpMsg, lpdwNumberOfBytesRecvd, lpOverlapped,
66-
lpCompletionRoutine);
65+
int res = ptr_RecvMsg (s, lpMsg, lpdwNumberOfBytesRecvd, lpOverlapped,
66+
lpCompletionRoutine);
67+
68+
/* If the msg was truncated then this pointer can be garbage. */
69+
if (res == SOCKET_ERROR && GetLastError () == WSAEMSGSIZE)
70+
{
71+
lpMsg->Control.len = 0;
72+
lpMsg->Control.buf = NULL;
73+
}
74+
75+
return res;
6776
}
6877
#else
6978
struct cmsghdr *cmsg_firsthdr(struct msghdr *mhdr) {

include/win32defs.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,4 +114,7 @@
114114
#endif
115115
#ifndef IP_RECVERR
116116
#define IP_RECVERR 75 // Receive ICMP errors.
117+
#endif
118+
#ifndef IPV6_TCLASS
119+
#define IPV6_TCLASS 39
117120
#endif

tests/Network/Socket/ByteStringSpec.hs

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -229,39 +229,46 @@ spec = do
229229

230230
it "receives control messages for IPv4" $ do
231231
let server sock = do
232-
setSocketOption sock RecvIPv4TTL 1
233-
setSocketOption sock RecvIPv4TOS 1
234-
setSocketOption sock RecvIPv4PktInfo 1
232+
whenSupported RecvIPv4TTL $ setSocketOption sock RecvIPv4TTL 1
233+
whenSupported RecvIPv4TOS $ setSocketOption sock RecvIPv4TOS 1
234+
whenSupported RecvIPv4PktInfo $ setSocketOption sock RecvIPv4PktInfo 1
235235
(_, _, cmsgs, _) <- recvMsg sock 1024 128 mempty
236236

237-
((lookupCmsg CmsgIdIPv4TTL cmsgs >>= decodeCmsg) :: Maybe IPv4TTL) `shouldNotBe` Nothing
238-
((lookupCmsg CmsgIdIPv4TOS cmsgs >>= decodeCmsg) :: Maybe IPv4TOS) `shouldNotBe` Nothing
239-
((lookupCmsg CmsgIdIPv4PktInfo cmsgs >>= decodeCmsg) :: Maybe IPv4PktInfo) `shouldNotBe` Nothing
237+
whenSupported RecvIPv4TTL $
238+
((lookupCmsg CmsgIdIPv4TTL cmsgs >>= decodeCmsg) :: Maybe IPv4TTL) `shouldNotBe` Nothing
239+
whenSupported RecvIPv4TOS $
240+
((lookupCmsg CmsgIdIPv4TOS cmsgs >>= decodeCmsg) :: Maybe IPv4TOS) `shouldNotBe` Nothing
241+
whenSupported RecvIPv4PktInfo $
242+
((lookupCmsg CmsgIdIPv4PktInfo cmsgs >>= decodeCmsg) :: Maybe IPv4PktInfo) `shouldNotBe` Nothing
240243
client sock addr = sendTo sock seg addr
241244

242245
seg = C.pack "This is a test message"
243246
udpTest client server
244247

245248
it "receives control messages for IPv6" $ do
246249
let server sock = do
247-
setSocketOption sock RecvIPv6HopLimit 1
248-
setSocketOption sock RecvIPv6TClass 1
249-
setSocketOption sock RecvIPv6PktInfo 1
250+
whenSupported RecvIPv6HopLimit $ setSocketOption sock RecvIPv6HopLimit 1
251+
whenSupported RecvIPv6TClass $ setSocketOption sock RecvIPv6TClass 1
252+
whenSupported RecvIPv6PktInfo $ setSocketOption sock RecvIPv6PktInfo 1
250253
(_, _, cmsgs, _) <- recvMsg sock 1024 128 mempty
251254

252-
((lookupCmsg CmsgIdIPv6HopLimit cmsgs >>= decodeCmsg) :: Maybe IPv6HopLimit) `shouldNotBe` Nothing
253-
((lookupCmsg CmsgIdIPv6TClass cmsgs >>= decodeCmsg) :: Maybe IPv6TClass) `shouldNotBe` Nothing
254-
((lookupCmsg CmsgIdIPv6PktInfo cmsgs >>= decodeCmsg) :: Maybe IPv6PktInfo) `shouldNotBe` Nothing
255+
256+
whenSupported RecvIPv6HopLimit $
257+
((lookupCmsg CmsgIdIPv6HopLimit cmsgs >>= decodeCmsg) :: Maybe IPv6HopLimit) `shouldNotBe` Nothing
258+
whenSupported RecvIPv6TClass $
259+
((lookupCmsg CmsgIdIPv6TClass cmsgs >>= decodeCmsg) :: Maybe IPv6TClass) `shouldNotBe` Nothing
260+
whenSupported RecvIPv6PktInfo $
261+
((lookupCmsg CmsgIdIPv6PktInfo cmsgs >>= decodeCmsg) :: Maybe IPv6PktInfo) `shouldNotBe` Nothing
255262
client sock addr = sendTo sock seg addr
256263

257264
seg = C.pack "This is a test message"
258265
udpTest6 client server
259266

260267
it "receives truncated control messages" $ do
261268
let server sock = do
262-
setSocketOption sock RecvIPv4TTL 1
263-
setSocketOption sock RecvIPv4TOS 1
264-
setSocketOption sock RecvIPv4PktInfo 1
269+
whenSupported RecvIPv4TTL $ setSocketOption sock RecvIPv4TTL 1
270+
whenSupported RecvIPv4TOS $ setSocketOption sock RecvIPv4TOS 1
271+
whenSupported RecvIPv4PktInfo $ setSocketOption sock RecvIPv4PktInfo 1
265272
(_, _, _, flags) <- recvMsg sock 1024 10 mempty
266273
flags .&. MSG_CTRUNC `shouldBe` MSG_CTRUNC
267274

0 commit comments

Comments
 (0)