@@ -53,6 +53,9 @@ pub struct MinimumValue {
5353 watcher_cancel_token : tokio_util:: sync:: CancellationToken ,
5454 updated_at : GracePeriod ,
5555 grace_period : Duration ,
56+
57+ #[ cfg( test) ]
58+ notify : std:: sync:: Arc < tokio:: sync:: Notify > ,
5659}
5760
5861struct CostModelWatcher {
@@ -61,6 +64,9 @@ struct CostModelWatcher {
6164 cost_models : CostModelMap ,
6265 global_model : GlobalModel ,
6366 updated_at : GracePeriod ,
67+
68+ #[ cfg( test) ]
69+ notify : std:: sync:: Arc < tokio:: sync:: Notify > ,
6470}
6571
6672impl CostModelWatcher {
@@ -71,12 +77,15 @@ impl CostModelWatcher {
7177 global_model : GlobalModel ,
7278 cancel_token : tokio_util:: sync:: CancellationToken ,
7379 grace_period : GracePeriod ,
80+ #[ cfg( test) ] notify : std:: sync:: Arc < tokio:: sync:: Notify > ,
7481 ) {
7582 let cost_model_watcher = CostModelWatcher {
7683 pgpool,
7784 global_model,
7885 cost_models,
7986 updated_at : grace_period,
87+ #[ cfg( test) ]
88+ notify,
8089 } ;
8190
8291 loop {
@@ -109,6 +118,8 @@ impl CostModelWatcher {
109118 // model cache.
110119 Err ( _) => self . handle_unexpected_notification ( payload) . await ,
111120 }
121+ #[ cfg( test) ]
122+ self . notify . notify_one ( ) ;
112123 }
113124
114125 fn handle_insert ( & self , deployment : String , model : String , variables : String ) {
@@ -200,6 +211,9 @@ impl MinimumValue {
200211 'cost_models_update_notification'",
201212 ) ;
202213
214+ #[ cfg( test) ]
215+ let notify = std:: sync:: Arc :: new ( tokio:: sync:: Notify :: new ( ) ) ;
216+
203217 let watcher_cancel_token = tokio_util:: sync:: CancellationToken :: new ( ) ;
204218 tokio:: spawn ( CostModelWatcher :: cost_models_watcher (
205219 pgpool. clone ( ) ,
@@ -208,13 +222,17 @@ impl MinimumValue {
208222 global_model. clone ( ) ,
209223 watcher_cancel_token. clone ( ) ,
210224 updated_at. clone ( ) ,
225+ #[ cfg( test) ]
226+ notify. clone ( ) ,
211227 ) ) ;
212228 Self {
213229 global_model,
214230 cost_model_map,
215231 watcher_cancel_token,
216232 updated_at,
217233 grace_period,
234+ #[ cfg( test) ]
235+ notify,
218236 }
219237 }
220238
@@ -347,7 +365,7 @@ enum CostModelNotification {
347365#[ cfg( test) ]
348366mod tests {
349367 use std:: time:: Duration ;
350- use test_assets:: { create_signed_receipt, SignedReceiptRequest } ;
368+ use test_assets:: { create_signed_receipt, flush_messages , SignedReceiptRequest } ;
351369
352370 use sqlx:: PgPool ;
353371 use tap_core:: receipt:: { checks:: Check , Context , ReceiptWithState } ;
@@ -388,7 +406,8 @@ mod tests {
388406 // insert 2 cost models for different deployment_id
389407 let test_models = test:: test_data ( ) ;
390408 add_cost_models ( & pgpool, to_db_models ( test_models. clone ( ) ) ) . await ;
391- sleep ( Duration :: from_millis ( 200 ) ) . await ;
409+
410+ flush_messages ( & check. notify ) . await ;
392411
393412 assert_eq ! (
394413 check. cost_model_map. read( ) . unwrap( ) . len( ) ,
@@ -411,7 +430,7 @@ mod tests {
411430 . await
412431 . unwrap ( ) ;
413432
414- sleep ( Duration :: from_millis ( 200 ) ) . await ;
433+ check . notify . notified ( ) . await ;
415434
416435 assert_eq ! ( check. cost_model_map. read( ) . unwrap( ) . len( ) , 0 ) ;
417436 }
@@ -431,7 +450,8 @@ mod tests {
431450
432451 let global_model = global_cost_model ( ) ;
433452 add_cost_models ( & pgpool, vec ! [ global_model. clone( ) ] ) . await ;
434- sleep ( Duration :: from_millis ( 10 ) ) . await ;
453+
454+ check. notify . notified ( ) . await ;
435455
436456 assert ! ( check. global_model. read( ) . unwrap( ) . is_some( ) ) ;
437457 }
@@ -449,7 +469,7 @@ mod tests {
449469 . await
450470 . unwrap ( ) ;
451471
452- sleep ( Duration :: from_millis ( 10 ) ) . await ;
472+ check . notify . notified ( ) . await ;
453473
454474 assert_eq ! ( check. cost_model_map. read( ) . unwrap( ) . len( ) , 0 ) ;
455475 }
@@ -461,7 +481,9 @@ mod tests {
461481
462482 add_cost_models ( & pgpool, to_db_models ( test_models. clone ( ) ) ) . await ;
463483
464- let check = MinimumValue :: new ( pgpool, Duration :: from_secs ( 1 ) ) . await ;
484+ let grace_period = Duration :: from_secs ( 1 ) ;
485+
486+ let check = MinimumValue :: new ( pgpool, grace_period) . await ;
465487
466488 let deployment_id = test_models[ 0 ] . deployment ;
467489 let mut ctx = Context :: new ( ) ;
@@ -511,7 +533,7 @@ mod tests {
511533 "Should accept since its inside grace period "
512534 ) ;
513535
514- sleep ( Duration :: from_millis ( 1010 ) ) . await ;
536+ sleep ( grace_period + Duration :: from_millis ( 10 ) ) . await ;
515537
516538 assert ! (
517539 check. check( & ctx, & receipt) . await . is_err( ) ,
0 commit comments