@@ -11,6 +11,7 @@ module Network.Socket.Posix.Cmsg where
11
11
#include <sys/socket.h>
12
12
13
13
import Data.ByteString.Internal
14
+ import Data.Proxy
14
15
import Foreign.ForeignPtr
15
16
import System.IO.Unsafe (unsafeDupablePerformIO )
16
17
import System.Posix.Types (Fd (.. ))
@@ -87,24 +88,27 @@ filterCmsg cid cmsgs = filter (\cmsg -> cmsgId cmsg == cid) cmsgs
87
88
-- Each control message type has a numeric 'CmsgId' and a 'Storable'
88
89
-- data representation.
89
90
class Storable a => ControlMessage a where
90
- controlMessageId :: a -> CmsgId
91
+ controlMessageId :: Proxy a -> CmsgId
91
92
92
- encodeCmsg :: ControlMessage a => a -> Cmsg
93
+ encodeCmsg :: forall a . ControlMessage a => a -> Cmsg
93
94
encodeCmsg x = unsafeDupablePerformIO $ do
94
95
bs <- create siz $ \ p0 -> do
95
96
let p = castPtr p0
96
97
poke p x
97
- return $ Cmsg (controlMessageId x) bs
98
+ let cmsid = controlMessageId (Proxy :: Proxy a )
99
+ return $ Cmsg cmsid bs
98
100
where
99
101
siz = sizeOf x
100
102
101
- decodeCmsg :: forall a . Storable a => Cmsg -> Maybe a
102
- decodeCmsg (Cmsg _ (PS fptr off len))
103
- | len < siz = Nothing
104
- | otherwise = unsafeDupablePerformIO $ withForeignPtr fptr $ \ p0 -> do
103
+ decodeCmsg :: forall a . (ControlMessage a , Storable a ) => Cmsg -> Maybe a
104
+ decodeCmsg (Cmsg cmsid (PS fptr off len))
105
+ | cid /= cmsid = Nothing
106
+ | len < siz = Nothing
107
+ | otherwise = unsafeDupablePerformIO $ withForeignPtr fptr $ \ p0 -> do
105
108
let p = castPtr (p0 `plusPtr` off)
106
109
Just <$> peek p
107
110
where
111
+ cid = controlMessageId (Proxy :: Proxy a )
108
112
siz = sizeOf (undefined :: a )
109
113
110
114
----------------------------------------------------------------
0 commit comments