1+ use crate :: config:: ServerConfig ;
12use core:: fmt;
3+ use socket2:: { Domain , Type } ;
24use std:: fmt:: Formatter ;
35use std:: io;
46use std:: io:: { Error , ErrorKind } ;
57use std:: net:: { IpAddr , SocketAddr } ;
68use std:: str:: FromStr ;
7- use socket2:: { Domain , Type } ;
89use tokio:: net:: { TcpListener , TcpStream } ;
9- use crate :: config:: ServerConfig ;
1010
1111pub struct TcpSocket {
12- tcp_listener : TcpListener ,
13- addr : TcpAddr
12+ tcp_listeners : Vec < TcpListener > ,
13+ addrs : Vec < TcpAddr > ,
14+ from_sys : bool ,
1415}
1516
1617impl TcpSocket {
17- pub fn bind ( config : & ServerConfig ) -> io:: Result < TcpSocket > {
18- let tcp_addr = TcpAddr :: new ( config) ?;
19- let socket = socket2:: Socket :: new ( tcp_addr. domain , Type :: STREAM , None ) ?;
18+ pub fn make_listener ( config : & ServerConfig ) -> io:: Result < TcpSocket > {
19+ let fd_listeners = Self :: find_fd_listeners ( ) ?;
20+ let ( tcp_listeners, tcp_addr, from_sys) = if !fd_listeners. 0 . is_empty ( ) {
21+ ( fd_listeners. 0 , fd_listeners. 1 , true )
22+ } else {
23+ let tcp_addr = TcpAddr :: from_config ( config) ?;
24+ ( vec ! [ Self :: bind( & tcp_addr) ?] , vec ! [ tcp_addr] , false )
25+ } ;
26+ Ok ( TcpSocket {
27+ tcp_listeners,
28+ addrs : tcp_addr,
29+ from_sys,
30+ } )
31+ }
32+
33+ fn find_fd_listeners ( ) -> io:: Result < ( Vec < TcpListener > , Vec < TcpAddr > ) > {
34+ let mut listen_fd = listenfd:: ListenFd :: from_env ( ) ;
35+ let mut fd_listeners = Vec :: new ( ) ;
36+ let mut addrs = Vec :: new ( ) ;
37+ if listen_fd. len ( ) > 0 {
38+ for index in 0 ..listen_fd. len ( ) {
39+ if let Ok ( Some ( listener) ) = listen_fd. take_tcp_listener ( index) {
40+ listener. set_nonblocking ( true ) ?; // PLD Point
41+ let tcp_listener = TcpListener :: from_std ( listener) ?;
42+ addrs. push ( TcpAddr :: from_socket ( & tcp_listener) ?) ;
43+ fd_listeners. push ( tcp_listener) ;
44+ }
45+ }
46+ return Ok ( ( fd_listeners, addrs) ) ;
47+ }
48+ Ok ( ( Vec :: new ( ) , Vec :: new ( ) ) )
49+ }
50+
51+ fn bind ( tcp_addr : & TcpAddr ) -> io:: Result < TcpListener > {
52+ let socket = socket2:: Socket :: new ( tcp_addr. domain , Type :: STREAM , None ) ?;
2053 if !tcp_addr. is_only_v6 {
2154 socket. set_only_v6 ( false ) ?;
2255 }
2356 socket. set_reuse_address ( true ) ?;
2457 socket. bind ( & tcp_addr. sock_addr . into ( ) ) ?;
2558 socket. listen ( 128 ) ?;
2659 let tcp_listener = TcpListener :: from_std ( socket. into ( ) ) ?;
27- Ok ( TcpSocket {
28- tcp_listener,
29- addr : tcp_addr
30- } )
60+ Ok ( tcp_listener)
3161 }
3262
33- pub async fn accept ( & self ) -> io:: Result < ( TcpStream , SocketAddr ) > {
34- self . tcp_listener . accept ( ) . await
63+ pub async fn accept ( & self ) -> io:: Result < ( TcpStream , SocketAddr ) > {
64+ for listener in & self . tcp_listeners {
65+ tokio:: select! {
66+ result = listener. accept( ) => {
67+ let ( stream, addr) = result?;
68+ return Ok ( ( stream, addr) ) ;
69+ }
70+ _ = tokio:: time:: sleep( std:: time:: Duration :: from_millis( 1 ) ) => { }
71+ }
72+ }
73+ Err ( Error :: new ( ErrorKind :: Other , "No listeners found." ) )
3574 }
36-
3775}
3876
3977impl fmt:: Display for TcpSocket {
4078 fn fmt ( & self , f : & mut Formatter < ' _ > ) -> fmt:: Result {
41- write ! ( f, "{}" , self . addr. sock_addr)
79+ match self . from_sys {
80+ true => {
81+ for ( i, addr) in self . addrs . iter ( ) . enumerate ( ) {
82+ if i > 0 {
83+ write ! ( f, ", " ) ?;
84+ }
85+ write ! ( f, "{}" , addr. sock_addr) ?;
86+ }
87+ Ok ( ( ) )
88+ }
89+ false => {
90+ write ! ( f, "{}" , self . addrs[ 0 ] . sock_addr)
91+ }
92+ }
4293 }
4394}
4495
@@ -50,24 +101,46 @@ pub struct TcpAddr {
50101}
51102
52103impl TcpAddr {
53- pub fn new ( config : & ServerConfig ) -> io:: Result < Self > {
104+ pub fn from_config ( config : & ServerConfig ) -> io:: Result < Self > {
54105 let bind_addr = config. bind_address . as_str ( ) ;
55- let parse_addr = Self :: parse_addr ( bind_addr ) ?;
56- let addr = SocketAddr :: new ( parse_addr . 0 , config. listen_port ) ;
106+ let parsed_addr = bind_addr . parse_addr ( ) ?;
107+ let addr = SocketAddr :: new ( parsed_addr . 0 , config. listen_port ) ;
57108 Ok ( TcpAddr {
58109 sock_addr : addr,
59- domain : parse_addr . 1 ,
110+ domain : parsed_addr . 1 ,
60111 is_only_v6 : bind_addr != "::" && bind_addr != "::0" ,
61112 } )
62113 }
63114
64- fn parse_addr ( ip_str : & str ) -> io:: Result < ( IpAddr , Domain ) > {
65- match IpAddr :: from_str ( ip_str) {
115+ pub fn from_socket ( listener : & TcpListener ) -> io:: Result < Self > {
116+ let socket_addr = listener. local_addr ( ) ?;
117+ let parsed_addr = socket_addr. parse_addr ( ) ?;
118+ Ok ( TcpAddr {
119+ sock_addr : socket_addr,
120+ domain : parsed_addr. 1 ,
121+ is_only_v6 : false ,
122+ } )
123+ }
124+ }
125+
126+ trait IpParser {
127+ fn parse_addr ( & self ) -> io:: Result < ( IpAddr , Domain ) > ;
128+ }
129+
130+ impl IpParser for & str {
131+ fn parse_addr ( & self ) -> io:: Result < ( IpAddr , Domain ) > {
132+ match IpAddr :: from_str ( self ) {
66133 Ok ( ip) => match ip {
67134 IpAddr :: V4 ( _) => Ok ( ( ip, Domain :: IPV4 ) ) ,
68135 IpAddr :: V6 ( _) => Ok ( ( ip, Domain :: IPV6 ) ) ,
69136 } ,
70137 Err ( e) => Err ( Error :: new ( ErrorKind :: Other , e) ) ,
71138 }
72139 }
140+ }
141+
142+ impl IpParser for SocketAddr {
143+ fn parse_addr ( & self ) -> io:: Result < ( IpAddr , Domain ) > {
144+ Ok ( ( self . ip ( ) , if self . is_ipv4 ( ) { Domain :: IPV4 } else { Domain :: IPV6 } ) )
145+ }
73146}
0 commit comments