Skip to content

Commit edd67f8

Browse files
committed
refactor: update cost model to use history
Signed-off-by: Gustavo Inacio <[email protected]>
1 parent 0ae3eb8 commit edd67f8

File tree

4 files changed

+74
-48
lines changed

4 files changed

+74
-48
lines changed

common/src/tap/checks/value_check.rs

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ use std::{
1212
time::Duration,
1313
};
1414
use thegraph_core::DeploymentId;
15-
use tokio::task::JoinHandle;
1615
use tracing::error;
1716
use ttl_cache::TtlCache;
1817

@@ -24,7 +23,15 @@ use tap_core::receipt::{
2423

2524
pub struct MinimumValue {
2625
cost_model_cache: Arc<Mutex<HashMap<DeploymentId, CostModelCache>>>,
27-
model_handle: JoinHandle<()>,
26+
watcher_cancel_token: tokio_util::sync::CancellationToken,
27+
}
28+
29+
impl Drop for MinimumValue {
30+
fn drop(&mut self) {
31+
// Clean shutdown for the sender_denylist_watcher
32+
// Though since it's not a critical task, we don't wait for it to finish (join).
33+
self.watcher_cancel_token.cancel();
34+
}
2835
}
2936

3037
impl MinimumValue {
@@ -37,19 +44,17 @@ impl MinimumValue {
3744
'cost_models_update_notify'",
3845
);
3946

40-
// TODO start watcher
41-
let cancel_token = tokio_util::sync::CancellationToken::new();
42-
43-
let model_handle = tokio::spawn(Self::cost_models_watcher(
47+
let watcher_cancel_token = tokio_util::sync::CancellationToken::new();
48+
tokio::spawn(Self::cost_models_watcher(
4449
pgpool.clone(),
4550
pglistener,
4651
cost_model_cache.clone(),
47-
cancel_token.clone(),
52+
watcher_cancel_token.clone(),
4853
));
4954

5055
Self {
5156
cost_model_cache,
52-
model_handle,
57+
watcher_cancel_token,
5358
}
5459
}
5560

@@ -128,12 +133,6 @@ impl MinimumValue {
128133
}
129134
}
130135

131-
impl Drop for MinimumValue {
132-
fn drop(&mut self) {
133-
self.model_handle.abort();
134-
}
135-
}
136-
137136
impl MinimumValue {
138137
fn get_expected_value(&self, agora_query: &AgoraQuery) -> anyhow::Result<u128> {
139138
// get agora model for the allocation_id
Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,8 @@
11
-- Add down migration script here
2-
DROP TABLE "CostModels";
2+
DROP TRIGGER IF EXISTS cost_models_update ON "CostModelsHistory" CASCADE;
3+
4+
DROP FUNCTION IF EXISTS cost_models_update_notify() CASCADE;
5+
6+
DROP VIEW "CostModels";
7+
8+
DROP TABLE "CostModelsHistory";
Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,45 @@
1-
CREATE TABLE IF NOT EXISTS "CostModels"
1+
CREATE TABLE IF NOT EXISTS "CostModelsHistory"
22
(
3-
id INT,
3+
id SERIAL PRIMARY KEY,
44
deployment VARCHAR NOT NULL,
55
model TEXT,
66
variables JSONB,
7-
PRIMARY KEY( deployment )
7+
"createdAt" TIMESTAMP WITH TIME ZONE,
8+
"updatedAt" TIMESTAMP WITH TIME ZONE
89
);
10+
11+
CREATE VIEW "CostModels" AS SELECT id,
12+
deployment,
13+
model,
14+
variables,
15+
"createdAt",
16+
"updatedAt"
17+
FROM "CostModelsHistory" t1
18+
JOIN
19+
(
20+
SELECT MAX(id)
21+
FROM "CostModelsHistory"
22+
GROUP BY deployment
23+
) t2
24+
ON t1.id = t2.MAX;
25+
26+
CREATE FUNCTION cost_models_update_notify()
27+
RETURNS trigger AS
28+
$$
29+
BEGIN
30+
IF TG_OP = 'DELETE' THEN
31+
PERFORM pg_notify('cost_models_update_notification', format('{"tg_op": "DELETE", "deployment": "%s"}', OLD.deployment));
32+
RETURN OLD;
33+
ELSIF TG_OP = 'INSERT' THEN
34+
PERFORM pg_notify('cost_models_update_notification', format('{"tg_op": "INSERT", "deployment": "%s", "model": "%s"}', NEW.deployment, NEW.model));
35+
RETURN NEW;
36+
ELSE
37+
PERFORM pg_notify('cost_models_update_notification', format('{"tg_op": "%s", "deployment": "%s", "model": "%s"}', NEW.deployment, NEW.model));
38+
RETURN NEW;
39+
END IF;
40+
END;
41+
$$ LANGUAGE 'plpgsql';
42+
43+
CREATE TRIGGER cost_models_update AFTER INSERT OR UPDATE OR DELETE
44+
ON "CostModelsHistory"
45+
FOR EACH ROW EXECUTE PROCEDURE cost_models_update_notify();

service/src/database.rs

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ pub async fn connect(url: &str) -> PgPool {
2626
/// These can have "global" as the deployment ID.
2727
#[derive(Debug, Clone)]
2828
struct DbCostModel {
29-
pub deployment: String,
29+
pub deployment: Option<String>,
3030
pub model: Option<String>,
3131
pub variables: Option<Value>,
3232
}
@@ -46,7 +46,12 @@ impl TryFrom<DbCostModel> for CostModel {
4646

4747
fn try_from(db_model: DbCostModel) -> Result<Self, Self::Error> {
4848
Ok(Self {
49-
deployment: DeploymentId::from_str(&db_model.deployment)?,
49+
deployment: DeploymentId::from_str(&db_model.deployment.ok_or(
50+
ParseDeploymentIdError::InvalidIpfsHashLength {
51+
value: String::new(),
52+
length: 0,
53+
},
54+
)?)?,
5055
model: db_model.model,
5156
variables: db_model.variables,
5257
})
@@ -57,7 +62,7 @@ impl From<CostModel> for DbCostModel {
5762
fn from(model: CostModel) -> Self {
5863
let deployment = model.deployment;
5964
DbCostModel {
60-
deployment: format!("{deployment:#x}"),
65+
deployment: Some(format!("{deployment:#x}")),
6166
model: model.model,
6267
variables: model.variables,
6368
}
@@ -210,28 +215,11 @@ mod test {
210215

211216
use super::*;
212217

213-
async fn setup_cost_models_table(pool: &PgPool) {
214-
sqlx::query!(
215-
r#"
216-
CREATE TABLE "CostModels"(
217-
id INT,
218-
deployment VARCHAR NOT NULL,
219-
model TEXT,
220-
variables JSONB,
221-
PRIMARY KEY( deployment )
222-
);
223-
"#,
224-
)
225-
.execute(pool)
226-
.await
227-
.expect("Create test instance in db");
228-
}
229-
230218
async fn add_cost_models(pool: &PgPool, models: Vec<DbCostModel>) {
231219
for model in models {
232220
sqlx::query!(
233221
r#"
234-
INSERT INTO "CostModels" (deployment, model)
222+
INSERT INTO "CostModelsHistory" (deployment, model)
235223
VALUES ($1, $2);
236224
"#,
237225
model.deployment,
@@ -249,7 +237,7 @@ mod test {
249237

250238
fn global_cost_model() -> DbCostModel {
251239
DbCostModel {
252-
deployment: "global".to_string(),
240+
deployment: Some("global".to_string()),
253241
model: Some("default => 0.00001;".to_string()),
254242
variables: None,
255243
}
@@ -281,15 +269,14 @@ mod test {
281269
]
282270
}
283271

284-
#[sqlx::test]
272+
#[sqlx::test(migrations = "../migrations")]
285273
async fn success_cost_models(pool: PgPool) {
286274
let test_models = test_data();
287275
let test_deployments = test_models
288276
.iter()
289277
.map(|model| model.deployment)
290278
.collect::<HashSet<_>>();
291279

292-
setup_cost_models_table(&pool).await;
293280
add_cost_models(&pool, to_db_models(test_models.clone())).await;
294281

295282
// First test: query without deployment filter
@@ -344,7 +331,7 @@ mod test {
344331
}
345332
}
346333

347-
#[sqlx::test]
334+
#[sqlx::test(migrations = "../migrations")]
348335
async fn global_fallback_cost_models(pool: PgPool) {
349336
let test_models = test_data();
350337
let test_deployments = test_models
@@ -353,7 +340,6 @@ mod test {
353340
.collect::<HashSet<_>>();
354341
let global_model = global_cost_model();
355342

356-
setup_cost_models_table(&pool).await;
357343
add_cost_models(&pool, to_db_models(test_models.clone())).await;
358344
add_cost_models(&pool, vec![global_model.clone()]).await;
359345

@@ -436,9 +422,8 @@ mod test {
436422
assert_eq!(missing_model.model, global_model.model);
437423
}
438424

439-
#[sqlx::test]
425+
#[sqlx::test(migrations = "../migrations")]
440426
async fn success_cost_model(pool: PgPool) {
441-
setup_cost_models_table(&pool).await;
442427
add_cost_models(&pool, to_db_models(test_data())).await;
443428

444429
let deployment_id_from_bytes = DeploymentId::from_str(
@@ -459,12 +444,11 @@ mod test {
459444
assert_eq!(model.model, Some("default => 0.00025;".to_string()));
460445
}
461446

462-
#[sqlx::test]
447+
#[sqlx::test(migrations = "../migrations")]
463448
async fn global_fallback_cost_model(pool: PgPool) {
464449
let test_models = test_data();
465450
let global_model = global_cost_model();
466451

467-
setup_cost_models_table(&pool).await;
468452
add_cost_models(&pool, to_db_models(test_models.clone())).await;
469453
add_cost_models(&pool, vec![global_model.clone()]).await;
470454

0 commit comments

Comments
 (0)