Skip to content

Commit 502e8e9

Browse files
committed
refactor: add value reload
Signed-off-by: Gustavo Inacio <[email protected]>
1 parent 7818033 commit 502e8e9

File tree

1 file changed

+77
-23
lines changed

1 file changed

+77
-23
lines changed

common/src/tap/checks/value_check.rs

Lines changed: 77 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ use sqlx::{postgres::PgListener, PgPool};
88
use std::{
99
cmp::min,
1010
collections::HashMap,
11-
sync::{Arc, Mutex},
11+
str::FromStr,
12+
sync::{Arc, Mutex, RwLock},
1213
time::Duration,
1314
};
1415
use thegraph_core::DeploymentId;
@@ -22,7 +23,7 @@ use tap_core::receipt::{
2223
};
2324

2425
pub struct MinimumValue {
25-
cost_model_cache: Arc<Mutex<HashMap<DeploymentId, CostModelCache>>>,
26+
cost_model_cache: Arc<RwLock<HashMap<DeploymentId, Mutex<CostModelCache>>>>,
2627
watcher_cancel_token: tokio_util::sync::CancellationToken,
2728
}
2829

@@ -36,7 +37,9 @@ impl Drop for MinimumValue {
3637

3738
impl MinimumValue {
3839
pub async fn new(pgpool: PgPool) -> Self {
39-
let cost_model_cache = Arc::new(Mutex::new(HashMap::<DeploymentId, CostModelCache>::new()));
40+
let cost_model_cache = Arc::new(RwLock::new(
41+
HashMap::<DeploymentId, Mutex<CostModelCache>>::new(),
42+
));
4043

4144
let mut pglistener = PgListener::connect_with(&pgpool.clone()).await.unwrap();
4245
pglistener.listen("cost_models_update_notify").await.expect(
@@ -58,10 +61,23 @@ impl MinimumValue {
5861
}
5962
}
6063

64+
fn get_expected_value(&self, agora_query: &AgoraQuery) -> anyhow::Result<u128> {
65+
// get agora model for the allocation_id
66+
let cache = self.cost_model_cache.read().unwrap();
67+
// on average, we'll have zero or one model
68+
let models = cache.get(&agora_query.deployment_id);
69+
70+
let expected_value = models
71+
.map(|cache| cache.lock().unwrap().cost(agora_query))
72+
.unwrap_or_default();
73+
74+
Ok(expected_value)
75+
}
76+
6177
async fn cost_models_watcher(
62-
_pgpool: PgPool,
78+
pgpool: PgPool,
6379
mut pglistener: PgListener,
64-
cost_model_cache: Arc<Mutex<HashMap<DeploymentId, CostModelCache>>>,
80+
cost_model_cache: Arc<RwLock<HashMap<DeploymentId, Mutex<CostModelCache>>>>,
6581
cancel_token: tokio_util::sync::CancellationToken,
6682
) {
6783
loop {
@@ -88,12 +104,12 @@ impl MinimumValue {
88104
"INSERT" => {
89105
let cost_model_source: CostModelSource = cost_model_notification.into();
90106
let mut cost_model_cache = cost_model_cache
91-
.lock()
107+
.write()
92108
.unwrap();
93109

94110
match cost_model_cache.get_mut(&deployment_id) {
95111
Some(cache) => {
96-
let _ = cache.insert_model(cost_model_source);
112+
let _ = cache.lock().unwrap().insert_model(cost_model_source);
97113
},
98114
None => {
99115
if let Ok(cache) = CostModelCache::new(cost_model_source).inspect_err(|err| {
@@ -102,14 +118,14 @@ impl MinimumValue {
102118
deployment_id, err
103119
)
104120
}) {
105-
cost_model_cache.insert(deployment_id, cache);
121+
cost_model_cache.insert(deployment_id, Mutex::new(cache));
106122
}
107123
},
108124
}
109125
}
110126
"DELETE" => {
111127
cost_model_cache
112-
.lock()
128+
.write()
113129
.unwrap()
114130
.remove(&cost_model_notification.deployment);
115131
}
@@ -122,29 +138,47 @@ impl MinimumValue {
122138
cost_model_notification.tg_op
123139
);
124140

125-
// Self::sender_denylist_reload(pgpool.clone(), denylist.clone())
126-
// .await
127-
// .expect("should be able to reload cost models")
141+
Self::value_check_reload(&pgpool, cost_model_cache.clone())
142+
.await
143+
.expect("should be able to reload cost models")
128144
}
129145
}
130146
}
131147
}
132148
}
133149
}
134-
}
135150

136-
impl MinimumValue {
137-
fn get_expected_value(&self, agora_query: &AgoraQuery) -> anyhow::Result<u128> {
138-
// get agora model for the allocation_id
139-
let mut cache = self.cost_model_cache.lock().unwrap();
140-
// on average, we'll have zero or one model
141-
let models = cache.get_mut(&agora_query.deployment_id);
151+
async fn value_check_reload(
152+
pgpool: &PgPool,
153+
cost_model_cache: Arc<RwLock<HashMap<DeploymentId, Mutex<CostModelCache>>>>,
154+
) -> anyhow::Result<()> {
155+
let models = sqlx::query!(
156+
r#"
157+
SELECT deployment, model, variables
158+
FROM "CostModels"
159+
WHERE deployment != 'global'
160+
ORDER BY deployment ASC
161+
"#
162+
)
163+
.fetch_all(pgpool)
164+
.await?;
165+
let models = models
166+
.into_iter()
167+
.map(|record| {
168+
let deployment_id = DeploymentId::from_str(&record.deployment.unwrap())?;
169+
let model = CostModelCache::new(CostModelSource {
170+
deployment_id,
171+
model: record.model.unwrap(),
172+
variables: record.variables.unwrap().to_string(),
173+
})?;
174+
175+
Ok::<_, anyhow::Error>((deployment_id, Mutex::new(model)))
176+
})
177+
.collect::<Result<HashMap<_, _>, _>>()?;
142178

143-
let expected_value = models
144-
.map(|cache| cache.cost(agora_query))
145-
.unwrap_or_default();
179+
*(cost_model_cache.write().unwrap()) = models;
146180

147-
Ok(expected_value)
181+
Ok(())
148182
}
149183
}
150184

@@ -279,3 +313,23 @@ impl CostModelCache {
279313
.unwrap_or_default()
280314
}
281315
}
316+
317+
#[cfg(test)]
318+
mod tests {
319+
use sqlx::PgPool;
320+
321+
#[sqlx::test(migrations = "../migrations")]
322+
async fn initialize_check(pg_pool: PgPool) {}
323+
324+
#[sqlx::test(migrations = "../migrations")]
325+
async fn should_initialize_check_with_caches(pg_pool: PgPool) {}
326+
327+
#[sqlx::test(migrations = "../migrations")]
328+
async fn should_add_model_to_cache_on_insert(pg_pool: PgPool) {}
329+
330+
#[sqlx::test(migrations = "../migrations")]
331+
async fn should_expire_old_model(pg_pool: PgPool) {}
332+
333+
#[sqlx::test(migrations = "../migrations")]
334+
async fn should_verify_global_model(pg_pool: PgPool) {}
335+
}

0 commit comments

Comments
 (0)