Skip to content

Commit 4e0c2ad

Browse files
committed
refactor: remove ttl, use tasks for expire
Signed-off-by: Gustavo Inacio <[email protected]>
1 parent bec3399 commit 4e0c2ad

File tree

4 files changed

+104
-98
lines changed

4 files changed

+104
-98
lines changed

Cargo.lock

Lines changed: 0 additions & 16 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

common/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ regex = "1.7.1"
3030
axum-extra = { version = "0.9.3", features = [
3131
"typed-header",
3232
], default-features = false }
33-
ttl_cache = "0.5.1"
3433
autometrics = { version = "1.0.1", features = ["prometheus-exporter"] }
3534
tower_governor = "0.3.2"
3635
tower-http = { version = "0.5.2", features = [

common/src/tap.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use tracing::error;
2424
mod checks;
2525
mod receipt_store;
2626

27-
pub use checks::value_check::{AgoraQuery, CostModelSource};
27+
pub use checks::value_check::AgoraQuery;
2828

2929
#[derive(Clone)]
3030
pub struct IndexerTapContext {

common/src/tap/checks/value_check.rs

Lines changed: 103 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,28 @@ use cost_model::CostModel;
77
use sqlx::{postgres::PgListener, PgPool};
88
use std::{
99
cmp::min,
10-
collections::HashMap,
10+
collections::{hash_map::Entry, HashMap, VecDeque},
1111
str::FromStr,
1212
sync::{Arc, Mutex, RwLock},
1313
time::Duration,
1414
};
15-
use thegraph_core::DeploymentId;
15+
use thegraph_core::{DeploymentId, ParseDeploymentIdError};
16+
use tokio::{sync::mpsc::channel, task::JoinHandle, time::sleep};
1617
use tracing::error;
17-
use ttl_cache::TtlCache;
1818

1919
use tap_core::receipt::{
2020
checks::{Check, CheckError, CheckResult},
2121
state::Checking,
2222
Context, ReceiptWithState,
2323
};
2424

25+
// we only accept receipts with minimal 1 wei grt
26+
const MINIMAL_VALUE: u128 = 1;
27+
28+
type CostModelMap = Arc<RwLock<HashMap<DeploymentId, RwLock<CostModelCache>>>>;
29+
2530
pub struct MinimumValue {
26-
cost_model_cache: Arc<RwLock<HashMap<DeploymentId, Mutex<CostModelCache>>>>,
31+
cost_model_cache: CostModelMap,
2732
watcher_cancel_token: tokio_util::sync::CancellationToken,
2833
}
2934

@@ -37,9 +42,7 @@ impl Drop for MinimumValue {
3742

3843
impl MinimumValue {
3944
pub async fn new(pgpool: PgPool) -> Self {
40-
let cost_model_cache = Arc::new(RwLock::new(
41-
HashMap::<DeploymentId, Mutex<CostModelCache>>::new(),
42-
));
45+
let cost_model_cache: CostModelMap = Default::default();
4346

4447
let mut pglistener = PgListener::connect_with(&pgpool.clone()).await.unwrap();
4548
pglistener.listen("cost_models_update_notify").await.expect(
@@ -64,28 +67,48 @@ impl MinimumValue {
6467
fn get_expected_value(&self, agora_query: &AgoraQuery) -> anyhow::Result<u128> {
6568
// get agora model for the allocation_id
6669
let cache = self.cost_model_cache.read().unwrap();
67-
// on average, we'll have zero or one model
6870
let models = cache.get(&agora_query.deployment_id);
6971

7072
let expected_value = models
71-
.map(|cache| cache.lock().unwrap().cost(agora_query))
72-
.unwrap_or_default();
73+
.map(|cache| {
74+
let cache = cache.read().unwrap();
75+
cache.cost(agora_query)
76+
})
77+
.unwrap_or(MINIMAL_VALUE);
7378

7479
Ok(expected_value)
7580
}
7681

7782
async fn cost_models_watcher(
7883
pgpool: PgPool,
7984
mut pglistener: PgListener,
80-
cost_model_cache: Arc<RwLock<HashMap<DeploymentId, Mutex<CostModelCache>>>>,
85+
cost_model_cache: CostModelMap,
8186
cancel_token: tokio_util::sync::CancellationToken,
8287
) {
88+
let handles: Arc<Mutex<HashMap<DeploymentId, VecDeque<JoinHandle<()>>>>> =
89+
Default::default();
90+
let (tx, mut rx) = channel::<DeploymentId>(64);
91+
8392
loop {
8493
tokio::select! {
8594
_ = cancel_token.cancelled() => {
8695
break;
8796
}
97+
Some(deployment_id) = rx.recv() => {
98+
let mut cost_model_write = cost_model_cache.write().unwrap();
99+
if let Some(cache) = cost_model_write.get_mut(&deployment_id) {
100+
cache.get_mut().unwrap().expire();
101+
}
88102

103+
if let Entry::Occupied(mut entry) = handles.lock().unwrap().entry(deployment_id) {
104+
let vec = entry.get_mut();
105+
vec.pop_front();
106+
if vec.is_empty() {
107+
entry.remove();
108+
}
109+
}
110+
111+
}
89112
pg_notification = pglistener.recv() => {
90113
let pg_notification = pg_notification.expect(
91114
"should be able to receive Postgres Notify events on the channel \
@@ -103,31 +126,38 @@ impl MinimumValue {
103126
match cost_model_notification.tg_op.as_str() {
104127
"INSERT" => {
105128
let cost_model_source: CostModelSource = cost_model_notification.into();
106-
let mut cost_model_cache = cost_model_cache
107-
.write()
108-
.unwrap();
109-
110-
match cost_model_cache.get_mut(&deployment_id) {
111-
Some(cache) => {
112-
let _ = cache.lock().unwrap().insert_model(cost_model_source);
113-
},
114-
None => {
115-
if let Ok(cache) = CostModelCache::new(cost_model_source).inspect_err(|err| {
116-
tracing::error!(
117-
"Error while compiling cost model for deployment id {}. Error: {}",
118-
deployment_id, err
119-
)
120-
}) {
121-
cost_model_cache.insert(deployment_id, Mutex::new(cache));
122-
}
123-
},
129+
{
130+
let mut cost_model_write = cost_model_cache
131+
.write()
132+
.unwrap();
133+
let cache = cost_model_write.entry(deployment_id).or_default();
134+
let _ = cache.get_mut().unwrap().insert_model(cost_model_source);
124135
}
136+
let _tx = tx.clone();
137+
138+
// expire after 60 seconds
139+
handles.lock()
140+
.unwrap()
141+
.entry(deployment_id)
142+
.or_default()
143+
.push_back(tokio::spawn(async move {
144+
// 1 minute after, we expire the older cache
145+
sleep(Duration::from_secs(60)).await;
146+
let _ = _tx.send(deployment_id).await;
147+
}));
125148
}
126149
"DELETE" => {
127-
cost_model_cache
128-
.write()
129-
.unwrap()
130-
.remove(&cost_model_notification.deployment);
150+
if let Entry::Occupied(mut entry) = cost_model_cache
151+
.write().unwrap().entry(cost_model_notification.deployment) {
152+
let should_remove = {
153+
let mut cost_model = entry.get_mut().write().unwrap();
154+
cost_model.expire();
155+
cost_model.is_empty()
156+
};
157+
if should_remove {
158+
entry.remove();
159+
}
160+
}
131161
}
132162
// UPDATE and TRUNCATE are not expected to happen. Reload the entire cost
133163
// model cache.
@@ -138,6 +168,17 @@ impl MinimumValue {
138168
cost_model_notification.tg_op
139169
);
140170

171+
{
172+
// clear all pending expire
173+
let mut handles = handles.lock().unwrap();
174+
for maps in handles.values() {
175+
for handle in maps {
176+
handle.abort();
177+
}
178+
}
179+
handles.clear();
180+
}
181+
141182
Self::value_check_reload(&pgpool, cost_model_cache.clone())
142183
.await
143184
.expect("should be able to reload cost models")
@@ -150,7 +191,7 @@ impl MinimumValue {
150191

151192
async fn value_check_reload(
152193
pgpool: &PgPool,
153-
cost_model_cache: Arc<RwLock<HashMap<DeploymentId, Mutex<CostModelCache>>>>,
194+
cost_model_cache: CostModelMap,
154195
) -> anyhow::Result<()> {
155196
let models = sqlx::query!(
156197
r#"
@@ -166,13 +207,14 @@ impl MinimumValue {
166207
.into_iter()
167208
.map(|record| {
168209
let deployment_id = DeploymentId::from_str(&record.deployment.unwrap())?;
169-
let model = CostModelCache::new(CostModelSource {
210+
let mut model = CostModelCache::default();
211+
let _ = model.insert_model(CostModelSource {
170212
deployment_id,
171213
model: record.model.unwrap(),
172-
variables: record.variables.unwrap().to_string(),
173-
})?;
214+
variables: record.variables.unwrap_or_default(),
215+
});
174216

175-
Ok::<_, anyhow::Error>((deployment_id, Mutex::new(model)))
217+
Ok::<_, ParseDeploymentIdError>((deployment_id, RwLock::new(model)))
176218
})
177219
.collect::<Result<HashMap<_, _>, _>>()?;
178220

@@ -220,7 +262,7 @@ fn compile_cost_model(src: CostModelSource) -> anyhow::Result<CostModel> {
220262
if src.model.len() > (1 << 16) {
221263
return Err(anyhow!("CostModelTooLarge"));
222264
}
223-
let model = CostModel::compile(&src.model, &src.variables)?;
265+
let model = CostModel::compile(&src.model, &src.variables.to_string())?;
224266
Ok(model)
225267
}
226268

@@ -231,18 +273,18 @@ pub struct AgoraQuery {
231273
}
232274

233275
#[derive(Clone, Eq, Hash, PartialEq)]
234-
pub struct CostModelSource {
276+
struct CostModelSource {
235277
pub deployment_id: DeploymentId,
236278
pub model: String,
237-
pub variables: String,
279+
pub variables: serde_json::Value,
238280
}
239281

240282
#[derive(serde::Deserialize)]
241283
struct CostModelNotification {
242284
tg_op: String,
243285
deployment: DeploymentId,
244286
model: String,
245-
variables: String,
287+
variables: serde_json::Value,
246288
}
247289

248290
impl From<CostModelNotification> for CostModelSource {
@@ -255,48 +297,29 @@ impl From<CostModelNotification> for CostModelSource {
255297
}
256298
}
257299

258-
pub struct CostModelCache {
259-
models: TtlCache<CostModelSource, CostModel>,
260-
latest_model: CostModel,
261-
latest_source: CostModelSource,
300+
#[derive(Default)]
301+
struct CostModelCache {
302+
models: VecDeque<CostModel>,
262303
}
263304

264305
impl CostModelCache {
265-
pub fn new(source: CostModelSource) -> anyhow::Result<Self> {
266-
let model = compile_cost_model(source.clone())?;
267-
Ok(Self {
268-
latest_model: model,
269-
latest_source: source,
270-
// arbitrary number of models copy
271-
models: TtlCache::new(10),
272-
})
273-
}
274-
275306
fn insert_model(&mut self, source: CostModelSource) -> anyhow::Result<()> {
276-
if source != self.latest_source {
277-
let model = compile_cost_model(source.clone())?;
278-
// update latest and insert into ttl the old model
279-
let old_model = std::mem::replace(&mut self.latest_model, model);
280-
self.latest_source = source.clone();
281-
282-
self.models
283-
// arbitrary cache duration
284-
.insert(source, old_model, Duration::from_secs(60));
285-
}
307+
let model = compile_cost_model(source.clone())?;
308+
self.models.push_back(model);
286309
Ok(())
287310
}
288311

289-
fn get_models(&mut self) -> Vec<&CostModel> {
290-
let mut values: Vec<&CostModel> = self.models.iter().map(|(_, v)| v).collect();
291-
values.push(&self.latest_model);
292-
values
312+
fn expire(&mut self) {
313+
self.models.pop_front();
293314
}
294315

295-
fn cost(&mut self, query: &AgoraQuery) -> u128 {
296-
let models = self.get_models();
316+
fn is_empty(&self) -> bool {
317+
self.models.is_empty()
318+
}
297319

298-
models
299-
.into_iter()
320+
fn cost(&self, query: &AgoraQuery) -> u128 {
321+
self.models
322+
.iter()
300323
.fold(None, |acc, model| {
301324
let value = model
302325
.cost(&query.query, &query.variables)
@@ -310,7 +333,7 @@ impl CostModelCache {
310333
Some(value)
311334
}
312335
})
313-
.unwrap_or_default()
336+
.unwrap_or(MINIMAL_VALUE)
314337
}
315338
}
316339

@@ -319,17 +342,17 @@ mod tests {
319342
use sqlx::PgPool;
320343

321344
#[sqlx::test(migrations = "../migrations")]
322-
async fn initialize_check(pg_pool: PgPool) {}
345+
async fn initialize_check(_pg_pool: PgPool) {}
323346

324347
#[sqlx::test(migrations = "../migrations")]
325-
async fn should_initialize_check_with_caches(pg_pool: PgPool) {}
348+
async fn should_initialize_check_with_caches(_pg_pool: PgPool) {}
326349

327350
#[sqlx::test(migrations = "../migrations")]
328-
async fn should_add_model_to_cache_on_insert(pg_pool: PgPool) {}
351+
async fn should_add_model_to_cache_on_insert(_pg_pool: PgPool) {}
329352

330353
#[sqlx::test(migrations = "../migrations")]
331-
async fn should_expire_old_model(pg_pool: PgPool) {}
354+
async fn should_expire_old_model(_pg_pool: PgPool) {}
332355

333356
#[sqlx::test(migrations = "../migrations")]
334-
async fn should_verify_global_model(pg_pool: PgPool) {}
357+
async fn should_verify_global_model(_pg_pool: PgPool) {}
335358
}

0 commit comments

Comments
 (0)