Skip to content

Commit 52d621f

Browse files
committed
feat: add ttl cache for older cost models
1 parent c7d9436 commit 52d621f

File tree

3 files changed

+128
-56
lines changed

3 files changed

+128
-56
lines changed

Cargo.lock

Lines changed: 17 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

common/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ regex = "1.7.1"
3030
axum-extra = { version = "0.9.3", features = [
3131
"typed-header",
3232
], default-features = false }
33+
ttl_cache = "0.5.1"
3334
autometrics = { version = "1.0.1", features = ["prometheus-exporter"] }
3435
tower_governor = "0.3.2"
3536
tower-http = { version = "0.5.2", features = [

common/src/tap/checks/value_check.rs

Lines changed: 110 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@ use std::{
99
cmp::min,
1010
collections::HashMap,
1111
sync::{Arc, Mutex},
12+
time::Duration,
1213
};
1314
use thegraph_core::DeploymentId;
14-
use tokio::{select, sync::mpsc::Receiver, task::JoinHandle};
15+
use tokio::{sync::mpsc::Receiver, task::JoinHandle};
16+
use ttl_cache::TtlCache;
1517

1618
use tap_core::{
1719
receipt::{
@@ -23,67 +25,77 @@ use tap_core::{
2325
};
2426

2527
pub struct MinimumValue {
26-
cost_model_cache: Arc<Mutex<HashMap<DeploymentId, CostModel>>>,
28+
cost_model_cache: Arc<Mutex<HashMap<DeploymentId, CostModelCache>>>,
2729
query_ids: Arc<Mutex<HashMap<SignatureBytes, AgoraQuery>>>,
28-
handle: JoinHandle<()>,
30+
model_handle: JoinHandle<()>,
31+
query_handle: JoinHandle<()>,
2932
}
3033

3134
impl MinimumValue {
3235
pub fn new(
3336
mut rx_cost_model: Receiver<CostModelSource>,
3437
mut rx_query: Receiver<AgoraQuery>,
3538
) -> Self {
36-
let cost_model_cache = Arc::new(Mutex::new(HashMap::new()));
39+
let cost_model_cache = Arc::new(Mutex::new(HashMap::<DeploymentId, CostModelCache>::new()));
3740
let query_ids = Arc::new(Mutex::new(HashMap::new()));
3841
let cache = cost_model_cache.clone();
3942
let query_ids_clone = query_ids.clone();
40-
let handle = tokio::spawn(async move {
43+
let model_handle = tokio::spawn(async move {
4144
loop {
42-
select! {
43-
model = rx_cost_model.recv() => {
44-
match model {
45-
Some(value) => {
46-
let deployment_id = value.deployment_id;
47-
48-
match compile_cost_model(value) {
49-
Ok(value) => {
50-
// todo keep track of the last X models
51-
cache.lock().unwrap().insert(deployment_id, value);
52-
}
53-
Err(err) => {
54-
tracing::error!(
55-
"Error while compiling cost model for deployment id {}. Error: {}",
56-
deployment_id, err
57-
)
58-
}
45+
let model = rx_cost_model.recv().await;
46+
match model {
47+
Some(value) => {
48+
let deployment_id = value.deployment_id;
49+
50+
if let Some(query) = cache.lock().unwrap().get_mut(&deployment_id) {
51+
let _ = query.insert_model(value);
52+
} else {
53+
match CostModelCache::new(value) {
54+
Ok(value) => {
55+
cache.lock().unwrap().insert(deployment_id, value);
56+
}
57+
Err(err) => {
58+
tracing::error!(
59+
"Error while compiling cost model for deployment id {}. Error: {}",
60+
deployment_id, err
61+
)
5962
}
6063
}
61-
None => continue,
6264
}
6365
}
64-
query = rx_query.recv() => {
65-
match query {
66-
Some(query) => {
67-
query_ids_clone.lock().unwrap().insert(query.signature.get_signature_bytes(), query);
68-
},
69-
None => continue,
70-
}
66+
None => continue,
67+
}
68+
}
69+
});
70+
71+
let query_handle = tokio::spawn(async move {
72+
loop {
73+
let query = rx_query.recv().await;
74+
match query {
75+
Some(query) => {
76+
query_ids_clone
77+
.lock()
78+
.unwrap()
79+
.insert(query.signature.get_signature_bytes(), query);
7180
}
81+
None => continue,
7282
}
7383
}
7484
});
7585

7686
Self {
7787
cost_model_cache,
78-
handle,
88+
model_handle,
7989
query_ids,
90+
query_handle,
8091
}
8192
}
8293
}
8394

8495
impl Drop for MinimumValue {
8596
fn drop(&mut self) {
86-
self.handle.abort();
97+
self.model_handle.abort();
98+
self.query_handle.abort();
8799
}
88100
}
89101

@@ -103,32 +115,16 @@ impl Check for MinimumValue {
103115
.map_err(CheckError::Failed)?;
104116

105117
// get agora model for the allocation_id
106-
let cache = self.cost_model_cache.lock().unwrap();
118+
let mut cache = self.cost_model_cache.lock().unwrap();
107119

108120
// on average, we'll have zero or one model
109-
let models = cache
110-
.get(&agora_query.deployment_id)
111-
.map(|model| vec![model])
112-
.unwrap_or_default();
121+
let models = cache.get_mut(&agora_query.deployment_id);
113122

114123
// get value
115124
let value = receipt.signed_receipt().message.value;
116125

117126
let expected_value = models
118-
.into_iter()
119-
.fold(None, |acc, model| {
120-
let value = model
121-
.cost(&agora_query.query, &agora_query.variables)
122-
.ok()
123-
.map(|fee| fee.to_u128().unwrap_or_default())
124-
.unwrap_or_default();
125-
if let Some(acc) = acc {
126-
// return the minimum value of the cache list
127-
Some(min(acc, value))
128-
} else {
129-
Some(value)
130-
}
131-
})
127+
.map(|cache| cache.cost(&agora_query))
132128
.unwrap_or_default();
133129

134130
let should_accept = value >= expected_value;
@@ -151,11 +147,11 @@ impl Check for MinimumValue {
151147
}
152148
}
153149

154-
fn compile_cost_model(src: CostModelSource) -> Result<CostModel, String> {
150+
fn compile_cost_model(src: CostModelSource) -> anyhow::Result<CostModel> {
155151
if src.model.len() > (1 << 16) {
156-
return Err("CostModelTooLarge".into());
152+
return Err(anyhow!("CostModelTooLarge"));
157153
}
158-
let model = CostModel::compile(&src.model, &src.variables).map_err(|err| err.to_string())?;
154+
let model = CostModel::compile(&src.model, &src.variables)?;
159155
Ok(model)
160156
}
161157

@@ -166,9 +162,68 @@ pub struct AgoraQuery {
166162
variables: String,
167163
}
168164

169-
#[derive(Eq, Hash, PartialEq)]
165+
#[derive(Clone, Eq, Hash, PartialEq)]
170166
pub struct CostModelSource {
171167
deployment_id: DeploymentId,
172168
model: String,
173169
variables: String,
174170
}
171+
172+
pub struct CostModelCache {
173+
models: TtlCache<CostModelSource, CostModel>,
174+
latest_model: CostModel,
175+
latest_source: CostModelSource,
176+
}
177+
178+
impl CostModelCache {
179+
pub fn new(source: CostModelSource) -> anyhow::Result<Self> {
180+
let model = compile_cost_model(source.clone())?;
181+
Ok(Self {
182+
latest_model: model,
183+
latest_source: source,
184+
// arbitrary number of models copy
185+
models: TtlCache::new(10),
186+
})
187+
}
188+
189+
fn insert_model(&mut self, source: CostModelSource) -> anyhow::Result<()> {
190+
if source != self.latest_source {
191+
let model = compile_cost_model(source.clone())?;
192+
// update latest and insert into ttl the old model
193+
let old_model = std::mem::replace(&mut self.latest_model, model);
194+
self.latest_source = source.clone();
195+
196+
self.models
197+
// arbitrary cache duration
198+
.insert(source, old_model, Duration::from_secs(60));
199+
}
200+
Ok(())
201+
}
202+
203+
fn get_models(&mut self) -> Vec<&CostModel> {
204+
let mut values: Vec<&CostModel> = self.models.iter().map(|(_, v)| v).collect();
205+
values.push(&self.latest_model);
206+
values
207+
}
208+
209+
fn cost(&mut self, query: &AgoraQuery) -> u128 {
210+
let models = self.get_models();
211+
212+
models
213+
.into_iter()
214+
.fold(None, |acc, model| {
215+
let value = model
216+
.cost(&query.query, &query.variables)
217+
.ok()
218+
.map(|fee| fee.to_u128().unwrap_or_default())
219+
.unwrap_or_default();
220+
if let Some(acc) = acc {
221+
// return the minimum value of the cache list
222+
Some(min(acc, value))
223+
} else {
224+
Some(value)
225+
}
226+
})
227+
.unwrap_or_default()
228+
}
229+
}

0 commit comments

Comments
 (0)