@@ -20,6 +20,8 @@ use tap_core::receipt::{
2020 Context , WithValueAndTimestamp ,
2121} ;
2222use thegraph_core:: DeploymentId ;
23+ #[ cfg( test) ]
24+ use tokio:: sync:: mpsc;
2325
2426use crate :: {
2527 database:: cost_model,
@@ -55,7 +57,7 @@ pub struct MinimumValue {
5557 grace_period : Duration ,
5658
5759 #[ cfg( test) ]
58- notify : std :: sync :: Arc < tokio :: sync :: Notify > ,
60+ msg_receiver : mpsc :: Receiver < ( ) > ,
5961}
6062
6163struct CostModelWatcher {
@@ -66,7 +68,7 @@ struct CostModelWatcher {
6668 updated_at : GracePeriod ,
6769
6870 #[ cfg( test) ]
69- notify : std :: sync :: Arc < tokio :: sync :: Notify > ,
71+ sender : mpsc :: Sender < ( ) > ,
7072}
7173
7274impl CostModelWatcher {
@@ -77,15 +79,15 @@ impl CostModelWatcher {
7779 global_model : GlobalModel ,
7880 cancel_token : tokio_util:: sync:: CancellationToken ,
7981 grace_period : GracePeriod ,
80- #[ cfg( test) ] notify : std :: sync :: Arc < tokio :: sync :: Notify > ,
82+ #[ cfg( test) ] sender : mpsc :: Sender < ( ) > ,
8183 ) {
8284 let cost_model_watcher = CostModelWatcher {
8385 pgpool,
8486 global_model,
8587 cost_models,
8688 updated_at : grace_period,
8789 #[ cfg( test) ]
88- notify ,
90+ sender ,
8991 } ;
9092
9193 loop {
@@ -119,7 +121,7 @@ impl CostModelWatcher {
119121 Err ( _) => self . handle_unexpected_notification ( payload) . await ,
120122 }
121123 #[ cfg( test) ]
122- self . notify . notify_one ( ) ;
124+ self . sender . send ( ( ) ) . await . expect ( "Channel failed" ) ;
123125 }
124126
125127 fn handle_insert ( & self , deployment : String , model : String , variables : String ) {
@@ -212,7 +214,7 @@ impl MinimumValue {
212214 ) ;
213215
214216 #[ cfg( test) ]
215- let notify = std :: sync :: Arc :: new ( tokio :: sync :: Notify :: new ( ) ) ;
217+ let ( sender , receiver ) = mpsc :: channel ( 10 ) ;
216218
217219 let watcher_cancel_token = tokio_util:: sync:: CancellationToken :: new ( ) ;
218220 tokio:: spawn ( CostModelWatcher :: cost_models_watcher (
@@ -223,7 +225,7 @@ impl MinimumValue {
223225 watcher_cancel_token. clone ( ) ,
224226 updated_at. clone ( ) ,
225227 #[ cfg( test) ]
226- notify . clone ( ) ,
228+ sender ,
227229 ) ) ;
228230 Self {
229231 global_model,
@@ -232,7 +234,7 @@ impl MinimumValue {
232234 updated_at,
233235 grace_period,
234236 #[ cfg( test) ]
235- notify ,
237+ msg_receiver : receiver ,
236238 }
237239 }
238240
@@ -399,14 +401,14 @@ mod tests {
399401
400402 #[ sqlx:: test( migrations = "../../migrations" ) ]
401403 async fn should_watch_model_insert ( pgpool : PgPool ) {
402- let check = MinimumValue :: new ( pgpool. clone ( ) , Duration :: from_secs ( 0 ) ) . await ;
404+ let mut check = MinimumValue :: new ( pgpool. clone ( ) , Duration :: from_secs ( 0 ) ) . await ;
403405 assert_eq ! ( check. cost_model_map. read( ) . unwrap( ) . len( ) , 0 ) ;
404406
405407 // insert 2 cost models for different deployment_id
406408 let test_models = test:: test_data ( ) ;
407409 add_cost_models ( & pgpool, to_db_models ( test_models. clone ( ) ) ) . await ;
408410
409- flush_messages ( & check. notify ) . await ;
411+ flush_messages ( & mut check. msg_receiver ) . await ;
410412
411413 assert_eq ! (
412414 check. cost_model_map. read( ) . unwrap( ) . len( ) ,
@@ -420,7 +422,7 @@ mod tests {
420422 let test_models = test:: test_data ( ) ;
421423 add_cost_models ( & pgpool, to_db_models ( test_models. clone ( ) ) ) . await ;
422424
423- let check = MinimumValue :: new ( pgpool. clone ( ) , Duration :: from_secs ( 0 ) ) . await ;
425+ let mut check = MinimumValue :: new ( pgpool. clone ( ) , Duration :: from_secs ( 0 ) ) . await ;
424426 assert_eq ! ( check. cost_model_map. read( ) . unwrap( ) . len( ) , 2 ) ;
425427
426428 // remove
@@ -429,7 +431,7 @@ mod tests {
429431 . await
430432 . unwrap ( ) ;
431433
432- check. notify . notified ( ) . await ;
434+ check. msg_receiver . recv ( ) . await . expect ( "Channel failed" ) ;
433435
434436 assert_eq ! ( check. cost_model_map. read( ) . unwrap( ) . len( ) , 0 ) ;
435437 }
@@ -445,12 +447,12 @@ mod tests {
445447
446448 #[ sqlx:: test( migrations = "../../migrations" ) ]
447449 async fn should_watch_global_model ( pgpool : PgPool ) {
448- let check = MinimumValue :: new ( pgpool. clone ( ) , Duration :: from_secs ( 0 ) ) . await ;
450+ let mut check = MinimumValue :: new ( pgpool. clone ( ) , Duration :: from_secs ( 0 ) ) . await ;
449451
450452 let global_model = global_cost_model ( ) ;
451453 add_cost_models ( & pgpool, vec ! [ global_model. clone( ) ] ) . await ;
452454
453- check. notify . notified ( ) . await ;
455+ check. msg_receiver . recv ( ) . await . expect ( "Channel failed" ) ;
454456
455457 assert ! ( check. global_model. read( ) . unwrap( ) . is_some( ) ) ;
456458 }
@@ -460,15 +462,15 @@ mod tests {
460462 let global_model = global_cost_model ( ) ;
461463 add_cost_models ( & pgpool, vec ! [ global_model. clone( ) ] ) . await ;
462464
463- let check = MinimumValue :: new ( pgpool. clone ( ) , Duration :: from_secs ( 0 ) ) . await ;
465+ let mut check = MinimumValue :: new ( pgpool. clone ( ) , Duration :: from_secs ( 0 ) ) . await ;
464466 assert ! ( check. global_model. read( ) . unwrap( ) . is_some( ) ) ;
465467
466468 sqlx:: query!( r#"DELETE FROM "CostModels""# )
467469 . execute ( & pgpool)
468470 . await
469471 . unwrap ( ) ;
470472
471- check. notify . notified ( ) . await ;
473+ check. msg_receiver . recv ( ) . await . expect ( "Channel failed" ) ;
472474
473475 assert_eq ! ( check. cost_model_map. read( ) . unwrap( ) . len( ) , 0 ) ;
474476 }
0 commit comments