Skip to content

Commit b24bedc

Browse files
refactor: use channels on TestableActor to be able to verify sent messages (#647)
* refactor: use channels instead of Notify in TestableActor * test: use channel response to test received message
1 parent 8bf83f5 commit b24bedc

File tree

10 files changed

+161
-119
lines changed

10 files changed

+161
-119
lines changed

crates/service/src/tap/checks/value_check.rs

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ use tap_core::receipt::{
2020
Context, WithValueAndTimestamp,
2121
};
2222
use thegraph_core::DeploymentId;
23+
#[cfg(test)]
24+
use tokio::sync::mpsc;
2325

2426
use crate::{
2527
database::cost_model,
@@ -55,7 +57,7 @@ pub struct MinimumValue {
5557
grace_period: Duration,
5658

5759
#[cfg(test)]
58-
notify: std::sync::Arc<tokio::sync::Notify>,
60+
msg_receiver: mpsc::Receiver<()>,
5961
}
6062

6163
struct CostModelWatcher {
@@ -66,7 +68,7 @@ struct CostModelWatcher {
6668
updated_at: GracePeriod,
6769

6870
#[cfg(test)]
69-
notify: std::sync::Arc<tokio::sync::Notify>,
71+
sender: mpsc::Sender<()>,
7072
}
7173

7274
impl CostModelWatcher {
@@ -77,15 +79,15 @@ impl CostModelWatcher {
7779
global_model: GlobalModel,
7880
cancel_token: tokio_util::sync::CancellationToken,
7981
grace_period: GracePeriod,
80-
#[cfg(test)] notify: std::sync::Arc<tokio::sync::Notify>,
82+
#[cfg(test)] sender: mpsc::Sender<()>,
8183
) {
8284
let cost_model_watcher = CostModelWatcher {
8385
pgpool,
8486
global_model,
8587
cost_models,
8688
updated_at: grace_period,
8789
#[cfg(test)]
88-
notify,
90+
sender,
8991
};
9092

9193
loop {
@@ -119,7 +121,7 @@ impl CostModelWatcher {
119121
Err(_) => self.handle_unexpected_notification(payload).await,
120122
}
121123
#[cfg(test)]
122-
self.notify.notify_one();
124+
self.sender.send(()).await.expect("Channel failed");
123125
}
124126

125127
fn handle_insert(&self, deployment: String, model: String, variables: String) {
@@ -212,7 +214,7 @@ impl MinimumValue {
212214
);
213215

214216
#[cfg(test)]
215-
let notify = std::sync::Arc::new(tokio::sync::Notify::new());
217+
let (sender, receiver) = mpsc::channel(10);
216218

217219
let watcher_cancel_token = tokio_util::sync::CancellationToken::new();
218220
tokio::spawn(CostModelWatcher::cost_models_watcher(
@@ -223,7 +225,7 @@ impl MinimumValue {
223225
watcher_cancel_token.clone(),
224226
updated_at.clone(),
225227
#[cfg(test)]
226-
notify.clone(),
228+
sender,
227229
));
228230
Self {
229231
global_model,
@@ -232,7 +234,7 @@ impl MinimumValue {
232234
updated_at,
233235
grace_period,
234236
#[cfg(test)]
235-
notify,
237+
msg_receiver: receiver,
236238
}
237239
}
238240

@@ -399,14 +401,14 @@ mod tests {
399401

400402
#[sqlx::test(migrations = "../../migrations")]
401403
async fn should_watch_model_insert(pgpool: PgPool) {
402-
let check = MinimumValue::new(pgpool.clone(), Duration::from_secs(0)).await;
404+
let mut check = MinimumValue::new(pgpool.clone(), Duration::from_secs(0)).await;
403405
assert_eq!(check.cost_model_map.read().unwrap().len(), 0);
404406

405407
// insert 2 cost models for different deployment_id
406408
let test_models = test::test_data();
407409
add_cost_models(&pgpool, to_db_models(test_models.clone())).await;
408410

409-
flush_messages(&check.notify).await;
411+
flush_messages(&mut check.msg_receiver).await;
410412

411413
assert_eq!(
412414
check.cost_model_map.read().unwrap().len(),
@@ -420,7 +422,7 @@ mod tests {
420422
let test_models = test::test_data();
421423
add_cost_models(&pgpool, to_db_models(test_models.clone())).await;
422424

423-
let check = MinimumValue::new(pgpool.clone(), Duration::from_secs(0)).await;
425+
let mut check = MinimumValue::new(pgpool.clone(), Duration::from_secs(0)).await;
424426
assert_eq!(check.cost_model_map.read().unwrap().len(), 2);
425427

426428
// remove
@@ -429,7 +431,7 @@ mod tests {
429431
.await
430432
.unwrap();
431433

432-
check.notify.notified().await;
434+
check.msg_receiver.recv().await.expect("Channel failed");
433435

434436
assert_eq!(check.cost_model_map.read().unwrap().len(), 0);
435437
}
@@ -445,12 +447,12 @@ mod tests {
445447

446448
#[sqlx::test(migrations = "../../migrations")]
447449
async fn should_watch_global_model(pgpool: PgPool) {
448-
let check = MinimumValue::new(pgpool.clone(), Duration::from_secs(0)).await;
450+
let mut check = MinimumValue::new(pgpool.clone(), Duration::from_secs(0)).await;
449451

450452
let global_model = global_cost_model();
451453
add_cost_models(&pgpool, vec![global_model.clone()]).await;
452454

453-
check.notify.notified().await;
455+
check.msg_receiver.recv().await.expect("Channel failed");
454456

455457
assert!(check.global_model.read().unwrap().is_some());
456458
}
@@ -460,15 +462,15 @@ mod tests {
460462
let global_model = global_cost_model();
461463
add_cost_models(&pgpool, vec![global_model.clone()]).await;
462464

463-
let check = MinimumValue::new(pgpool.clone(), Duration::from_secs(0)).await;
465+
let mut check = MinimumValue::new(pgpool.clone(), Duration::from_secs(0)).await;
464466
assert!(check.global_model.read().unwrap().is_some());
465467

466468
sqlx::query!(r#"DELETE FROM "CostModels""#)
467469
.execute(&pgpool)
468470
.await
469471
.unwrap();
470472

471-
check.notify.notified().await;
473+
check.msg_receiver.recv().await.expect("Channel failed");
472474

473475
assert_eq!(check.cost_model_map.read().unwrap().len(), 0);
474476
}

0 commit comments

Comments
 (0)