3
3
// SPDX-License-Identifier: Apache-2.0
4
4
//
5
5
6
+ use nix:: unistd;
6
7
use protobuf:: { CodedInputStream , Message } ;
7
8
use std:: collections:: HashMap ;
8
9
use std:: os:: unix:: io:: RawFd ;
@@ -25,6 +26,7 @@ use tokio::{
25
26
prelude:: * ,
26
27
stream:: Stream ,
27
28
sync:: mpsc:: { channel, Receiver , Sender } ,
29
+ sync:: watch,
28
30
} ;
29
31
use tokio_vsock:: VsockListener ;
30
32
@@ -33,6 +35,9 @@ pub struct Server {
33
35
listeners : Vec < RawFd > ,
34
36
methods : Arc < HashMap < String , Box < dyn MethodHandler + Send + Sync > > > ,
35
37
domain : Option < Domain > ,
38
+ disconnect_tx : Option < watch:: Sender < i32 > > ,
39
+ all_conn_done_rx : Option < Receiver < i32 > > ,
40
+ stop_listen_tx : Option < Sender < Sender < RawFd > > > ,
36
41
}
37
42
38
43
impl Default for Server {
@@ -41,6 +46,9 @@ impl Default for Server {
41
46
listeners : Vec :: with_capacity ( 1 ) ,
42
47
methods : Arc :: new ( HashMap :: new ( ) ) ,
43
48
domain : None ,
49
+ disconnect_tx : None ,
50
+ all_conn_done_rx : None ,
51
+ stop_listen_tx : None ,
44
52
}
45
53
}
46
54
}
@@ -60,6 +68,7 @@ impl Server {
60
68
let ( fd, domain) = common:: do_bind ( host) ?;
61
69
self . domain = Some ( domain) ;
62
70
71
+ common:: do_listen ( fd) ?;
63
72
self . listeners . push ( fd) ;
64
73
Ok ( self )
65
74
}
@@ -79,30 +88,27 @@ impl Server {
79
88
self
80
89
}
81
90
82
- fn listen ( & self ) -> Result < RawFd > {
91
+ fn get_listenfd ( & self ) -> Result < RawFd > {
83
92
if self . listeners . is_empty ( ) {
84
93
return Err ( Error :: Others ( "ttrpc-rust not bind" . to_string ( ) ) ) ;
85
94
}
86
95
87
- let listenfd = self . listeners [ 0 ] ;
88
- common:: do_listen ( listenfd) ?;
89
-
96
+ let listenfd = self . listeners [ self . listeners . len ( ) - 1 ] ;
90
97
Ok ( listenfd)
91
98
}
92
99
93
- pub async fn start ( & self ) -> Result < ( ) > {
94
- let listenfd = self . listen ( ) ?;
100
+ pub async fn start ( & mut self ) -> Result < ( ) > {
101
+ let listenfd = self . get_listenfd ( ) ?;
95
102
96
103
match self . domain . as_ref ( ) . unwrap ( ) {
97
104
Domain :: Unix => {
98
105
let sys_unix_listener;
99
106
unsafe {
100
107
sys_unix_listener = SysUnixListener :: from_raw_fd ( listenfd) ;
101
108
}
102
- let mut unix_listener = UnixListener :: from_std ( sys_unix_listener) . unwrap ( ) ;
103
- let incoming = unix_listener. incoming ( ) ;
109
+ let unix_listener = UnixListener :: from_std ( sys_unix_listener) . unwrap ( ) ;
104
110
105
- self . do_start ( listenfd, incoming ) . await
111
+ self . do_start ( listenfd, unix_listener ) . await
106
112
}
107
113
Domain :: Vsock => {
108
114
let incoming;
@@ -115,52 +121,143 @@ impl Server {
115
121
}
116
122
}
117
123
118
- pub async fn do_start < I , S > ( & self , listenfd : RawFd , mut incoming : I ) -> Result < ( ) >
124
+ pub async fn do_start < I , S > ( & mut self , listenfd : RawFd , mut incoming : I ) -> Result < ( ) >
119
125
where
120
- I : Stream < Item = std:: io:: Result < S > > + Unpin ,
126
+ I : Stream < Item = std:: io:: Result < S > > + Unpin + Send + ' static + AsRawFd ,
121
127
S : AsyncRead + AsyncWrite + AsRawFd + Send + ' static ,
122
128
{
123
- while let Some ( result) = incoming. next ( ) . await {
124
- match result {
125
- Ok ( stream) => {
126
- common:: set_fd_close_exec ( stream. as_raw_fd ( ) ) ?;
127
- let methods = self . methods . clone ( ) ;
128
- tokio:: spawn ( async move {
129
- let ( mut reader, mut writer) = split ( stream) ;
130
- let ( tx, mut rx) : ( Sender < Vec < u8 > > , Receiver < Vec < u8 > > ) = channel ( 100 ) ;
131
-
132
- tokio:: spawn ( async move {
133
- while let Some ( buf) = rx. recv ( ) . await {
134
- if let Err ( e) = writer. write_all ( & buf) . await {
135
- error ! ( "write_message got error: {:?}" , e) ;
136
- }
137
- }
138
- } ) ;
129
+ let methods = self . methods . clone ( ) ;
130
+ let ( disconnect_tx, close_conn_rx) = watch:: channel ( 0 ) ;
131
+ self . disconnect_tx = Some ( disconnect_tx) ;
139
132
140
- loop {
141
- let tx = tx. clone ( ) ;
133
+ let ( conn_done_tx, all_conn_done_rx) = channel :: < i32 > ( 1 ) ;
134
+
135
+ self . all_conn_done_rx = Some ( all_conn_done_rx) ;
136
+ let ( stop_listen_tx, mut stop_listen_rx) = channel ( 1 ) ;
137
+ self . stop_listen_tx = Some ( stop_listen_tx) ;
138
+
139
+ tokio:: spawn ( async move {
140
+ loop {
141
+ tokio:: select! {
142
+ conn = incoming. next( ) => {
143
+ if let Some ( conn) = conn {
144
+ // Accept a new connection
142
145
let methods = methods. clone( ) ;
146
+ match conn {
147
+ Ok ( stream) => {
148
+ let fd = stream. as_raw_fd( ) ;
149
+ if let Err ( e) = common:: set_fd_close_exec( fd) {
150
+ error!( "{:?}" , e) ;
151
+ continue ;
152
+ }
153
+
154
+ let mut close_conn_rx = close_conn_rx. clone( ) ;
143
155
144
- match receive ( & mut reader) . await {
145
- Ok ( message) => {
156
+ let ( req_done_tx, mut all_req_done_rx) = channel:: <i32 >( 1 ) ;
157
+ let conn_done_tx2 = conn_done_tx. clone( ) ;
158
+
159
+ // The connection handler
146
160
tokio:: spawn( async move {
147
- handle_request ( tx, listenfd, methods, message) . await ;
161
+ let ( mut reader, mut writer) = split( stream) ;
162
+ let ( tx, mut rx) : ( Sender <Vec <u8 >>, Receiver <Vec <u8 >>) = channel( 100 ) ;
163
+
164
+ tokio:: spawn( async move {
165
+ while let Some ( buf) = rx. recv( ) . await {
166
+ if let Err ( e) = writer. write_all( & buf) . await {
167
+ error!( "write_message got error: {:?}" , e) ;
168
+ }
169
+ }
170
+ } ) ;
171
+
172
+ loop {
173
+ let tx = tx. clone( ) ;
174
+ let methods = methods. clone( ) ;
175
+ let req_done_tx2 = req_done_tx. clone( ) ;
176
+
177
+ tokio:: select! {
178
+ resp = receive( & mut reader) => {
179
+ match resp {
180
+ Ok ( message) => {
181
+ tokio:: spawn( async move {
182
+ handle_request( tx, listenfd, methods, message) . await ;
183
+ drop( req_done_tx2) ;
184
+ } ) ;
185
+ }
186
+ Err ( e) => {
187
+ trace!( "error {:?}" , e) ;
188
+ break ;
189
+ }
190
+ }
191
+ }
192
+ v = close_conn_rx. recv( ) => {
193
+ // 0 is the init value of this watch, not a valid signal
194
+ // is_none means the tx was dropped.
195
+ if v. is_none( ) || v. unwrap( ) != 0 {
196
+ info!( "Stop accepting new connections." ) ;
197
+ break ;
198
+ }
199
+ }
200
+ }
201
+ }
202
+
203
+ drop( req_done_tx) ;
204
+ all_req_done_rx. recv( ) . await ;
205
+ drop( conn_done_tx2) ;
148
206
} ) ;
149
207
}
150
208
Err ( e) => {
151
- trace ! ( "error {:?}" , e) ;
152
- break ;
209
+ error!( "{:?}" , e)
153
210
}
154
211
}
212
+
213
+ } else {
214
+ break ;
215
+ }
216
+ }
217
+ fd_tx = stop_listen_rx. recv( ) => {
218
+ if let Some ( mut fd_tx) = fd_tx {
219
+ let dup_fd = unistd:: dup( incoming. as_raw_fd( ) ) . unwrap( ) ;
220
+ common:: set_fd_close_exec( dup_fd) . unwrap( ) ;
221
+ drop( incoming) ;
222
+
223
+ fd_tx. send( dup_fd) . await . unwrap( ) ;
224
+ break ;
155
225
}
156
- } ) ;
226
+ }
157
227
}
158
- Err ( e) => error ! ( "{:?}" , e) ,
159
228
}
160
- }
229
+ drop ( conn_done_tx) ;
230
+ } ) ;
231
+ Ok ( ( ) )
232
+ }
233
+
234
+ pub async fn shutdown ( & mut self ) -> Result < ( ) > {
235
+ self . stop_listen ( ) . await ;
236
+ self . disconnect ( ) . await ;
161
237
162
238
Ok ( ( ) )
163
239
}
240
+
241
+ pub async fn disconnect ( & mut self ) {
242
+ if let Some ( tx) = self . disconnect_tx . take ( ) {
243
+ tx. broadcast ( 1 ) . ok ( ) ;
244
+ }
245
+
246
+ if let Some ( mut rx) = self . all_conn_done_rx . take ( ) {
247
+ rx. recv ( ) . await ;
248
+ }
249
+ }
250
+
251
+ pub async fn stop_listen ( & mut self ) {
252
+ if let Some ( mut tx) = self . stop_listen_tx . take ( ) {
253
+ let ( fd_tx, mut fd_rx) = channel ( 1 ) ;
254
+ tx. send ( fd_tx) . await . unwrap ( ) ;
255
+
256
+ let fd = fd_rx. recv ( ) . await . unwrap ( ) ;
257
+ self . listeners . clear ( ) ;
258
+ self . listeners . push ( fd) ;
259
+ }
260
+ }
164
261
}
165
262
166
263
async fn handle_request (
0 commit comments