Skip to content

Commit e1a4057

Browse files
Mistukekazu-yamamoto
authored andcommitted
network: Initial implementation
1 parent 3da86ef commit e1a4057

File tree

10 files changed

+475
-10
lines changed

10 files changed

+475
-10
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ cabal.sandbox.config
1919
.cabal-sandbox
2020
.stack-work/
2121
.ghc.*
22+
.vscode

Network/Socket/Buffer.hsc

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ module Network.Socket.Buffer (
1717

1818
#if !defined(mingw32_HOST_OS)
1919
import Foreign.C.Error (getErrno, eAGAIN, eWOULDBLOCK)
20+
#else
21+
import System.Win32.Types
22+
import Foreign.Ptr (nullPtr)
2023
#endif
2124
import Foreign.Marshal.Alloc (alloca, allocaBytes)
2225
import Foreign.Marshal.Utils (with)
@@ -25,15 +28,19 @@ import System.IO.Error (mkIOError, ioeSetErrorString, catchIOError)
2528

2629
#if defined(mingw32_HOST_OS)
2730
import GHC.IO.FD (FD(..), readRawBufferPtr, writeRawBufferPtr)
31+
import Network.Socket.Win32.CmsgHdr
32+
import Network.Socket.Win32.MsgHdr
33+
import Network.Socket.Win32.WSABuf
34+
#else
35+
import Network.Socket.Posix.CmsgHdr
36+
import Network.Socket.Posix.MsgHdr
37+
import Network.Socket.Posix.IOVec
2838
#endif
2939

3040
import Network.Socket.Imports
3141
import Network.Socket.Internal
3242
import Network.Socket.Name
3343
import Network.Socket.Types
34-
import Network.Socket.Posix.CmsgHdr
35-
import Network.Socket.Posix.MsgHdr
36-
import Network.Socket.Posix.IOVec
3744
import Network.Socket.Flag
3845

3946
-- | Send data to the socket. The recipient can be specified
@@ -195,13 +202,22 @@ sendBufMsg :: SocketAddress sa
195202
-> IO Int -- ^ The length actually sent
196203
sendBufMsg s sa bufsizs cmsgs flags = do
197204
sz <- withSocketAddress sa $ \addrPtr addrSize ->
205+
#if !defined(mingw32_HOST_OS)
198206
withIOVec bufsizs $ \(iovsPtr, iovsLen) -> do
207+
#else
208+
withWSABuf bufsizs $ \(wsaBPtr, wsaBLen) -> do
209+
#endif
199210
withCmsgs cmsgs $ \ctrlPtr ctrlLen -> do
200211
let msgHdr = MsgHdr {
201212
msgName = addrPtr
202213
, msgNameLen = fromIntegral addrSize
214+
#if !defined(mingw32_HOST_OS)
203215
, msgIov = iovsPtr
204216
, msgIovLen = fromIntegral iovsLen
217+
#else
218+
, msgBuffer = wsaBPtr
219+
, msgBufferLen = fromIntegral wsaBLen
220+
#endif
205221
, msgCtrl = castPtr ctrlPtr
206222
, msgCtrlLen = fromIntegral ctrlLen
207223
, msgFlags = 0
@@ -210,7 +226,12 @@ sendBufMsg s sa bufsizs cmsgs flags = do
210226
withFdSocket s $ \fd ->
211227
with msgHdr $ \msgHdrPtr ->
212228
throwSocketErrorWaitWrite s "Network.Socket.Buffer.sendMsg" $
229+
#if !defined(mingw32_HOST_OS)
213230
c_sendmsg fd msgHdrPtr cflags
231+
#else
232+
alloca $ \send_ptr ->
233+
c_sendmsg fd msgHdrPtr cflags send_ptr nullPtr nullPtr
234+
#endif
214235
return $ fromIntegral sz
215236

216237
-- | Receive data from the socket using recvmsg(2).
@@ -227,20 +248,38 @@ recvBufMsg :: SocketAddress sa
227248
recvBufMsg s bufsizs clen flags = do
228249
withNewSocketAddress $ \addrPtr addrSize ->
229250
allocaBytes clen $ \ctrlPtr ->
251+
#if !defined(mingw32_HOST_OS)
230252
withIOVec bufsizs $ \(iovsPtr, iovsLen) -> do
253+
#else
254+
withWSABuf bufsizs $ \(wsaBPtr, wsaBLen) -> do
255+
#endif
231256
let msgHdr = MsgHdr {
232257
msgName = addrPtr
233258
, msgNameLen = fromIntegral addrSize
259+
#if !defined(mingw32_HOST_OS)
234260
, msgIov = iovsPtr
235261
, msgIovLen = fromIntegral iovsLen
262+
#else
263+
, msgBuffer = wsaBPtr
264+
, msgBufferLen = fromIntegral wsaBLen
265+
#endif
236266
, msgCtrl = castPtr ctrlPtr
237267
, msgCtrlLen = fromIntegral clen
238268
, msgFlags = 0
239269
}
240-
cflags = fromMsgFlag flags
270+
_cflags = fromMsgFlag flags
241271
withFdSocket s $ \fd -> do
242272
with msgHdr $ \msgHdrPtr -> do
243-
len <- fromIntegral <$> throwSocketErrorWaitRead s "Network.Socket.Buffer.recvmg" (c_recvmsg fd msgHdrPtr cflags)
273+
len <- fromIntegral <$>
274+
#if !defined(mingw32_HOST_OS)
275+
throwSocketErrorWaitRead s "Network.Socket.Buffer.recvmg" $
276+
c_recvmsg fd msgHdrPtr _cflags
277+
#else
278+
alloca $ \len_ptr ->
279+
throwSocketErrorWaitRead s "Network.Socket.Buffer.recvmg" $
280+
c_recvmsg fd msgHdrPtr len_ptr nullPtr nullPtr
281+
peek len_ptr
282+
#endif
244283
sockaddr <- peekSocketAddress addrPtr `catchIOError` \_ -> getPeerName s
245284
hdr <- peek msgHdrPtr
246285
cmsgs <- parseCmsgs msgHdrPtr
@@ -250,20 +289,28 @@ recvBufMsg s bufsizs clen flags = do
250289
#if !defined(mingw32_HOST_OS)
251290
foreign import ccall unsafe "send"
252291
c_send :: CInt -> Ptr a -> CSize -> CInt -> IO CInt
292+
foreign import ccall unsafe "sendmsg"
293+
c_sendmsg :: CInt -> Ptr (MsgHdr sa) -> CInt -> IO CInt -- fixme CSsize
294+
foreign import ccall unsafe "recvmsg"
295+
c_recvmsg :: CInt -> Ptr (MsgHdr sa) -> CInt -> IO CInt
253296
#else
254297
foreign import CALLCONV SAFE_ON_WIN "ioctlsocket"
255298
c_ioctlsocket :: CInt -> CLong -> Ptr CULong -> IO CInt
256299
foreign import CALLCONV SAFE_ON_WIN "WSAGetLastError"
257300
c_WSAGetLastError :: IO CInt
301+
foreign import CALLCONV SAFE_ON_WIN "sendmsg"
302+
-- fixme Handle for SOCKET, see #426
303+
c_sendmsg :: CInt -> Ptr (MsgHdr sa) -> DWORD -> LPDWORD -> Ptr () -> Ptr () -> IO CInt
304+
foreign import CALLCONV SAFE_ON_WIN "recvmsg"
305+
c_recvmsg :: CInt -> Ptr (MsgHdr sa) -> LPDWORD -> Ptr () -> Ptr () -> IO CInt
306+
307+
failIfSockError = failIf_ (==#{const SOCKET_ERROR})
258308
#endif
309+
259310
foreign import ccall unsafe "recv"
260311
c_recv :: CInt -> Ptr CChar -> CSize -> CInt -> IO CInt
261312
foreign import CALLCONV SAFE_ON_WIN "sendto"
262313
c_sendto :: CInt -> Ptr a -> CSize -> CInt -> Ptr sa -> CInt -> IO CInt
263314
foreign import CALLCONV SAFE_ON_WIN "recvfrom"
264315
c_recvfrom :: CInt -> Ptr a -> CSize -> CInt -> Ptr sa -> Ptr CInt -> IO CInt
265316

266-
foreign import ccall unsafe "sendmsg"
267-
c_sendmsg :: CInt -> Ptr (MsgHdr sa) -> CInt -> IO CInt -- fixme CSsize
268-
foreign import ccall unsafe "recvmsg"
269-
c_recvmsg :: CInt -> Ptr (MsgHdr sa) -> CInt -> IO CInt

Network/Socket/Win32/Cmsg.hsc

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
{-# LANGUAGE CPP #-}
2+
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
3+
{-# LANGUAGE PatternSynonyms #-}
4+
{-# LANGUAGE ScopedTypeVariables #-}
5+
6+
module Network.Socket.Win32.Cmsg where
7+
8+
#include "HsNet.h"
9+
10+
import Data.ByteString.Internal
11+
import Foreign.ForeignPtr
12+
import System.IO.Unsafe (unsafeDupablePerformIO)
13+
import System.Win32.Types (HANDLE)
14+
15+
import Network.Socket.Imports
16+
import Network.Socket.Types
17+
18+
-- | Control message (ancillary data) including a pair of level and type.
19+
data Cmsg = Cmsg {
20+
cmsgId :: !CmsgId
21+
, cmsgData :: !ByteString
22+
} deriving (Eq, Show)
23+
24+
----------------------------------------------------------------
25+
26+
-- | Identifier of control message (ancillary data).
27+
data CmsgId = CmsgId {
28+
cmsgLevel :: !CInt
29+
, cmsglType :: !CInt
30+
} deriving (Eq, Show)
31+
32+
-- | The identifier for 'IPv4TTL'.
33+
pattern CmsgIdIPv4TTL :: CmsgId
34+
pattern CmsgIdIPv4TTL = CmsgId (#const IPPROTO_IP) (#const IP_TTL)
35+
36+
-- | The identifier for 'IPv6HopLimit'.
37+
pattern CmsgIdIPv6HopLimit :: CmsgId
38+
pattern CmsgIdIPv6HopLimit = CmsgId (#const IPPROTO_IPV6) (#const IPV6_HOPLIMIT)
39+
40+
-- | The identifier for 'IPv4TOS'.
41+
pattern CmsgIdIPv4TOS :: CmsgId
42+
pattern CmsgIdIPv4TOS = CmsgId (#const IPPROTO_IP) (#const IP_RECVTOS)
43+
44+
-- | The identifier for 'IPv6TClass'.
45+
pattern CmsgIdIPv6TClass :: CmsgId
46+
pattern CmsgIdIPv6TClass = CmsgId (#const IPPROTO_IPV6) (#const IPV6_RECVTCLASS)
47+
48+
-- | The identifier for 'IPv4PktInfo'.
49+
pattern CmsgIdIPv4PktInfo :: CmsgId
50+
pattern CmsgIdIPv4PktInfo = CmsgId (#const IPPROTO_IP) (#const IP_PKTINFO)
51+
52+
-- | The identifier for 'IPv6PktInfo'.
53+
pattern CmsgIdIPv6PktInfo :: CmsgId
54+
pattern CmsgIdIPv6PktInfo = CmsgId (#const IPPROTO_IPV6) (#const IPV6_PKTINFO)
55+
56+
-- Use WSADuplicateSocket for CmsgIdFd
57+
-- pattern CmsgIdFd :: CmsgId
58+
59+
----------------------------------------------------------------
60+
61+
-- | Looking up control message. The following shows an example usage:
62+
--
63+
-- > (lookupCmsg CmsgIdIPv4TOS cmsgs >>= decodeCmsg) :: Maybe IPv4TOS
64+
lookupCmsg :: CmsgId -> [Cmsg] -> Maybe Cmsg
65+
lookupCmsg _ [] = Nothing
66+
lookupCmsg cid (cmsg:cmsgs)
67+
| cmsgId cmsg == cid = Just cmsg
68+
| otherwise = lookupCmsg cid cmsgs
69+
70+
-- | Filtering control message.
71+
filterCmsg :: CmsgId -> [Cmsg] -> [Cmsg]
72+
filterCmsg cid cmsgs = filter (\cmsg -> cmsgId cmsg == cid) cmsgs
73+
74+
----------------------------------------------------------------
75+
76+
-- | A class to encode and decode control message.
77+
class Storable a => ControlMessage a where
78+
controlMessageId :: a -> CmsgId
79+
80+
encodeCmsg :: ControlMessage a => a -> Cmsg
81+
encodeCmsg x = unsafeDupablePerformIO $ do
82+
bs <- create siz $ \p0 -> do
83+
let p = castPtr p0
84+
poke p x
85+
return $ Cmsg (controlMessageId x) bs
86+
where
87+
siz = sizeOf x
88+
89+
decodeCmsg :: forall a . Storable a => Cmsg -> Maybe a
90+
decodeCmsg (Cmsg _ (PS fptr off len))
91+
| len < siz = Nothing
92+
| otherwise = unsafeDupablePerformIO $ withForeignPtr fptr $ \p0 -> do
93+
let p = castPtr (p0 `plusPtr` off)
94+
Just <$> peek p
95+
where
96+
siz = sizeOf (undefined :: a)
97+
98+
----------------------------------------------------------------
99+
100+
-- | Time to live of IPv4.
101+
newtype IPv4TTL = IPv4TTL DWORD deriving (Eq, Show, Storable)
102+
103+
instance ControlMessage IPv4TTL where
104+
controlMessageId _ = CmsgIdIPv4TTL
105+
106+
----------------------------------------------------------------
107+
108+
-- | Hop limit of IPv6.
109+
newtype IPv6HopLimit = IPv6HopLimit DWORD deriving (Eq, Show, Storable)
110+
111+
instance ControlMessage IPv6HopLimit where
112+
controlMessageId _ = CmsgIdIPv6HopLimit
113+
114+
----------------------------------------------------------------
115+
116+
-- | TOS of IPv4.
117+
newtype IPv4TOS = IPv4TOS DWORD deriving (Eq, Show, Storable)
118+
119+
instance ControlMessage IPv4TOS where
120+
controlMessageId _ = CmsgIdIPv4TOS
121+
122+
----------------------------------------------------------------
123+
124+
-- | Traffic class of IPv6.
125+
newtype IPv6TClass = IPv6TClass DWORD deriving (Eq, Show, Storable)
126+
127+
instance ControlMessage IPv6TClass where
128+
controlMessageId _ = CmsgIdIPv6TClass
129+
130+
----------------------------------------------------------------
131+
132+
-- | Network interface ID and local IPv4 address.
133+
data IPv4PktInfo = IPv4PktInfo ULONG HostAddress deriving (Eq)
134+
135+
instance Show IPv4PktInfo where
136+
show (IPv4PktInfo n sa ha) = "IPv4PktInfo " ++ show n ++ " " ++ show (hostAddressToTuple ha)
137+
138+
instance ControlMessage IPv4PktInfo where
139+
controlMessageId _ = CmsgIdIPv4PktInfo
140+
141+
instance Storable IPv4PktInfo where
142+
sizeOf _ = const #{size IN_PKTINFO}
143+
alignment _ = #alignment IN_PKTINFO
144+
poke p (IPv4PktInfo n sa ha) = do
145+
(#poke IN_PKTINFO, ipi_ifindex) p (fromIntegral n :: CInt)
146+
(#poke IN_PKTINFO, ipi_addr) p ha
147+
peek p = do
148+
n <- (#peek IN_PKTINFO, ipi_ifindex) p
149+
ha <- (#peek IN_PKTINFO, ipi_addr) p
150+
return $ IPv4PktInfo n ha
151+
152+
----------------------------------------------------------------
153+
154+
-- | Network interface ID and local IPv4 address.
155+
data IPv6PktInfo = IPv6PktInfo Int HostAddress6 deriving (Eq)
156+
157+
instance Show IPv6PktInfo where
158+
show (IPv6PktInfo n ha6) = "IPv6PktInfo " ++ show n ++ " " ++ show (hostAddress6ToTuple ha6)
159+
160+
instance ControlMessage IPv6PktInfo where
161+
controlMessageId _ = CmsgIdIPv6PktInfo
162+
163+
instance Storable IPv6PktInfo where
164+
sizeOf _ = const #{size IN6_PKTINFO}
165+
alignment _ = #alignment IN6_PKTINFO
166+
poke p (IPv6PktInfo n ha6) = do
167+
(#poke IN6_PKTINFO, ipi6_ifindex) p (fromIntegral n :: CInt)
168+
(#poke IN6_PKTINFO, ipi6_addr) p (In6Addr ha6)
169+
peek p = do
170+
In6Addr ha6 <- (#peek IN6_PKTINFO, ipi6_addr) p
171+
n :: ULONG <- (#peek IN6_PKTINFO, ipi6_ifindex) p
172+
return $ IPv6PktInfo (fromIntegral n) ha6

0 commit comments

Comments
 (0)