@@ -20,6 +20,7 @@ use std::{
20
20
time:: Duration ,
21
21
} ;
22
22
23
+ use event_listener:: { Event , EventListener } ;
23
24
use futures_util:: { stream:: SelectAll , Stream , StreamExt } ;
24
25
use http_body:: Body ;
25
26
use hyper:: { body:: HttpBody , server:: conn:: Connection , Request , Response } ;
@@ -29,6 +30,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
29
30
use tokio_rustls:: rustls:: ServerConfig ;
30
31
use tower_http:: add_extension:: AddExtension ;
31
32
use tower_service:: Service ;
33
+ use tracing:: Instrument ;
32
34
33
35
use crate :: {
34
36
maybe_tls:: { MaybeTlsAcceptor , MaybeTlsStream , TlsStreamInfo } ,
@@ -153,6 +155,17 @@ impl AcceptError {
153
155
/// Returns an error if the proxy protocol or TLS handshake failed.
154
156
/// Returns the connection, which should be used to spawn a task to serve the
155
157
/// connection.
158
+ #[ allow( clippy:: type_complexity) ]
159
+ #[ tracing:: instrument(
160
+ name = "accept" ,
161
+ skip_all,
162
+ fields(
163
+ network. protocol. name = "http" ,
164
+ network. peer. address,
165
+ network. peer. port,
166
+ ) ,
167
+ err,
168
+ ) ]
156
169
async fn accept < S , B > (
157
170
maybe_proxy_acceptor : & MaybeProxyAcceptor ,
158
171
maybe_tls_acceptor : & MaybeTlsAcceptor ,
@@ -171,6 +184,18 @@ where
171
184
B :: Data : Send + ' static ,
172
185
B :: Error : std:: error:: Error + Send + Sync + ' static ,
173
186
{
187
+ let span = tracing:: Span :: current ( ) ;
188
+
189
+ match peer_addr {
190
+ SocketAddr :: Net ( addr) => {
191
+ span. record ( "network.peer.address" , tracing:: field:: display ( addr. ip ( ) ) ) ;
192
+ span. record ( "network.peer.port" , addr. port ( ) ) ;
193
+ }
194
+ SocketAddr :: Unix ( ref addr) => {
195
+ span. record ( "network.peer.address" , tracing:: field:: debug ( addr) ) ;
196
+ }
197
+ }
198
+
174
199
// Wrap the connection acceptation logic in a timeout
175
200
tokio:: time:: timeout ( HANDSHAKE_TIMEOUT , async move {
176
201
let ( proxy, stream) = maybe_proxy_acceptor
@@ -209,6 +234,7 @@ where
209
234
210
235
Ok ( conn)
211
236
} )
237
+ . instrument ( span)
212
238
. await
213
239
. map_err ( AcceptError :: handshake_timeout) ?
214
240
}
@@ -220,19 +246,28 @@ pin_project! {
220
246
/// signal is received, the boolean is set to true. The connection will then check the
221
247
/// boolean before polling the underlying connection, and if it's true, it will start a
222
248
/// graceful shutdown.
249
+ ///
250
+ /// We also use an event listener to wake up the connection when the shutdown signal is
251
+ /// received, because the connection needs to be polled again to start the graceful shutdown.
223
252
struct AbortableConnection <C > {
224
253
#[ pin]
225
254
connection: C ,
255
+ #[ pin]
256
+ shutdown_listener: EventListener ,
257
+ shutdown_event: Arc <Event >,
226
258
shutdown_in_progress: Arc <AtomicBool >,
227
259
did_start_shutdown: bool ,
228
260
}
229
261
}
230
262
231
263
impl < C > AbortableConnection < C > {
232
- fn new ( connection : C , shutdown_in_progress : & Arc < AtomicBool > ) -> Self {
264
+ fn new ( connection : C , shutdown_in_progress : & Arc < AtomicBool > , event : & Arc < Event > ) -> Self {
265
+ let shutdown_listener = EventListener :: new ( ) ;
233
266
Self {
234
267
connection,
268
+ shutdown_listener,
235
269
shutdown_in_progress : Arc :: clone ( shutdown_in_progress) ,
270
+ shutdown_event : Arc :: clone ( event) ,
236
271
did_start_shutdown : false ,
237
272
}
238
273
}
@@ -254,10 +289,20 @@ where
254
289
fn poll ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
255
290
let mut this = self . project ( ) ;
256
291
257
- // XXX: This assumes the task will be polled again after the shutdown signal is
258
- // received. I *think* that internally `graceful_shutdown` is only
259
- // setting a bunch of flags anyway, so I expect it should be fine to not have a
260
- // waker here.
292
+ // If we aren't listening for the shutdown signal, start listening
293
+ if !this. shutdown_listener . is_listening ( ) {
294
+ // XXX: it feels like we should setup the listener when we create it, but it
295
+ // needs a `Pin<&mut EventListener>` to do so, and I can't figure out
296
+ // how to get one outside of the `poll` method.
297
+ this. shutdown_listener . as_mut ( ) . listen ( this. shutdown_event ) ;
298
+ }
299
+
300
+ // Poll the shutdown signal, so that wakers get registered.
301
+ // XXX: I don't think we care about the result of this poll, since it's only
302
+ // really to register wakers. But I'm not sure if it's safe to
303
+ // ignore the result.
304
+ let _ = this. shutdown_listener . poll ( cx) ;
305
+
261
306
if !* this. did_start_shutdown
262
307
&& this
263
308
. shutdown_in_progress
@@ -312,6 +357,7 @@ where
312
357
313
358
// A shared atomic boolean to tell all connections to shutdown
314
359
let shutdown_in_progress = Arc :: new ( AtomicBool :: new ( false ) ) ;
360
+ let shutdown_event = Arc :: new ( Event :: new ( ) ) ;
315
361
316
362
loop {
317
363
tokio:: select! {
@@ -330,10 +376,10 @@ where
330
376
match res {
331
377
Some ( Ok ( Ok ( connection) ) ) => {
332
378
tracing:: trace!( "Accepted connection" ) ;
333
- let conn = AbortableConnection :: new( connection, & shutdown_in_progress) ;
379
+ let conn = AbortableConnection :: new( connection, & shutdown_in_progress, & shutdown_event ) ;
334
380
connection_tasks. spawn( conn) ;
335
381
} ,
336
- Some ( Ok ( Err ( e ) ) ) => tracing :: error! ( " Connection did not finish handshake: {e}" ) ,
382
+ Some ( Ok ( Err ( _e ) ) ) => { /* Connection did not finish handshake, error should be logged in `accept` */ } ,
337
383
Some ( Err ( e) ) => tracing:: error!( "Join error: {e}" ) ,
338
384
None => tracing:: error!( "Join set was polled even though it was empty" ) ,
339
385
}
@@ -368,6 +414,7 @@ where
368
414
369
415
// Tell the active connections to shutdown
370
416
shutdown_in_progress. store ( true , std:: sync:: atomic:: Ordering :: Relaxed ) ;
417
+ shutdown_event. notify ( usize:: MAX ) ;
371
418
372
419
// Wait for connections to cleanup
373
420
if !accept_tasks. is_empty ( ) || !connection_tasks. is_empty ( ) {
@@ -386,10 +433,10 @@ where
386
433
match res {
387
434
Some ( Ok ( Ok ( connection) ) ) => {
388
435
tracing:: trace!( "Accepted connection" ) ;
389
- let conn = AbortableConnection :: new( connection, & shutdown_in_progress) ;
436
+ let conn = AbortableConnection :: new( connection, & shutdown_in_progress, & shutdown_event ) ;
390
437
connection_tasks. spawn( conn) ;
391
438
}
392
- Some ( Ok ( Err ( e ) ) ) => tracing :: error! ( " Connection did not finish handshake: {e}" ) ,
439
+ Some ( Ok ( Err ( _e ) ) ) => { /* Connection did not finish handshake, error should be logged in `accept` */ } ,
393
440
Some ( Err ( e) ) => tracing:: error!( "Join error: {e}" ) ,
394
441
None => tracing:: error!( "Join set was polled even though it was empty" ) ,
395
442
}
0 commit comments