Skip to content

Commit f29aa06

Browse files
committed
Implement socket endpoints, in particular reading a string description of them
Useful for 3rd-party networking applications that might want to pass around service specifiers without worrying whether these are IP addresses, DNS names, or UNIX-domain socket paths. Previously, there was no data type to encapsulate these options together. In particular, getAddrInfo had to be used to resolve DNS names into a SockAddr before calling connect/bind, but it could not deal with UNIX domain sockets. The new function sockNameToAddr takes this role, transparently converting DNS names and passing through non-DNS-names unaltered, so that it can be used uniformly without worrying about the specific type of input name/address.
1 parent 29c11bf commit f29aa06

File tree

4 files changed

+122
-2
lines changed

4 files changed

+122
-2
lines changed

Network/Socket.hs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ module Network.Socket
152152
, Socket
153153
, socket
154154
, openSocket
155+
, socketFromEndpoint
155156
, withFdSocket
156157
, unsafeFdSocket
157158
, touchSocket
@@ -182,8 +183,14 @@ module Network.Socket
182183
-- ** Protocol number
183184
, ProtocolNumber
184185
, defaultProtocol
186+
-- * Basic socket endpoint type
187+
, SockEndpoint(..)
188+
, readSockEndpoint
189+
, showSockEndpoint
190+
, resolveEndpoint
185191
-- * Basic socket address type
186192
, SockAddr(..)
193+
, sockAddrFamily
187194
, isSupportedSockAddr
188195
, getPeerName
189196
, getSocketName

Network/Socket/Info.hsc

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,17 @@
77

88
module Network.Socket.Info where
99

10+
import Control.Exception (try, IOException)
1011
import Foreign.Marshal.Alloc (alloca, allocaBytes)
1112
import Foreign.Marshal.Utils (maybeWith, with)
1213
import GHC.IO.Exception (IOErrorType(NoSuchThing))
1314
import System.IO.Error (ioeSetErrorString, mkIOError)
15+
import System.IO.Unsafe (unsafePerformIO)
16+
import Text.Read (readEither)
1417

1518
import Network.Socket.Imports
1619
import Network.Socket.Internal
17-
import Network.Socket.Syscall
20+
import Network.Socket.Syscall (socket)
1821
import Network.Socket.Types
1922

2023
-----------------------------------------------------------------------------
@@ -467,10 +470,74 @@ showHostAddress6 ha6@(a1, a2, a3, a4)
467470
scanl (\c i -> if i == 0 then c - 1 else 0) 0 fields `zip` [0..]
468471

469472
-----------------------------------------------------------------------------
470-
471473
-- | A utility function to open a socket with `AddrInfo`.
472474
-- This is a just wrapper for the following code:
473475
--
474476
-- > \addr -> socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)
475477
openSocket :: AddrInfo -> IO Socket
476478
openSocket addr = socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)
479+
480+
-----------------------------------------------------------------------------
481+
-- SockEndpoint
482+
483+
-- | Read a string representing a socket endpoint.
484+
readSockEndpoint :: PortNumber -> String -> Either String SockEndpoint
485+
readSockEndpoint defPort hostport = case hostport of
486+
'/':_ -> Right $ EndpointByAddr $ SockAddrUnix hostport
487+
'[':tl -> case span ((/=) ']') tl of
488+
(_, []) -> Left $ "unterminated IPv6 address: " <> hostport
489+
(ipv6, _:port) -> case readAddr ipv6 of
490+
Nothing -> Left $ "invalid IPv6 address: " <> ipv6
491+
Just addr -> EndpointByAddr . sockAddrPort addr <$> readPort port
492+
_ -> case span ((/=) ':') hostport of
493+
(host, port) -> case readAddr host of
494+
Nothing -> EndpointByName host <$> readPort port
495+
Just addr -> EndpointByAddr . sockAddrPort addr <$> readPort port
496+
where
497+
readPort "" = Right defPort
498+
readPort ":" = Right defPort
499+
readPort (':':port) = case readEither port of
500+
Right p -> Right p
501+
Left _ -> Left $ "bad port: " <> port
502+
readPort x = Left $ "bad port: " <> x
503+
hints = Just $ defaultHints { addrFlags = [AI_NUMERICHOST] }
504+
readAddr host = case unsafePerformIO (try (getAddrInfo hints (Just host) Nothing)) of
505+
Left e -> Nothing where _ = e :: IOException
506+
Right r -> Just (addrAddress (head r))
507+
sockAddrPort h p = case h of
508+
SockAddrInet _ a -> SockAddrInet p a
509+
SockAddrInet6 _ f a s -> SockAddrInet6 p f a s
510+
x -> x
511+
512+
showSockEndpoint :: SockEndpoint -> String
513+
showSockEndpoint n = case n of
514+
EndpointByName h p -> h <> ":" <> show p
515+
EndpointByAddr a -> show a
516+
517+
-- | Resolve a socket endpoint into a list of socket addresses.
518+
-- The result is always non-empty; Haskell throws an exception if name
519+
-- resolution fails.
520+
resolveEndpoint :: SockEndpoint -> IO [SockAddr]
521+
resolveEndpoint name = case name of
522+
EndpointByAddr a -> pure [a]
523+
EndpointByName host port -> fmap addrAddress <$> getAddrInfo hints (Just host) (Just (show port))
524+
where
525+
hints = Just $ defaultHints { addrSocketType = Stream }
526+
-- prevents duplicates, otherwise getAddrInfo returns all socket types
527+
528+
-- | Shortcut for creating a socket from a socket endpoint.
529+
--
530+
-- >>> import Network.Socket
531+
-- >>> let Right sn = readSockEndpoint 0 "0.0.0.0:0"
532+
-- >>> (s, a) <- socketFromEndpoint sn head Stream defaultProtocol
533+
-- >>> bind s a
534+
socketFromEndpoint
535+
:: SockEndpoint
536+
-> ([SockAddr] -> SockAddr)
537+
-> SocketType
538+
-> ProtocolNumber
539+
-> IO (Socket, SockAddr)
540+
socketFromEndpoint end select stype protocol = do
541+
a <- select <$> resolveEndpoint end
542+
s <- socket (sockAddrFamily a) stype protocol
543+
pure (s, a)

