1
1
use std:: sync:: atomic:: { AtomicBool , Ordering } ;
2
2
use std:: sync:: Arc ;
3
3
use tokenizers:: Tokenizer ;
4
+ use tokio:: time:: Instant ;
4
5
use text_generation_client:: { Batch , NextTokenChooserParameters , Request , ShardedClient } ;
5
6
6
7
const TEST_INPUT : & str = "liveness" ;
@@ -25,9 +26,15 @@ impl Health {
25
26
}
26
27
27
28
pub ( crate ) async fn check ( & mut self ) -> bool {
28
- if self . generation_health . load ( Ordering :: SeqCst ) {
29
+ let generation_healthy = self . generation_health . load ( Ordering :: SeqCst ) ;
30
+
31
+ let mut guard = Guard { prefill : !generation_healthy, start_time : Some ( Instant :: now ( ) ) } ;
32
+
33
+ let ok = if generation_healthy {
29
34
// Generation is healthy, we only check that the shards are answering gRPC calls
30
- self . client . health ( ) . await . is_ok ( )
35
+ self . client . health ( ) . await
36
+ . map_err ( |err| tracing:: error!( "Basic shard healthcheck error: {err}" ) )
37
+ . is_ok ( )
31
38
} else {
32
39
// Generation is unhealthy or have not sent any generation request yet
33
40
@@ -51,13 +58,32 @@ impl Health {
51
58
requests : vec ! [ liveness_request] ,
52
59
total_tokens : 1 ,
53
60
} ;
54
- // Skips the queue
61
+ // Skips the queue, but will still be serialized behind in-flight prefill/next_token requests
55
62
let value = self . client . prefill ( batch, vec ! [ ] ) . await
56
- . map_err ( |err| tracing:: error!( "Healthcheck error: {err}" ) )
63
+ . map_err ( |err| tracing:: error!( "Prefill healthcheck error: {err}" ) )
57
64
. is_ok ( ) ;
58
65
// Update generation health
59
66
self . generation_health . store ( value, Ordering :: SeqCst ) ;
60
67
value
68
+ } ;
69
+ guard. start_time = None ;
70
+ ok
71
+ }
72
+ }
73
+
74
+ struct Guard {
75
+ prefill : bool ,
76
+ start_time : Option < Instant > , // None once completed
77
+ }
78
+
79
+ impl Drop for Guard {
80
+ fn drop ( & mut self ) {
81
+ if let Some ( start_time) = self . start_time {
82
+ tracing:: warn!(
83
+ "Healthcheck request cancelled during {} check after {}ms" ,
84
+ if self . prefill { "prefill" } else { "basic shard" } ,
85
+ start_time. elapsed( ) . as_millis( ) ,
86
+ )
61
87
}
62
88
}
63
89
}
0 commit comments