|
7 | 7 |
|
8 | 8 | module Network.Socket.Info where
|
9 | 9 |
|
| 10 | +import Control.Exception (try, IOException) |
10 | 11 | import Foreign.Marshal.Alloc (alloca, allocaBytes)
|
11 | 12 | import Foreign.Marshal.Utils (maybeWith, with)
|
12 | 13 | import GHC.IO.Exception (IOErrorType(NoSuchThing))
|
13 | 14 | import System.IO.Error (ioeSetErrorString, mkIOError)
|
| 15 | +import System.IO.Unsafe (unsafePerformIO) |
| 16 | +import Text.Read (readEither) |
14 | 17 |
|
15 | 18 | import Network.Socket.Imports
|
16 | 19 | import Network.Socket.Internal
|
17 |
| -import Network.Socket.Syscall |
| 20 | +import Network.Socket.Syscall (socket) |
18 | 21 | import Network.Socket.Types
|
19 | 22 |
|
20 | 23 | -----------------------------------------------------------------------------
|
@@ -467,10 +470,74 @@ showHostAddress6 ha6@(a1, a2, a3, a4)
|
467 | 470 | scanl (\c i -> if i == 0 then c - 1 else 0) 0 fields `zip` [0..]
|
468 | 471 |
|
469 | 472 | -----------------------------------------------------------------------------
|
470 |
| - |
471 | 473 | -- | A utility function to open a socket with `AddrInfo`.
|
472 | 474 | -- This is a just wrapper for the following code:
|
473 | 475 | --
|
474 | 476 | -- > \addr -> socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)
|
475 | 477 | openSocket :: AddrInfo -> IO Socket
|
476 | 478 | openSocket addr = socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)
|
| 479 | + |
| 480 | +----------------------------------------------------------------------------- |
| 481 | +-- SockEndpoint |
| 482 | + |
| 483 | +-- | Read a string representing a socket endpoint. |
| 484 | +readSockEndpoint :: PortNumber -> String -> Either String SockEndpoint |
| 485 | +readSockEndpoint defPort hostport = case hostport of |
| 486 | + '/':_ -> Right $ EndpointByAddr $ SockAddrUnix hostport |
| 487 | + '[':tl -> case span ((/=) ']') tl of |
| 488 | + (_, []) -> Left $ "unterminated IPv6 address: " <> hostport |
| 489 | + (ipv6, _:port) -> case readAddr ipv6 of |
| 490 | + Nothing -> Left $ "invalid IPv6 address: " <> ipv6 |
| 491 | + Just addr -> EndpointByAddr . sockAddrPort addr <$> readPort port |
| 492 | + _ -> case span ((/=) ':') hostport of |
| 493 | + (host, port) -> case readAddr host of |
| 494 | + Nothing -> EndpointByName host <$> readPort port |
| 495 | + Just addr -> EndpointByAddr . sockAddrPort addr <$> readPort port |
| 496 | + where |
| 497 | + readPort "" = Right defPort |
| 498 | + readPort ":" = Right defPort |
| 499 | + readPort (':':port) = case readEither port of |
| 500 | + Right p -> Right p |
| 501 | + Left _ -> Left $ "bad port: " <> port |
| 502 | + readPort x = Left $ "bad port: " <> x |
| 503 | + hints = Just $ defaultHints { addrFlags = [AI_NUMERICHOST] } |
| 504 | + readAddr host = case unsafePerformIO (try (getAddrInfo hints (Just host) Nothing)) of |
| 505 | + Left e -> Nothing where _ = e :: IOException |
| 506 | + Right r -> Just (addrAddress (head r)) |
| 507 | + sockAddrPort h p = case h of |
| 508 | + SockAddrInet _ a -> SockAddrInet p a |
| 509 | + SockAddrInet6 _ f a s -> SockAddrInet6 p f a s |
| 510 | + x -> x |
| 511 | + |
| 512 | +showSockEndpoint :: SockEndpoint -> String |
| 513 | +showSockEndpoint n = case n of |
| 514 | + EndpointByName h p -> h <> ":" <> show p |
| 515 | + EndpointByAddr a -> show a |
| 516 | + |
| 517 | +-- | Resolve a socket endpoint into a list of socket addresses. |
| 518 | +-- The result is always non-empty; Haskell throws an exception if name |
| 519 | +-- resolution fails. |
| 520 | +resolveEndpoint :: SockEndpoint -> IO [SockAddr] |
| 521 | +resolveEndpoint name = case name of |
| 522 | + EndpointByAddr a -> pure [a] |
| 523 | + EndpointByName host port -> fmap addrAddress <$> getAddrInfo hints (Just host) (Just (show port)) |
| 524 | + where |
| 525 | + hints = Just $ defaultHints { addrSocketType = Stream } |
| 526 | + -- prevents duplicates, otherwise getAddrInfo returns all socket types |
| 527 | + |
| 528 | +-- | Shortcut for creating a socket from a socket endpoint. |
| 529 | +-- |
| 530 | +-- >>> import Network.Socket |
| 531 | +-- >>> let Right sn = readSockEndpoint 0 "0.0.0.0:0" |
| 532 | +-- >>> (s, a) <- socketFromEndpoint sn head Stream defaultProtocol |
| 533 | +-- >>> bind s a |
| 534 | +socketFromEndpoint |
| 535 | + :: SockEndpoint |
| 536 | + -> ([SockAddr] -> SockAddr) |
| 537 | + -> SocketType |
| 538 | + -> ProtocolNumber |
| 539 | + -> IO (Socket, SockAddr) |
| 540 | +socketFromEndpoint end select stype protocol = do |
| 541 | + a <- select <$> resolveEndpoint end |
| 542 | + s <- socket (sockAddrFamily a) stype protocol |
| 543 | + pure (s, a) |
0 commit comments