Network/Socket/Types.hsc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ module Network.Socket.Types (
5353
, withNewSocketAddress
5454

5555
-- * Socket address type
56+
, SockEndpoint(..)
5657
, SockAddr(..)
58+
, sockAddrFamily
5759
, isSupportedSockAddr
5860
, HostAddress
5961
, hostAddressToTuple
@@ -1041,6 +1043,23 @@ type FlowInfo = Word32
10411043
-- | Scope identifier.
10421044
type ScopeID = Word32
10431045

1046+
-- | Socket endpoints.
1047+
--
1048+
-- A wrapper around socket addresses that also accommodates the
1049+
-- popular usage of specifying them by name, e.g. "example.com:80".
1050+
-- We don't support service names here (string aliases for port
1051+
-- numbers) because they also imply a particular socket type, which
1052+
-- is outside of the scope of this data type.
1053+
--
1054+
-- This roughly corresponds to the "authority" part of a URI, as
1055+
-- defined here: https://tools.ietf.org/html/rfc3986#section-3.2
1056+
--
1057+
-- See also 'Network.Socket.socketFromEndpoint'.
1058+
data SockEndpoint
1059+
= EndpointByName !String !PortNumber
1060+
| EndpointByAddr !SockAddr
1061+
deriving (Eq, Ord)
1062+
10441063
-- | Socket addresses.
10451064
-- The existence of a constructor does not necessarily imply that
10461065
-- that socket address type is supported on your system: see
@@ -1064,6 +1083,12 @@ instance NFData SockAddr where
10641083
rnf (SockAddrInet6 _ _ _ _) = ()
10651084
rnf (SockAddrUnix str) = rnf str
10661085

1086+
sockAddrFamily :: SockAddr -> Family
1087+
sockAddrFamily addr = case addr of
1088+
SockAddrInet _ _ -> AF_INET
1089+
SockAddrInet6 _ _ _ _ -> AF_INET6
1090+
SockAddrUnix _ -> AF_UNIX
1091+
10671092
-- | Is the socket address type supported on this system?
10681093
isSupportedSockAddr :: SockAddr -> Bool
10691094
isSupportedSockAddr addr = case addr of

tests/Network/SocketSpec.hs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,27 @@ spec = do
158158
-- check if an exception is not thrown.
159159
isSupportedSockAddr addr `shouldBe` True
160160

161+
it "endpoints API, IPv4" $ do
162+
let Right end = readSockEndpoint 0 "127.0.0.1:6001"
163+
(sock, addr) <- socketFromEndpoint end head Stream defaultProtocol
164+
bind sock addr
165+
listen sock 1
166+
close sock
167+
168+
it "endpoints API, IPv6" $ do
169+
let Right end = readSockEndpoint 0 "[::1]:6001"
170+
(sock, addr) <- socketFromEndpoint end head Stream defaultProtocol
171+
bind sock addr
172+
listen sock 1
173+
close sock
174+
175+
it "endpoints API, DNS" $ do
176+
let Right end = readSockEndpoint 0 "localhost:6001"
177+
(sock, addr) <- socketFromEndpoint end head Stream defaultProtocol
178+
bind sock addr
179+
listen sock 1
180+
close sock
181+
161182
#if !defined(mingw32_HOST_OS)
162183
when isUnixDomainSocketAvailable $ do
163184
context "unix sockets" $ do

0 commit comments

Comments
 (0)