From 1bbbccb6152ae7fdfbdcd1ebb8ced547afc99941 Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Fri, 19 Apr 2024 15:10:28 -0300 Subject: [PATCH 01/21] feat: add value check Signed-off-by: Gustavo Inacio --- Cargo.lock | 157 +++++++++++++++++++++--- common/Cargo.toml | 1 + common/src/tap/checks.rs | 1 + common/src/tap/checks/value_check.rs | 174 +++++++++++++++++++++++++++ 4 files changed, 316 insertions(+), 17 deletions(-) create mode 100644 common/src/tap/checks/value_check.rs diff --git a/Cargo.lock b/Cargo.lock index 82d5cb970..3575e4d36 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -353,7 +353,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26154390b1d205a4a7ac7352aa2eb4f81f391399d4e2f546fb81a2f8bb383f62" dependencies = [ "alloy-rlp-derive", - "arrayvec", + "arrayvec 0.7.6", "bytes", ] @@ -712,7 +712,7 @@ dependencies = [ "ark-serialize 0.3.0", "ark-std 0.3.0", "derivative", - "num-bigint", + "num-bigint 0.4.6", "num-traits", "paste", "rustc_version 0.3.3", @@ -732,7 +732,7 @@ dependencies = [ "derivative", "digest 0.10.7", "itertools 0.10.5", - "num-bigint", + "num-bigint 0.4.6", "num-traits", "paste", "rustc_version 0.4.1", @@ -765,7 +765,7 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "db2fd794a08ccb318058009eefdf15bcaaaaf6f8161eb3345f907222bac38b20" dependencies = [ - "num-bigint", + "num-bigint 0.4.6", "num-traits", "quote", "syn 1.0.109", @@ -777,7 +777,7 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7abe79b0e4288889c4574159ab790824d0033b9fdcb2a112a3182fac2e514565" dependencies = [ - "num-bigint", + "num-bigint 0.4.6", "num-traits", "proc-macro2", "quote", @@ -802,7 +802,7 @@ checksum = "adb7b85a02b83d2f22f89bd5cac66c9c89474240cb6207cb1efc16d098e822a5" dependencies = [ "ark-std 0.4.0", "digest 0.10.7", - "num-bigint", + "num-bigint 0.4.6", ] [[package]] @@ -825,6 +825,12 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "arrayvec" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b" + [[package]] name = "arrayvec" version = "0.7.6" @@ -1263,7 +1269,7 @@ checksum = "51d712318a27c7150326677b321a5fa91b55f6d9034ffd67f20319e147d40cee" dependencies = [ "autocfg", "libm", - "num-bigint", + "num-bigint 0.4.6", "num-integer", "num-traits", "serde", @@ -1482,7 +1488,7 @@ dependencies = [ "bincode", "build-info-common", "chrono", - "num-bigint", + "num-bigint 0.4.6", "num-traits", "proc-macro-error2", "proc-macro2", @@ -1818,6 +1824,22 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "cost-model" +version = "0.1.0" +source = "git+https://github.com/graphprotocol/agora?rev=3ed34ca#3ed34cae6dadded9c8c6ca34356309c98f0f0578" +dependencies = [ + "firestorm", + "fraction", + "graphql 0.2.0", + "itertools 0.12.1", + "lazy_static", + "nom 5.1.3", + "num-bigint 0.2.6", + "num-traits", + "serde_json", +] + [[package]] name = "cpufeatures" version = "0.2.14" @@ -2243,7 +2265,7 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "139834ddba373bbdd213dffe02c8d110508dcf1726c2be27e8d1f7d7e1856418" dependencies = [ - "arrayvec", + "arrayvec 0.7.6", "auto_impl", "bytes", ] @@ -2347,6 +2369,16 @@ dependencies = [ "thiserror", ] +[[package]] +name = "fraction" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a27f0e7512f6915c9bc38594725e33d7673da9308fea0abf4cc258c281cdbb2a" +dependencies = [ + "lazy_static", + "num", +] + [[package]] name = "fs_extra" version = "1.3.0" @@ -2566,6 +2598,16 @@ dependencies = [ "spinning_top", ] +[[package]] +name = "graphql" +version = "0.2.0" +source = "git+https://github.com/edgeandnode/toolshed?tag=graphql-v0.2.0#2df5ee975656027d4a7fbf591baf6a29dfbe0ee6" +dependencies = [ + "firestorm", + "graphql-parser", + "serde", +] + [[package]] name = "graphql" version = "0.3.0" @@ -3095,6 +3137,7 @@ dependencies = [ "bigdecimal", "bip39", "build-info", + "cost-model", "env_logger", "eventuals", "graphql_client", @@ -3153,7 +3196,7 @@ dependencies = [ "build-info", "build-info-build", "clap", - "graphql", + "graphql 0.3.0", "hex-literal", "indexer-common", "indexer-config", @@ -3292,6 +3335,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.13.0" @@ -3555,6 +3607,19 @@ dependencies = [ "spin", ] +[[package]] +name = "lexical-core" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6607c62aa161d23d17a9072cc5da0be67cdfc89d3afb1e8d9c842bebc2525ffe" +dependencies = [ + "arrayvec 0.5.2", + "bitflags 1.3.2", + "cfg-if", + "ryu", + "static_assertions", +] + [[package]] name = "libc" version = "0.2.159" @@ -3799,6 +3864,17 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c" +[[package]] +name = "nom" +version = "5.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08959a387a676302eebf4ddbcbc611da04285579f76f88ee0506c63b1a61dd4b" +dependencies = [ + "lexical-core", + "memchr", + "version_check", +] + [[package]] name = "nom" version = "7.1.3" @@ -3831,6 +3907,31 @@ dependencies = [ "winapi", ] +[[package]] +name = "num" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36" +dependencies = [ + "num-bigint 0.2.6", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "090c7f9998ee0ff65aa5b723e4009f7b217707f1fb5ea551329cc4d6231fb304" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + [[package]] name = "num-bigint" version = "0.4.6" @@ -3858,6 +3959,16 @@ dependencies = [ "zeroize", ] +[[package]] +name = "num-complex" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95" +dependencies = [ + "autocfg", + "num-traits", +] + [[package]] name = "num-conv" version = "0.1.0" @@ -3884,6 +3995,18 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-rational" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c000134b5dbf44adc5cb772486d335293351644b801551abe8f75c84cfa4aef" +dependencies = [ + "autocfg", + "num-bigint 0.2.6", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -4056,7 +4179,7 @@ version = "3.6.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "306800abfa29c7f16596b5970a588435e3d5b3149683d00c12b699cc19f895ee" dependencies = [ - "arrayvec", + "arrayvec 0.7.6", "bitvec", "byte-slice-cast", "impl-trait-for-tuples", @@ -4915,7 +5038,7 @@ dependencies = [ "ark-ff 0.4.2", "bytes", "fastrlp", - "num-bigint", + "num-bigint 0.4.6", "num-traits", "parity-scale-codec", "primitive-types", @@ -4940,7 +5063,7 @@ version = "1.36.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b082d80e3e3cc52b2ed634388d436fe1f4de6af5786cc2de9ba9737527bdf555" dependencies = [ - "arrayvec", + "arrayvec 0.7.6", "borsh", "bytes", "num-traits", @@ -5233,7 +5356,7 @@ dependencies = [ "core-foundation", "core-foundation-sys", "libc", - "num-bigint", + "num-bigint 0.4.6", "security-framework-sys", ] @@ -5514,7 +5637,7 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adc4e5204eb1910f40f9cfa375f6f05b68c3abac4b6fd879c8ff5e7ae8a0a085" dependencies = [ - "num-bigint", + "num-bigint 0.4.6", "num-traits", "thiserror", "time", @@ -5615,7 +5738,7 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7bba3a93db0cc4f7bdece8bb09e77e2e785c20bfebf79eb8340ed80708048790" dependencies = [ - "nom", + "nom 7.1.3", "unicode_categories", ] @@ -5785,7 +5908,7 @@ dependencies = [ "log", "md-5", "memchr", - "num-bigint", + "num-bigint 0.4.6", "once_cell", "rand 0.8.5", "rust_decimal", diff --git a/common/Cargo.toml b/common/Cargo.toml index 4cb1b25ba..911fcc048 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -26,6 +26,7 @@ graphql_client.workspace = true serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } tokio = { workspace = true, features = ["fs", "tokio-macros"] } +cost-model = { git = "https://github.com/graphprotocol/agora", rev = "3ed34ca" } regex = "1.7.1" axum-extra = { version = "0.9.3", features = [ "typed-header", diff --git a/common/src/tap/checks.rs b/common/src/tap/checks.rs index 4da34f594..d4e23e9b2 100644 --- a/common/src/tap/checks.rs +++ b/common/src/tap/checks.rs @@ -6,3 +6,4 @@ pub mod deny_list_check; pub mod receipt_max_val_check; pub mod sender_balance_check; pub mod timestamp_check; +pub mod value_check; diff --git a/common/src/tap/checks/value_check.rs b/common/src/tap/checks/value_check.rs new file mode 100644 index 000000000..43f6acb04 --- /dev/null +++ b/common/src/tap/checks/value_check.rs @@ -0,0 +1,174 @@ +// Copyright 2023-, GraphOps and Semiotic Labs. +// SPDX-License-Identifier: Apache-2.0 + +use alloy::signers::Signature; +use anyhow::anyhow; +use bigdecimal::ToPrimitive; +use cost_model::CostModel; +use std::{ + cmp::min, + collections::HashMap, + sync::{Arc, Mutex}, +}; +use thegraph_core::DeploymentId; +use tokio::{select, sync::mpsc::Receiver, task::JoinHandle}; + +use tap_core::{ + receipt::{ + checks::{Check, CheckError, CheckResult}, + state::Checking, + ReceiptWithState, + }, + signed_message::{SignatureBytes, SignatureBytesExt}, +}; + +pub struct MinimumValue { + cost_model_cache: Arc>>, + query_ids: Arc>>, + handle: JoinHandle<()>, +} + +impl MinimumValue { + pub fn new( + mut rx_cost_model: Receiver, + mut rx_query: Receiver, + ) -> Self { + let cost_model_cache = Arc::new(Mutex::new(HashMap::new())); + let query_ids = Arc::new(Mutex::new(HashMap::new())); + let cache = cost_model_cache.clone(); + let query_ids_clone = query_ids.clone(); + let handle = tokio::spawn(async move { + loop { + select! { + model = rx_cost_model.recv() => { + match model { + Some(value) => { + let deployment_id = value.deployment_id; + + match compile_cost_model(value) { + Ok(value) => { + // todo keep track of the last X models + cache.lock().unwrap().insert(deployment_id, value); + } + Err(err) => { + tracing::error!( + "Error while compiling cost model for deployment id {}. Error: {}", + deployment_id, err + ) + } + } + } + None => continue, + } + } + query = rx_query.recv() => { + match query { + Some(query) => { + query_ids_clone.lock().unwrap().insert(query.signature.get_signature_bytes(), query); + }, + None => continue, + } + } + } + } + }); + + Self { + cost_model_cache, + handle, + query_ids, + } + } +} + +impl Drop for MinimumValue { + fn drop(&mut self) { + self.handle.abort(); + } +} + +#[async_trait::async_trait] +impl Check for MinimumValue { + async fn check(&self, receipt: &ReceiptWithState) -> CheckResult { + // get key + let key = &receipt.signed_receipt().signature.get_signature_bytes(); + + // get query from key + let agora_query = self + .query_ids + .lock() + .unwrap() + .remove(key) + .ok_or(anyhow!("No query found")) + .map_err(CheckError::Failed)?; + + // get agora model for the allocation_id + let cache = self.cost_model_cache.lock().unwrap(); + + // on average, we'll have zero or one model + let models = cache + .get(&agora_query.deployment_id) + .map(|model| vec![model]) + .unwrap_or_default(); + + // get value + let value = receipt.signed_receipt().message.value; + + let expected_value = models + .into_iter() + .fold(None, |acc, model| { + let value = model + .cost(&agora_query.query, &agora_query.variables) + .ok() + .map(|fee| fee.to_u128().unwrap_or_default()) + .unwrap_or_default(); + if let Some(acc) = acc { + // return the minimum value of the cache list + Some(min(acc, value)) + } else { + Some(value) + } + }) + .unwrap_or_default(); + + let should_accept = value >= expected_value; + + tracing::trace!( + value, + expected_value, + should_accept, + "Evaluating mininum query fee." + ); + + if should_accept { + Ok(()) + } else { + return Err(CheckError::Failed(anyhow!( + "Query receipt does not have the minimum value. Expected value: {}. Minimum value: {}.", + expected_value, value, + ))); + } + } +} + +fn compile_cost_model(src: CostModelSource) -> Result { + if src.model.len() > (1 << 16) { + return Err("CostModelTooLarge".into()); + } + let model = CostModel::compile(&src.model, &src.variables).map_err(|err| err.to_string())?; + Ok(model) +} + +pub struct AgoraQuery { + signature: Signature, + deployment_id: DeploymentId, + query: String, + variables: String, +} + +#[derive(Eq, Hash, PartialEq)] +pub struct CostModelSource { + deployment_id: DeploymentId, + model: String, + variables: String, +} From 671ff05a9259de0520172a75e009429c99ce0ace Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Fri, 19 Apr 2024 17:32:25 -0300 Subject: [PATCH 02/21] feat: add ttl cache for older cost models --- Cargo.lock | 16 +++ common/Cargo.toml | 1 + common/src/tap/checks/value_check.rs | 165 ++++++++++++++++++--------- 3 files changed, 127 insertions(+), 55 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3575e4d36..5a4d2a737 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3159,6 +3159,7 @@ dependencies = [ "tower-http", "tower_governor", "tracing", + "ttl_cache", "wiremock 0.5.22", ] @@ -3643,6 +3644,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "linked-hash-map" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" + [[package]] name = "linkme" version = "0.3.28" @@ -6688,6 +6695,15 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "ttl_cache" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4189890526f0168710b6ee65ceaedf1460c48a14318ceec933cb26baa492096a" +dependencies = [ + "linked-hash-map", +] + [[package]] name = "tungstenite" version = "0.23.0" diff --git a/common/Cargo.toml b/common/Cargo.toml index 911fcc048..4cf3515c1 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -31,6 +31,7 @@ regex = "1.7.1" axum-extra = { version = "0.9.3", features = [ "typed-header", ], default-features = false } +ttl_cache = "0.5.1" autometrics = { version = "1.0.1", features = ["prometheus-exporter"] } tower_governor = "0.3.2" tower-http = { version = "0.5.2", features = [ diff --git a/common/src/tap/checks/value_check.rs b/common/src/tap/checks/value_check.rs index 43f6acb04..ff9427d7f 100644 --- a/common/src/tap/checks/value_check.rs +++ b/common/src/tap/checks/value_check.rs @@ -9,9 +9,11 @@ use std::{ cmp::min, collections::HashMap, sync::{Arc, Mutex}, + time::Duration, }; use thegraph_core::DeploymentId; -use tokio::{select, sync::mpsc::Receiver, task::JoinHandle}; +use tokio::{sync::mpsc::Receiver, task::JoinHandle}; +use ttl_cache::TtlCache; use tap_core::{ receipt::{ @@ -23,9 +25,10 @@ use tap_core::{ }; pub struct MinimumValue { - cost_model_cache: Arc>>, + cost_model_cache: Arc>>, query_ids: Arc>>, - handle: JoinHandle<()>, + model_handle: JoinHandle<()>, + query_handle: JoinHandle<()>, } impl MinimumValue { @@ -33,57 +36,66 @@ impl MinimumValue { mut rx_cost_model: Receiver, mut rx_query: Receiver, ) -> Self { - let cost_model_cache = Arc::new(Mutex::new(HashMap::new())); + let cost_model_cache = Arc::new(Mutex::new(HashMap::::new())); let query_ids = Arc::new(Mutex::new(HashMap::new())); let cache = cost_model_cache.clone(); let query_ids_clone = query_ids.clone(); - let handle = tokio::spawn(async move { + let model_handle = tokio::spawn(async move { loop { - select! { - model = rx_cost_model.recv() => { - match model { - Some(value) => { - let deployment_id = value.deployment_id; - - match compile_cost_model(value) { - Ok(value) => { - // todo keep track of the last X models - cache.lock().unwrap().insert(deployment_id, value); - } - Err(err) => { - tracing::error!( - "Error while compiling cost model for deployment id {}. Error: {}", - deployment_id, err - ) - } + let model = rx_cost_model.recv().await; + match model { + Some(value) => { + let deployment_id = value.deployment_id; + + if let Some(query) = cache.lock().unwrap().get_mut(&deployment_id) { + let _ = query.insert_model(value); + } else { + match CostModelCache::new(value) { + Ok(value) => { + cache.lock().unwrap().insert(deployment_id, value); + } + Err(err) => { + tracing::error!( + "Error while compiling cost model for deployment id {}. Error: {}", + deployment_id, err + ) } } - None => continue, } } - query = rx_query.recv() => { - match query { - Some(query) => { - query_ids_clone.lock().unwrap().insert(query.signature.get_signature_bytes(), query); - }, - None => continue, - } + None => continue, + } + } + }); + + let query_handle = tokio::spawn(async move { + loop { + let query = rx_query.recv().await; + match query { + Some(query) => { + query_ids_clone + .lock() + .unwrap() + .insert(query.signature.get_signature_bytes(), query); } + None => continue, } } }); Self { cost_model_cache, - handle, + model_handle, query_ids, + query_handle, } } } impl Drop for MinimumValue { fn drop(&mut self) { - self.handle.abort(); + self.model_handle.abort(); + self.query_handle.abort(); } } @@ -103,32 +115,16 @@ impl Check for MinimumValue { .map_err(CheckError::Failed)?; // get agora model for the allocation_id - let cache = self.cost_model_cache.lock().unwrap(); + let mut cache = self.cost_model_cache.lock().unwrap(); // on average, we'll have zero or one model - let models = cache - .get(&agora_query.deployment_id) - .map(|model| vec![model]) - .unwrap_or_default(); + let models = cache.get_mut(&agora_query.deployment_id); // get value let value = receipt.signed_receipt().message.value; let expected_value = models - .into_iter() - .fold(None, |acc, model| { - let value = model - .cost(&agora_query.query, &agora_query.variables) - .ok() - .map(|fee| fee.to_u128().unwrap_or_default()) - .unwrap_or_default(); - if let Some(acc) = acc { - // return the minimum value of the cache list - Some(min(acc, value)) - } else { - Some(value) - } - }) + .map(|cache| cache.cost(&agora_query)) .unwrap_or_default(); let should_accept = value >= expected_value; @@ -151,11 +147,11 @@ impl Check for MinimumValue { } } -fn compile_cost_model(src: CostModelSource) -> Result { +fn compile_cost_model(src: CostModelSource) -> anyhow::Result { if src.model.len() > (1 << 16) { - return Err("CostModelTooLarge".into()); + return Err(anyhow!("CostModelTooLarge")); } - let model = CostModel::compile(&src.model, &src.variables).map_err(|err| err.to_string())?; + let model = CostModel::compile(&src.model, &src.variables)?; Ok(model) } @@ -166,9 +162,68 @@ pub struct AgoraQuery { variables: String, } -#[derive(Eq, Hash, PartialEq)] +#[derive(Clone, Eq, Hash, PartialEq)] pub struct CostModelSource { deployment_id: DeploymentId, model: String, variables: String, } + +pub struct CostModelCache { + models: TtlCache, + latest_model: CostModel, + latest_source: CostModelSource, +} + +impl CostModelCache { + pub fn new(source: CostModelSource) -> anyhow::Result { + let model = compile_cost_model(source.clone())?; + Ok(Self { + latest_model: model, + latest_source: source, + // arbitrary number of models copy + models: TtlCache::new(10), + }) + } + + fn insert_model(&mut self, source: CostModelSource) -> anyhow::Result<()> { + if source != self.latest_source { + let model = compile_cost_model(source.clone())?; + // update latest and insert into ttl the old model + let old_model = std::mem::replace(&mut self.latest_model, model); + self.latest_source = source.clone(); + + self.models + // arbitrary cache duration + .insert(source, old_model, Duration::from_secs(60)); + } + Ok(()) + } + + fn get_models(&mut self) -> Vec<&CostModel> { + let mut values: Vec<&CostModel> = self.models.iter().map(|(_, v)| v).collect(); + values.push(&self.latest_model); + values + } + + fn cost(&mut self, query: &AgoraQuery) -> u128 { + let models = self.get_models(); + + models + .into_iter() + .fold(None, |acc, model| { + let value = model + .cost(&query.query, &query.variables) + .ok() + .map(|fee| fee.to_u128().unwrap_or_default()) + .unwrap_or_default(); + if let Some(acc) = acc { + // return the minimum value of the cache list + Some(min(acc, value)) + } else { + Some(value) + } + }) + .unwrap_or_default() + } +} From a79fd76e842a97c117e6b7203489ff871d3d52f8 Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Fri, 19 Apr 2024 17:36:59 -0300 Subject: [PATCH 03/21] fix: use break instead of continue --- common/src/tap/checks/value_check.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/src/tap/checks/value_check.rs b/common/src/tap/checks/value_check.rs index ff9427d7f..fd2679504 100644 --- a/common/src/tap/checks/value_check.rs +++ b/common/src/tap/checks/value_check.rs @@ -63,7 +63,7 @@ impl MinimumValue { } } } - None => continue, + None => break, } } }); @@ -78,7 +78,7 @@ impl MinimumValue { .unwrap() .insert(query.signature.get_signature_bytes(), query); } - None => continue, + None => break, } } }); From d68d38465bda339eb184716aa38b26fd0f146d8f Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Fri, 19 Apr 2024 17:47:44 -0300 Subject: [PATCH 04/21] refactor: split in multiple functions --- common/src/tap/checks/value_check.rs | 38 +++++++++++++++++----------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/common/src/tap/checks/value_check.rs b/common/src/tap/checks/value_check.rs index fd2679504..3f2487f20 100644 --- a/common/src/tap/checks/value_check.rs +++ b/common/src/tap/checks/value_check.rs @@ -68,6 +68,7 @@ impl MinimumValue { } }); + // we use two different handles because in case one channel breaks we still have the other let query_handle = tokio::spawn(async move { loop { let query = rx_query.recv().await; @@ -99,34 +100,41 @@ impl Drop for MinimumValue { } } -#[async_trait::async_trait] -impl Check for MinimumValue { - async fn check(&self, receipt: &ReceiptWithState) -> CheckResult { - // get key - let key = &receipt.signed_receipt().signature.get_signature_bytes(); +impl MinimumValue { + fn get_agora_query(&self, query_id: &SignatureBytes) -> Option { + self.query_ids.lock().unwrap().remove(query_id) + } + fn get_expected_value(&self, query_id: &SignatureBytes) -> anyhow::Result { // get query from key let agora_query = self - .query_ids - .lock() - .unwrap() - .remove(key) - .ok_or(anyhow!("No query found")) - .map_err(CheckError::Failed)?; + .get_agora_query(query_id) + .ok_or(anyhow!("No query found"))?; // get agora model for the allocation_id let mut cache = self.cost_model_cache.lock().unwrap(); - // on average, we'll have zero or one model let models = cache.get_mut(&agora_query.deployment_id); - // get value - let value = receipt.signed_receipt().message.value; - let expected_value = models .map(|cache| cache.cost(&agora_query)) .unwrap_or_default(); + Ok(expected_value) + } +} + +#[async_trait::async_trait] +impl Check for MinimumValue { + async fn check(&self, receipt: &ReceiptWithState) -> CheckResult { + // get key + let key = &receipt.signed_receipt().signature.get_signature_bytes(); + + let expected_value = self.get_expected_value(key).map_err(CheckError::Failed)?; + + // get value + let value = receipt.signed_receipt().message.value; + let should_accept = value >= expected_value; tracing::trace!( From 1bd21b1da1ca4499c96d37f382c22d8757e077df Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Fri, 19 Apr 2024 17:50:45 -0300 Subject: [PATCH 05/21] feat: add error logs for model insert error --- .../indexer_service/http/indexer_service.rs | 6 +++ common/src/tap.rs | 8 +++ common/src/tap/checks/value_check.rs | 51 ++++++++++++++++--- service/src/routes/cost.rs | 41 +++++++++++++-- service/src/service.rs | 12 ++++- 5 files changed, 105 insertions(+), 13 deletions(-) diff --git a/common/src/indexer_service/http/indexer_service.rs b/common/src/indexer_service/http/indexer_service.rs index e3aca9640..fe884b03d 100644 --- a/common/src/indexer_service/http/indexer_service.rs +++ b/common/src/indexer_service/http/indexer_service.rs @@ -36,6 +36,7 @@ use tracing::{info, info_span}; use crate::escrow_accounts::EscrowAccounts; use crate::escrow_accounts::EscrowAccountsError; +use crate::tap::{ValueCheckReceiver, ValueCheckSender}; use crate::{ address::public_key, indexer_service::http::static_subgraph::static_subgraph_request_handler, @@ -177,6 +178,8 @@ where pub release: IndexerServiceRelease, pub url_namespace: &'static str, pub extra_routes: Router>>, + pub value_check_receiver: ValueCheckReceiver, + pub value_check_sender: ValueCheckSender, } pub struct IndexerServiceState @@ -191,6 +194,7 @@ where // tap pub escrow_accounts: Eventual, pub domain_separator: Eip712Domain, + pub value_check_sender: ValueCheckSender, } pub struct IndexerService {} @@ -320,6 +324,7 @@ impl IndexerService { domain_separator.clone(), timestamp_error_tolerance, receipt_max_value, + options.value_check_receiver, ) .await; @@ -336,6 +341,7 @@ impl IndexerService { service_impl: Arc::new(options.service_impl), escrow_accounts, domain_separator, + value_check_sender: options.value_check_sender, }); // Rate limits by allowing bursts of 10 requests and requiring 100ms of diff --git a/common/src/tap.rs b/common/src/tap.rs index 2d23ca922..c5f2d383a 100644 --- a/common/src/tap.rs +++ b/common/src/tap.rs @@ -6,6 +6,7 @@ use crate::tap::checks::deny_list_check::DenyListCheck; use crate::tap::checks::receipt_max_val_check::ReceiptMaxValueCheck; use crate::tap::checks::sender_balance_check::SenderBalanceCheck; use crate::tap::checks::timestamp_check::TimestampCheck; +use crate::tap::checks::value_check::MinimumValue; use crate::{escrow_accounts::EscrowAccounts, prelude::Allocation}; use alloy::dyn_abi::Eip712Domain; use alloy::primitives::Address; @@ -24,6 +25,11 @@ use tracing::error; mod checks; mod receipt_store; +pub use checks::value_check::{ + create_value_check, CostModelSource, ValueCheckReceiver, ValueCheckSender, +}; + +#[derive(Clone)] pub struct IndexerTapContext { domain_separator: Arc, receipt_producer: Sender, @@ -44,6 +50,7 @@ impl IndexerTapContext { domain_separator: Eip712Domain, timestamp_error_tolerance: Duration, receipt_max_value: u128, + value_check_receiver: ValueCheckReceiver, ) -> Vec { vec![ Arc::new(AllocationEligible::new(indexer_allocations)), @@ -54,6 +61,7 @@ impl IndexerTapContext { Arc::new(TimestampCheck::new(timestamp_error_tolerance)), Arc::new(DenyListCheck::new(pgpool, escrow_accounts, domain_separator).await), Arc::new(ReceiptMaxValueCheck::new(receipt_max_value)), + Arc::new(MinimumValue::new(value_check_receiver)), ] } diff --git a/common/src/tap/checks/value_check.rs b/common/src/tap/checks/value_check.rs index 3f2487f20..70c18dc72 100644 --- a/common/src/tap/checks/value_check.rs +++ b/common/src/tap/checks/value_check.rs @@ -12,7 +12,10 @@ use std::{ time::Duration, }; use thegraph_core::DeploymentId; -use tokio::{sync::mpsc::Receiver, task::JoinHandle}; +use tokio::{ + sync::mpsc::{Receiver, Sender}, + task::JoinHandle, +}; use ttl_cache::TtlCache; use tap_core::{ @@ -31,10 +34,39 @@ pub struct MinimumValue { query_handle: JoinHandle<()>, } +#[derive(Clone)] +pub struct ValueCheckSender { + pub tx_cost_model: Sender, + pub tx_query: Sender, +} + +pub struct ValueCheckReceiver { + rx_cost_model: Receiver, + rx_query: Receiver, +} + +pub fn create_value_check(size: usize) -> (ValueCheckSender, ValueCheckReceiver) { + let (tx_cost_model, rx_cost_model) = tokio::sync::mpsc::channel(size); + let (tx_query, rx_query) = tokio::sync::mpsc::channel(size); + + ( + ValueCheckSender { + tx_query, + tx_cost_model, + }, + ValueCheckReceiver { + rx_cost_model, + rx_query, + }, + ) +} + impl MinimumValue { pub fn new( - mut rx_cost_model: Receiver, - mut rx_query: Receiver, + ValueCheckReceiver { + mut rx_query, + mut rx_cost_model, + }: ValueCheckReceiver, ) -> Self { let cost_model_cache = Arc::new(Mutex::new(HashMap::::new())); let query_ids = Arc::new(Mutex::new(HashMap::new())); @@ -48,7 +80,12 @@ impl MinimumValue { let deployment_id = value.deployment_id; if let Some(query) = cache.lock().unwrap().get_mut(&deployment_id) { - let _ = query.insert_model(value); + let _ = query.insert_model(value).inspect_err(|err| { + tracing::error!( + "Error while compiling cost model for deployment id {}. Error: {}", + deployment_id, err + ) + }); } else { match CostModelCache::new(value) { Ok(value) => { @@ -172,9 +209,9 @@ pub struct AgoraQuery { #[derive(Clone, Eq, Hash, PartialEq)] pub struct CostModelSource { - deployment_id: DeploymentId, - model: String, - variables: String, + pub deployment_id: DeploymentId, + pub model: String, + pub variables: String, } pub struct CostModelCache { diff --git a/service/src/routes/cost.rs b/service/src/routes/cost.rs index 08d4eb2cf..630f5987f 100644 --- a/service/src/routes/cost.rs +++ b/service/src/routes/cost.rs @@ -12,6 +12,7 @@ use prometheus::{ register_counter, register_counter_vec, register_histogram, register_histogram_vec, Counter, CounterVec, Histogram, HistogramVec, }; +use indexer_common::tap::CostModelSource; use serde::{Deserialize, Serialize}; use serde_json::Value; use thegraph_core::DeploymentId; @@ -66,6 +67,16 @@ pub struct GraphQlCostModel { pub variables: Option, } +impl From for CostModelSource { + fn from(value: CostModel) -> Self { + Self { + deployment_id: value.deployment, + model: value.model.unwrap_or_default(), + variables: value.variables.unwrap_or_default().to_string(), + } + } +} + impl From for GraphQlCostModel { fn from(model: CostModel) -> Self { Self { @@ -127,8 +138,20 @@ impl Query { ctx: &Context<'_>, deployment_ids: Vec, ) -> Result, anyhow::Error> { - let pool = &ctx.data_unchecked::>().database; + let state = &ctx.data_unchecked::>(); + + let cost_model_sender = &state.value_check_sender; + + let pool = &state.database; let cost_models = database::cost_models(pool, &deployment_ids).await?; + + for model in &cost_models { + let _ = cost_model_sender + .tx_cost_model + .send(CostModelSource::from(model.clone())) + .await; + } + Ok(cost_models.into_iter().map(|m| m.into()).collect()) } @@ -137,10 +160,20 @@ impl Query { ctx: &Context<'_>, deployment_id: DeploymentId, ) -> Result, anyhow::Error> { + + let state = &ctx.data_unchecked::>(); + let cost_model_sender = &state.value_check_sender; let pool = &ctx.data_unchecked::>().database; - database::cost_model(pool, &deployment_id) - .await - .map(|model_opt| model_opt.map(GraphQlCostModel::from)) + let model = database::cost_model(pool, &deployment_id).await?; + + if let Some(model) = &model { + let _ = cost_model_sender + .tx_cost_model + .send(CostModelSource::from(model.clone())) + .await; + } + + Ok(model.map(GraphQlCostModel::from)) } } diff --git a/service/src/service.rs b/service/src/service.rs index e583bff75..7ef9f831f 100644 --- a/service/src/service.rs +++ b/service/src/service.rs @@ -7,8 +7,9 @@ use std::time::Duration; use super::{error::SubgraphServiceError, routes}; use anyhow::anyhow; use axum::{async_trait, routing::post, Json, Router}; -use indexer_common::indexer_service::http::{ - AttestationOutput, IndexerServiceImpl, IndexerServiceResponse, +use indexer_common::{ + indexer_service::http::{AttestationOutput, IndexerServiceImpl, IndexerServiceResponse}, + tap::{create_value_check, ValueCheckSender}, }; use indexer_config::Config; use reqwest::Url; @@ -67,6 +68,7 @@ pub struct SubgraphServiceState { pub graph_node_client: reqwest::Client, pub graph_node_status_url: &'static Url, pub graph_node_query_base_url: &'static Url, + pub value_check_sender: ValueCheckSender, } struct SubgraphService { @@ -146,6 +148,9 @@ pub async fn run() -> anyhow::Result<()> { build_info::build_info!(fn build_info); let release = IndexerServiceRelease::from(build_info()); + // arbitrary value + let (value_check_sender, value_check_receiver) = create_value_check(10); + // Some of the subgraph service configuration goes into the so-called // "state", which will be passed to any request handler, middleware etc. // that is involved in serving requests @@ -161,6 +166,7 @@ pub async fn run() -> anyhow::Result<()> { .expect("Failed to init HTTP client for Graph Node"), graph_node_status_url: &config.graph_node.status_url, graph_node_query_base_url: &config.graph_node.query_url, + value_check_sender: value_check_sender.clone(), }); IndexerService::run(IndexerServiceOptions { @@ -172,6 +178,8 @@ pub async fn run() -> anyhow::Result<()> { .route("/cost", post(routes::cost::cost)) .route("/status", post(routes::status)) .with_state(state), + value_check_receiver, + value_check_sender, }) .await } From cb495c7168e7736893dfd6bd0b5ee8b59d6deea0 Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Fri, 19 Apr 2024 18:57:20 -0300 Subject: [PATCH 06/21] feat: send query request to value check Signed-off-by: Gustavo Inacio --- .../indexer_service/http/request_handler.rs | 29 ++++++++++++++++++- common/src/tap.rs | 2 +- common/src/tap/checks/value_check.rs | 8 ++--- 3 files changed, 33 insertions(+), 6 deletions(-) diff --git a/common/src/indexer_service/http/request_handler.rs b/common/src/indexer_service/http/request_handler.rs index b7f23d681..69277b3d3 100644 --- a/common/src/indexer_service/http/request_handler.rs +++ b/common/src/indexer_service/http/request_handler.rs @@ -16,7 +16,9 @@ use reqwest::StatusCode; use thegraph_core::DeploymentId; use tracing::trace; -use crate::indexer_service::http::IndexerServiceResponse; +use serde_json::value::RawValue; + +use crate::{indexer_service::http::IndexerServiceResponse, tap::AgoraQuery}; use super::{ indexer_service::{AttestationOutput, IndexerServiceError, IndexerServiceState}, @@ -109,6 +111,31 @@ where }; let allocation_id = receipt.message.allocation_id; + let signature = receipt.signature; + + #[derive(Debug, serde::Deserialize)] + pub struct QueryBody { + pub query: String, + pub variables: Option>, + } + + let query_body: QueryBody = + serde_json::from_slice(&body).map_err(|e| IndexerServiceError::InvalidRequest(e.into()))?; + let variables = query_body + .variables + .as_ref() + .map(ToString::to_string) + .unwrap_or_default(); + let _ = state + .value_check_sender + .tx_query + .send(AgoraQuery { + signature, + deployment_id: manifest_id, + query: query_body.query.clone(), + variables, + }) + .await; // recover the signer address // get escrow accounts from eventual diff --git a/common/src/tap.rs b/common/src/tap.rs index c5f2d383a..b3dbd3d33 100644 --- a/common/src/tap.rs +++ b/common/src/tap.rs @@ -26,7 +26,7 @@ mod checks; mod receipt_store; pub use checks::value_check::{ - create_value_check, CostModelSource, ValueCheckReceiver, ValueCheckSender, + create_value_check, AgoraQuery, CostModelSource, ValueCheckReceiver, ValueCheckSender, }; #[derive(Clone)] diff --git a/common/src/tap/checks/value_check.rs b/common/src/tap/checks/value_check.rs index 70c18dc72..60110b945 100644 --- a/common/src/tap/checks/value_check.rs +++ b/common/src/tap/checks/value_check.rs @@ -201,10 +201,10 @@ fn compile_cost_model(src: CostModelSource) -> anyhow::Result { } pub struct AgoraQuery { - signature: Signature, - deployment_id: DeploymentId, - query: String, - variables: String, + pub signature: Signature, + pub deployment_id: DeploymentId, + pub query: String, + pub variables: String, } #[derive(Clone, Eq, Hash, PartialEq)] From 5fe8567f6a2ed347299ef1281b78fe91e06bf2ae Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Thu, 19 Sep 2024 20:00:48 +0200 Subject: [PATCH 07/21] feat: use new context to pass query Signed-off-by: Gustavo Inacio --- Cargo.lock | 23 ++++-- Cargo.toml | 2 +- .../indexer_service/http/request_handler.rs | 20 +++-- common/src/tap/checks/allocation_eligible.rs | 6 +- common/src/tap/checks/deny_list_check.rs | 28 +++++-- .../src/tap/checks/receipt_max_val_check.rs | 22 ++++- common/src/tap/checks/sender_balance_check.rs | 6 +- common/src/tap/checks/timestamp_check.rs | 23 ++++-- common/src/tap/checks/value_check.rs | 80 ++++--------------- tap-agent/src/agent/sender_allocation.rs | 13 ++- .../src/tap/context/checks/allocation_id.rs | 6 +- tap-agent/src/tap/context/checks/signature.rs | 6 +- tap-agent/src/tap/context/checks/value.rs | 6 +- 13 files changed, 134 insertions(+), 107 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5a4d2a737..033486e4b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -701,6 +701,12 @@ version = "1.0.90" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37bf3594c4c988a53154954629820791dde498571819ae4ca50ca811e060cc95" +[[package]] +name = "anymap3" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9316150bf847270bd54e404c3fb1c5d3c6af118b3f0c3b42097da5ceafbe7c27" + [[package]] name = "ark-ff" version = "0.3.0" @@ -3149,7 +3155,7 @@ dependencies = [ "serde", "serde_json", "sqlx", - "tap_core 1.0.0 (git+https://github.com/semiotic-ai/timeline-aggregation-protocol?rev=ff856d9)", + "tap_core 1.0.0 (git+https://github.com/semiotic-ai/timeline-aggregation-protocol?rev=3fe6bc2)", "test-log", "thegraph-core", "thegraph-graphql-http", @@ -3241,7 +3247,7 @@ dependencies = [ "serde_json", "sqlx", "tap_aggregator", - "tap_core 1.0.0 (git+https://github.com/semiotic-ai/timeline-aggregation-protocol?rev=ff856d9)", + "tap_core 1.0.0 (git+https://github.com/semiotic-ai/timeline-aggregation-protocol?rev=3fe6bc2)", "tempfile", "thegraph-core", "thiserror", @@ -6168,17 +6174,14 @@ dependencies = [ [[package]] name = "tap_core" version = "1.0.0" -source = "git+https://github.com/semiotic-ai/timeline-aggregation-protocol?rev=eb8447e#eb8447ed4566ced6846c03b510b25b915f985186" +source = "git+https://github.com/semiotic-ai/timeline-aggregation-protocol?rev=3fe6bc2#3fe6bc27c10161e5a1c011c088281d3da227df53" dependencies = [ "alloy", "anyhow", + "anymap3", "async-trait", "rand 0.8.5", - "rand_core 0.6.4", - "rstest", "serde", - "strum 0.24.1", - "strum_macros 0.24.3", "thiserror", "tokio", ] @@ -6186,13 +6189,17 @@ dependencies = [ [[package]] name = "tap_core" version = "1.0.0" -source = "git+https://github.com/semiotic-ai/timeline-aggregation-protocol?rev=ff856d9#ff856d966112af4c4d554a81154797fae4b335d9" +source = "git+https://github.com/semiotic-ai/timeline-aggregation-protocol?rev=eb8447e#eb8447ed4566ced6846c03b510b25b915f985186" dependencies = [ "alloy", "anyhow", "async-trait", "rand 0.8.5", + "rand_core 0.6.4", + "rstest", "serde", + "strum 0.24.1", + "strum_macros 0.24.3", "thiserror", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 6bf9e0b65..ed66b25d2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,7 +38,7 @@ sqlx = { version = "0.8.2", features = [ tracing = { version = "0.1.40", default-features = false } bigdecimal = "0.4.3" build-info = "0.0.39" -tap_core = { git = "https://github.com/semiotic-ai/timeline-aggregation-protocol", rev = "ff856d9", default-features = false } +tap_core = { git = "https://github.com/semiotic-ai/timeline-aggregation-protocol", rev = "3fe6bc2", default-features = false } tracing-subscriber = { version = "0.3", features = [ "json", "env-filter", diff --git a/common/src/indexer_service/http/request_handler.rs b/common/src/indexer_service/http/request_handler.rs index 69277b3d3..d0547143c 100644 --- a/common/src/indexer_service/http/request_handler.rs +++ b/common/src/indexer_service/http/request_handler.rs @@ -13,6 +13,7 @@ use axum_extra::TypedHeader; use lazy_static::lazy_static; use prometheus::{register_counter_vec, register_histogram_vec, CounterVec, HistogramVec}; use reqwest::StatusCode; +use tap_core::receipt::Context; use thegraph_core::DeploymentId; use tracing::trace; @@ -126,16 +127,13 @@ where .as_ref() .map(ToString::to_string) .unwrap_or_default(); - let _ = state - .value_check_sender - .tx_query - .send(AgoraQuery { - signature, - deployment_id: manifest_id, - query: query_body.query.clone(), - variables, - }) - .await; + let mut ctx = Context::new(); + ctx.insert(AgoraQuery { + signature, + deployment_id: manifest_id, + query: query_body.query.clone(), + variables, + }); // recover the signer address // get escrow accounts from eventual @@ -168,7 +166,7 @@ where // Verify the receipt and store it in the database state .tap_manager - .verify_and_store_receipt(receipt) + .verify_and_store_receipt(&ctx, receipt) .await .inspect_err(|_| { FAILED_RECEIPT diff --git a/common/src/tap/checks/allocation_eligible.rs b/common/src/tap/checks/allocation_eligible.rs index ee4e752c7..a2a106c64 100644 --- a/common/src/tap/checks/allocation_eligible.rs +++ b/common/src/tap/checks/allocation_eligible.rs @@ -27,7 +27,11 @@ impl AllocationEligible { } #[async_trait::async_trait] impl Check for AllocationEligible { - async fn check(&self, receipt: &ReceiptWithState) -> CheckResult { + async fn check( + &self, + _: &tap_core::receipt::Context, + receipt: &ReceiptWithState, + ) -> CheckResult { let allocation_id = receipt.signed_receipt().message.allocation_id; if !self .indexer_allocations diff --git a/common/src/tap/checks/deny_list_check.rs b/common/src/tap/checks/deny_list_check.rs index 469e6a848..cf751d29a 100644 --- a/common/src/tap/checks/deny_list_check.rs +++ b/common/src/tap/checks/deny_list_check.rs @@ -150,7 +150,11 @@ impl DenyListCheck { #[async_trait::async_trait] impl Check for DenyListCheck { - async fn check(&self, receipt: &ReceiptWithState) -> CheckResult { + async fn check( + &self, + _: &tap_core::receipt::Context, + receipt: &ReceiptWithState, + ) -> CheckResult { let receipt_signer = receipt .signed_receipt() .recover_signer(&self.domain_separator) @@ -195,7 +199,7 @@ mod tests { use std::str::FromStr; use alloy::hex::ToHexExt; - use tap_core::receipt::ReceiptWithState; + use tap_core::receipt::{Context, ReceiptWithState}; use crate::test_vectors::{self, create_signed_receipt, TAP_SENDER}; @@ -241,7 +245,10 @@ mod tests { let checking_receipt = ReceiptWithState::new(signed_receipt); // Check that the receipt is rejected - assert!(deny_list_check.check(&checking_receipt).await.is_err()); + assert!(deny_list_check + .check(&Context::new(), &checking_receipt) + .await + .is_err()); } #[sqlx::test(migrations = "../migrations")] @@ -255,7 +262,10 @@ mod tests { // Check that the receipt is valid let checking_receipt = ReceiptWithState::new(signed_receipt); - deny_list_check.check(&checking_receipt).await.unwrap(); + deny_list_check + .check(&Context::new(), &checking_receipt) + .await + .unwrap(); // Add the sender to the denylist sqlx::query!( @@ -271,7 +281,10 @@ mod tests { // Check that the receipt is rejected tokio::time::sleep(std::time::Duration::from_millis(100)).await; - assert!(deny_list_check.check(&checking_receipt).await.is_err()); + assert!(deny_list_check + .check(&Context::new(), &checking_receipt) + .await + .is_err()); // Remove the sender from the denylist sqlx::query!( @@ -287,6 +300,9 @@ mod tests { // Check that the receipt is valid again tokio::time::sleep(std::time::Duration::from_millis(100)).await; - deny_list_check.check(&checking_receipt).await.unwrap(); + deny_list_check + .check(&Context::new(), &checking_receipt) + .await + .unwrap(); } } diff --git a/common/src/tap/checks/receipt_max_val_check.rs b/common/src/tap/checks/receipt_max_val_check.rs index 1dc7c26b6..8d480fff3 100644 --- a/common/src/tap/checks/receipt_max_val_check.rs +++ b/common/src/tap/checks/receipt_max_val_check.rs @@ -20,7 +20,11 @@ impl ReceiptMaxValueCheck { #[async_trait::async_trait] impl Check for ReceiptMaxValueCheck { - async fn check(&self, receipt: &ReceiptWithState) -> CheckResult { + async fn check( + &self, + _: &tap_core::receipt::Context, + receipt: &ReceiptWithState, + ) -> CheckResult { let receipt_value = receipt.signed_receipt().message.value; if receipt_value < self.receipt_max_value { @@ -42,6 +46,7 @@ mod tests { use std::str::FromStr; use std::time::Duration; use std::time::SystemTime; + use tap_core::receipt::Context; use super::*; use crate::tap::Eip712Domain; @@ -92,20 +97,29 @@ mod tests { async fn test_receipt_lower_than_limit() { let signed_receipt = create_signed_receipt_with_custom_value(RECEIPT_LIMIT - 1); let timestamp_check = ReceiptMaxValueCheck::new(RECEIPT_LIMIT); - assert!(timestamp_check.check(&signed_receipt).await.is_ok()); + assert!(timestamp_check + .check(&Context::new(), &signed_receipt) + .await + .is_ok()); } #[tokio::test] async fn test_receipt_higher_than_limit() { let signed_receipt = create_signed_receipt_with_custom_value(RECEIPT_LIMIT + 1); let timestamp_check = ReceiptMaxValueCheck::new(RECEIPT_LIMIT); - assert!(timestamp_check.check(&signed_receipt).await.is_err()); + assert!(timestamp_check + .check(&Context::new(), &signed_receipt) + .await + .is_err()); } #[tokio::test] async fn test_receipt_same_as_limit() { let signed_receipt = create_signed_receipt_with_custom_value(RECEIPT_LIMIT); let timestamp_check = ReceiptMaxValueCheck::new(RECEIPT_LIMIT); - assert!(timestamp_check.check(&signed_receipt).await.is_err()); + assert!(timestamp_check + .check(&Context::new(), &signed_receipt) + .await + .is_err()); } } diff --git a/common/src/tap/checks/sender_balance_check.rs b/common/src/tap/checks/sender_balance_check.rs index b0269e71a..08a822759 100644 --- a/common/src/tap/checks/sender_balance_check.rs +++ b/common/src/tap/checks/sender_balance_check.rs @@ -30,7 +30,11 @@ impl SenderBalanceCheck { #[async_trait::async_trait] impl Check for SenderBalanceCheck { - async fn check(&self, receipt: &ReceiptWithState) -> CheckResult { + async fn check( + &self, + _: &tap_core::receipt::Context, + receipt: &ReceiptWithState, + ) -> CheckResult { let escrow_accounts_snapshot = self.escrow_accounts.value_immediate().unwrap_or_default(); let receipt_signer = receipt diff --git a/common/src/tap/checks/timestamp_check.rs b/common/src/tap/checks/timestamp_check.rs index 7504433e9..219742928 100644 --- a/common/src/tap/checks/timestamp_check.rs +++ b/common/src/tap/checks/timestamp_check.rs @@ -23,7 +23,11 @@ impl TimestampCheck { #[async_trait::async_trait] impl Check for TimestampCheck { - async fn check(&self, receipt: &ReceiptWithState) -> CheckResult { + async fn check( + &self, + _: &tap_core::receipt::Context, + receipt: &ReceiptWithState, + ) -> CheckResult { let timestamp_now = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .map_err(|e| CheckError::Failed(e.into()))?; @@ -54,7 +58,7 @@ mod tests { use super::*; use crate::tap::Eip712Domain; use tap_core::{ - receipt::{checks::Check, state::Checking, Receipt, ReceiptWithState}, + receipt::{checks::Check, state::Checking, Context, Receipt, ReceiptWithState}, signed_message::EIP712SignedMessage, tap_eip712_domain, }; @@ -98,7 +102,10 @@ mod tests { let timestamp_ns = timestamp as u64; let signed_receipt = create_signed_receipt_with_custom_timestamp(timestamp_ns); let timestamp_check = TimestampCheck::new(Duration::from_secs(30)); - assert!(timestamp_check.check(&signed_receipt).await.is_ok()); + assert!(timestamp_check + .check(&Context::new(), &signed_receipt) + .await + .is_ok()); } #[tokio::test] @@ -111,7 +118,10 @@ mod tests { let timestamp_ns = timestamp as u64; let signed_receipt = create_signed_receipt_with_custom_timestamp(timestamp_ns); let timestamp_check = TimestampCheck::new(Duration::from_secs(30)); - assert!(timestamp_check.check(&signed_receipt).await.is_err()); + assert!(timestamp_check + .check(&Context::new(), &signed_receipt) + .await + .is_err()); } #[tokio::test] @@ -124,6 +134,9 @@ mod tests { let timestamp_ns = timestamp as u64; let signed_receipt = create_signed_receipt_with_custom_timestamp(timestamp_ns); let timestamp_check = TimestampCheck::new(Duration::from_secs(30)); - assert!(timestamp_check.check(&signed_receipt).await.is_err()); + assert!(timestamp_check + .check(&Context::new(), &signed_receipt) + .await + .is_err()); } } diff --git a/common/src/tap/checks/value_check.rs b/common/src/tap/checks/value_check.rs index 60110b945..9689c17c9 100644 --- a/common/src/tap/checks/value_check.rs +++ b/common/src/tap/checks/value_check.rs @@ -18,60 +18,39 @@ use tokio::{ }; use ttl_cache::TtlCache; -use tap_core::{ - receipt::{ - checks::{Check, CheckError, CheckResult}, - state::Checking, - ReceiptWithState, - }, - signed_message::{SignatureBytes, SignatureBytesExt}, +use tap_core::receipt::{ + checks::{Check, CheckError, CheckResult}, + state::Checking, + Context, ReceiptWithState, }; pub struct MinimumValue { cost_model_cache: Arc>>, - query_ids: Arc>>, model_handle: JoinHandle<()>, - query_handle: JoinHandle<()>, } #[derive(Clone)] pub struct ValueCheckSender { pub tx_cost_model: Sender, - pub tx_query: Sender, } pub struct ValueCheckReceiver { rx_cost_model: Receiver, - rx_query: Receiver, } pub fn create_value_check(size: usize) -> (ValueCheckSender, ValueCheckReceiver) { let (tx_cost_model, rx_cost_model) = tokio::sync::mpsc::channel(size); - let (tx_query, rx_query) = tokio::sync::mpsc::channel(size); ( - ValueCheckSender { - tx_query, - tx_cost_model, - }, - ValueCheckReceiver { - rx_cost_model, - rx_query, - }, + ValueCheckSender { tx_cost_model }, + ValueCheckReceiver { rx_cost_model }, ) } impl MinimumValue { - pub fn new( - ValueCheckReceiver { - mut rx_query, - mut rx_cost_model, - }: ValueCheckReceiver, - ) -> Self { + pub fn new(ValueCheckReceiver { mut rx_cost_model }: ValueCheckReceiver) -> Self { let cost_model_cache = Arc::new(Mutex::new(HashMap::::new())); - let query_ids = Arc::new(Mutex::new(HashMap::new())); let cache = cost_model_cache.clone(); - let query_ids_clone = query_ids.clone(); let model_handle = tokio::spawn(async move { loop { let model = rx_cost_model.recv().await; @@ -105,27 +84,9 @@ impl MinimumValue { } }); - // we use two different handles because in case one channel breaks we still have the other - let query_handle = tokio::spawn(async move { - loop { - let query = rx_query.recv().await; - match query { - Some(query) => { - query_ids_clone - .lock() - .unwrap() - .insert(query.signature.get_signature_bytes(), query); - } - None => break, - } - } - }); - Self { cost_model_cache, model_handle, - query_ids, - query_handle, } } } @@ -133,28 +94,18 @@ impl MinimumValue { impl Drop for MinimumValue { fn drop(&mut self) { self.model_handle.abort(); - self.query_handle.abort(); } } impl MinimumValue { - fn get_agora_query(&self, query_id: &SignatureBytes) -> Option { - self.query_ids.lock().unwrap().remove(query_id) - } - - fn get_expected_value(&self, query_id: &SignatureBytes) -> anyhow::Result { - // get query from key - let agora_query = self - .get_agora_query(query_id) - .ok_or(anyhow!("No query found"))?; - + fn get_expected_value(&self, agora_query: &AgoraQuery) -> anyhow::Result { // get agora model for the allocation_id let mut cache = self.cost_model_cache.lock().unwrap(); // on average, we'll have zero or one model let models = cache.get_mut(&agora_query.deployment_id); let expected_value = models - .map(|cache| cache.cost(&agora_query)) + .map(|cache| cache.cost(agora_query)) .unwrap_or_default(); Ok(expected_value) @@ -163,11 +114,14 @@ impl MinimumValue { #[async_trait::async_trait] impl Check for MinimumValue { - async fn check(&self, receipt: &ReceiptWithState) -> CheckResult { - // get key - let key = &receipt.signed_receipt().signature.get_signature_bytes(); - - let expected_value = self.get_expected_value(key).map_err(CheckError::Failed)?; + async fn check(&self, ctx: &Context, receipt: &ReceiptWithState) -> CheckResult { + let agora_query = ctx + .get() + .ok_or(CheckError::Failed(anyhow!("Could not find agora query")))?; + + let expected_value = self + .get_expected_value(agora_query) + .map_err(CheckError::Failed)?; // get value let value = receipt.signed_receipt().message.value; diff --git a/tap-agent/src/agent/sender_allocation.rs b/tap-agent/src/agent/sender_allocation.rs index b44a896f3..572a25249 100644 --- a/tap-agent/src/agent/sender_allocation.rs +++ b/tap-agent/src/agent/sender_allocation.rs @@ -23,7 +23,7 @@ use tap_core::{ receipt::{ checks::{Check, CheckList}, state::Failed, - ReceiptWithState, + Context, ReceiptWithState, }, signed_message::EIP712SignedMessage, }; @@ -521,6 +521,7 @@ impl SenderAllocationState { } = self .tap_manager .create_rav_request( + &Context::new(), self.timestamp_buffer_ns, Some(self.rav_request_receipt_limit), ) @@ -884,7 +885,7 @@ pub mod tests { use tap_core::receipt::{ checks::{Check, CheckError, CheckList, CheckResult}, state::Checking, - ReceiptWithState, + Context, ReceiptWithState, }; use tokio::sync::mpsc; use wiremock::{ @@ -1495,7 +1496,11 @@ pub mod tests { #[async_trait::async_trait] impl Check for FailingCheck { - async fn check(&self, _receipt: &ReceiptWithState) -> CheckResult { + async fn check( + &self, + _: &tap_core::receipt::Context, + _receipt: &ReceiptWithState, + ) -> CheckResult { Err(CheckError::Failed(anyhow::anyhow!("Failing check"))) } } @@ -1517,7 +1522,7 @@ pub mod tests { .into_iter() .map(|receipt| async { receipt - .finalize_receipt_checks(&checks) + .finalize_receipt_checks(&Context::new(), &checks) .await .unwrap() .unwrap_err() diff --git a/tap-agent/src/tap/context/checks/allocation_id.rs b/tap-agent/src/tap/context/checks/allocation_id.rs index 87f05fbe2..978bd7c45 100644 --- a/tap-agent/src/tap/context/checks/allocation_id.rs +++ b/tap-agent/src/tap/context/checks/allocation_id.rs @@ -46,7 +46,11 @@ impl AllocationId { #[async_trait::async_trait] impl Check for AllocationId { - async fn check(&self, receipt: &ReceiptWithState) -> CheckResult { + async fn check( + &self, + _: &tap_core::receipt::Context, + receipt: &ReceiptWithState, + ) -> CheckResult { let allocation_id = receipt.signed_receipt().message.allocation_id; // TODO: Remove the if block below? Each TAP Monitor is specific to an allocation // ID. So the receipts that are received here should already have been filtered by diff --git a/tap-agent/src/tap/context/checks/signature.rs b/tap-agent/src/tap/context/checks/signature.rs index e4727ef6b..4b1bbe81a 100644 --- a/tap-agent/src/tap/context/checks/signature.rs +++ b/tap-agent/src/tap/context/checks/signature.rs @@ -29,7 +29,11 @@ impl Signature { #[async_trait::async_trait] impl Check for Signature { - async fn check(&self, receipt: &ReceiptWithState) -> CheckResult { + async fn check( + &self, + _: &tap_core::receipt::Context, + receipt: &ReceiptWithState, + ) -> CheckResult { let signer = receipt .signed_receipt() .recover_signer(&self.domain_separator) diff --git a/tap-agent/src/tap/context/checks/value.rs b/tap-agent/src/tap/context/checks/value.rs index 5265a370c..414095d1a 100644 --- a/tap-agent/src/tap/context/checks/value.rs +++ b/tap-agent/src/tap/context/checks/value.rs @@ -24,7 +24,11 @@ pub struct Value { #[async_trait::async_trait] impl Check for Value { - async fn check(&self, receipt: &ReceiptWithState) -> CheckResult { + async fn check( + &self, + _: &tap_core::receipt::Context, + receipt: &ReceiptWithState, + ) -> CheckResult { let value = receipt.signed_receipt().message.value; let query_id = receipt.signed_receipt().unique_hash(); From abe02e9c1ad2d0d3b683cfe789b93be886a75992 Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Mon, 7 Oct 2024 15:36:58 +0200 Subject: [PATCH 08/21] feat: remove signature Signed-off-by: Gustavo Inacio --- .../indexer_service/http/request_handler.rs | 2 - common/src/tap/checks/value_check.rs | 65 ++++++++++++++++++- 2 files changed, 63 insertions(+), 4 deletions(-) diff --git a/common/src/indexer_service/http/request_handler.rs b/common/src/indexer_service/http/request_handler.rs index d0547143c..8cfe000c5 100644 --- a/common/src/indexer_service/http/request_handler.rs +++ b/common/src/indexer_service/http/request_handler.rs @@ -112,7 +112,6 @@ where }; let allocation_id = receipt.message.allocation_id; - let signature = receipt.signature; #[derive(Debug, serde::Deserialize)] pub struct QueryBody { @@ -129,7 +128,6 @@ where .unwrap_or_default(); let mut ctx = Context::new(); ctx.insert(AgoraQuery { - signature, deployment_id: manifest_id, query: query_body.query.clone(), variables, diff --git a/common/src/tap/checks/value_check.rs b/common/src/tap/checks/value_check.rs index 9689c17c9..7c71739b9 100644 --- a/common/src/tap/checks/value_check.rs +++ b/common/src/tap/checks/value_check.rs @@ -1,10 +1,11 @@ // Copyright 2023-, GraphOps and Semiotic Labs. // SPDX-License-Identifier: Apache-2.0 -use alloy::signers::Signature; use anyhow::anyhow; use bigdecimal::ToPrimitive; use cost_model::CostModel; +use sqlx::{postgres::PgListener, PgPool}; +use tracing::error; use std::{ cmp::min, collections::HashMap, @@ -89,6 +90,67 @@ impl MinimumValue { model_handle, } } + + async fn cost_models_watcher( + pgpool: PgPool, + mut pglistener: PgListener, + denylist: Arc>>, + cancel_token: tokio_util::sync::CancellationToken, + ) { + #[derive(serde::Deserialize)] + struct DenylistNotification { + tg_op: String, + deployment: DeploymentId, + } + + loop { + tokio::select! { + _ = cancel_token.cancelled() => { + break; + } + + pg_notification = pglistener.recv() => { + let pg_notification = pg_notification.expect( + "should be able to receive Postgres Notify events on the channel \ + 'scalar_tap_deny_notification'", + ); + + let denylist_notification: DenylistNotification = + serde_json::from_str(pg_notification.payload()).expect( + "should be able to deserialize the Postgres Notify event payload as a \ + DenylistNotification", + ); + + match denylist_notification.tg_op.as_str() { + "INSERT" => { + denylist + .write() + .unwrap() + .insert(denylist_notification.sender_address); + } + "DELETE" => { + denylist + .write() + .unwrap() + .remove(&denylist_notification.sender_address); + } + // UPDATE and TRUNCATE are not expected to happen. Reload the entire denylist. + _ => { + error!( + "Received an unexpected denylist table notification: {}. Reloading entire \ + denylist.", + denylist_notification.tg_op + ); + + Self::sender_denylist_reload(pgpool.clone(), denylist.clone()) + .await + .expect("should be able to reload the sender denylist") + } + } + } + } + } + } } impl Drop for MinimumValue { @@ -155,7 +217,6 @@ fn compile_cost_model(src: CostModelSource) -> anyhow::Result { } pub struct AgoraQuery { - pub signature: Signature, pub deployment_id: DeploymentId, pub query: String, pub variables: String, From 38cf5619a5631373f0b81f82d658c47b03bdd706 Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Mon, 7 Oct 2024 16:09:53 +0200 Subject: [PATCH 09/21] refactor: use just pglistener Signed-off-by: Gustavo Inacio --- .../indexer_service/http/indexer_service.rs | 6 - common/src/tap.rs | 9 +- common/src/tap/checks/value_check.rs | 157 +++++++++--------- service/src/routes/cost.rs | 21 +-- service/src/service.rs | 12 +- 5 files changed, 81 insertions(+), 124 deletions(-) diff --git a/common/src/indexer_service/http/indexer_service.rs b/common/src/indexer_service/http/indexer_service.rs index fe884b03d..e3aca9640 100644 --- a/common/src/indexer_service/http/indexer_service.rs +++ b/common/src/indexer_service/http/indexer_service.rs @@ -36,7 +36,6 @@ use tracing::{info, info_span}; use crate::escrow_accounts::EscrowAccounts; use crate::escrow_accounts::EscrowAccountsError; -use crate::tap::{ValueCheckReceiver, ValueCheckSender}; use crate::{ address::public_key, indexer_service::http::static_subgraph::static_subgraph_request_handler, @@ -178,8 +177,6 @@ where pub release: IndexerServiceRelease, pub url_namespace: &'static str, pub extra_routes: Router>>, - pub value_check_receiver: ValueCheckReceiver, - pub value_check_sender: ValueCheckSender, } pub struct IndexerServiceState @@ -194,7 +191,6 @@ where // tap pub escrow_accounts: Eventual, pub domain_separator: Eip712Domain, - pub value_check_sender: ValueCheckSender, } pub struct IndexerService {} @@ -324,7 +320,6 @@ impl IndexerService { domain_separator.clone(), timestamp_error_tolerance, receipt_max_value, - options.value_check_receiver, ) .await; @@ -341,7 +336,6 @@ impl IndexerService { service_impl: Arc::new(options.service_impl), escrow_accounts, domain_separator, - value_check_sender: options.value_check_sender, }); // Rate limits by allowing bursts of 10 requests and requiring 100ms of diff --git a/common/src/tap.rs b/common/src/tap.rs index b3dbd3d33..43df83d3f 100644 --- a/common/src/tap.rs +++ b/common/src/tap.rs @@ -25,9 +25,7 @@ use tracing::error; mod checks; mod receipt_store; -pub use checks::value_check::{ - create_value_check, AgoraQuery, CostModelSource, ValueCheckReceiver, ValueCheckSender, -}; +pub use checks::value_check::{AgoraQuery, CostModelSource}; #[derive(Clone)] pub struct IndexerTapContext { @@ -50,7 +48,6 @@ impl IndexerTapContext { domain_separator: Eip712Domain, timestamp_error_tolerance: Duration, receipt_max_value: u128, - value_check_receiver: ValueCheckReceiver, ) -> Vec { vec![ Arc::new(AllocationEligible::new(indexer_allocations)), @@ -59,9 +56,9 @@ impl IndexerTapContext { domain_separator.clone(), )), Arc::new(TimestampCheck::new(timestamp_error_tolerance)), - Arc::new(DenyListCheck::new(pgpool, escrow_accounts, domain_separator).await), + Arc::new(DenyListCheck::new(pgpool.clone(), escrow_accounts, domain_separator).await), Arc::new(ReceiptMaxValueCheck::new(receipt_max_value)), - Arc::new(MinimumValue::new(value_check_receiver)), + Arc::new(MinimumValue::new(pgpool).await), ] } diff --git a/common/src/tap/checks/value_check.rs b/common/src/tap/checks/value_check.rs index 7c71739b9..6b260515e 100644 --- a/common/src/tap/checks/value_check.rs +++ b/common/src/tap/checks/value_check.rs @@ -5,7 +5,6 @@ use anyhow::anyhow; use bigdecimal::ToPrimitive; use cost_model::CostModel; use sqlx::{postgres::PgListener, PgPool}; -use tracing::error; use std::{ cmp::min, collections::HashMap, @@ -13,10 +12,8 @@ use std::{ time::Duration, }; use thegraph_core::DeploymentId; -use tokio::{ - sync::mpsc::{Receiver, Sender}, - task::JoinHandle, -}; +use tokio::task::JoinHandle; +use tracing::error; use ttl_cache::TtlCache; use tap_core::receipt::{ @@ -30,60 +27,25 @@ pub struct MinimumValue { model_handle: JoinHandle<()>, } -#[derive(Clone)] -pub struct ValueCheckSender { - pub tx_cost_model: Sender, -} - -pub struct ValueCheckReceiver { - rx_cost_model: Receiver, -} +impl MinimumValue { + pub async fn new(pgpool: PgPool) -> Self { + let cost_model_cache = Arc::new(Mutex::new(HashMap::::new())); -pub fn create_value_check(size: usize) -> (ValueCheckSender, ValueCheckReceiver) { - let (tx_cost_model, rx_cost_model) = tokio::sync::mpsc::channel(size); + let mut pglistener = PgListener::connect_with(&pgpool.clone()).await.unwrap(); + pglistener.listen("cost_models_update_notify").await.expect( + "should be able to subscribe to Postgres Notify events on the channel \ + 'cost_models_update_notify'", + ); - ( - ValueCheckSender { tx_cost_model }, - ValueCheckReceiver { rx_cost_model }, - ) -} + // TODO start watcher + let cancel_token = tokio_util::sync::CancellationToken::new(); -impl MinimumValue { - pub fn new(ValueCheckReceiver { mut rx_cost_model }: ValueCheckReceiver) -> Self { - let cost_model_cache = Arc::new(Mutex::new(HashMap::::new())); - let cache = cost_model_cache.clone(); - let model_handle = tokio::spawn(async move { - loop { - let model = rx_cost_model.recv().await; - match model { - Some(value) => { - let deployment_id = value.deployment_id; - - if let Some(query) = cache.lock().unwrap().get_mut(&deployment_id) { - let _ = query.insert_model(value).inspect_err(|err| { - tracing::error!( - "Error while compiling cost model for deployment id {}. Error: {}", - deployment_id, err - ) - }); - } else { - match CostModelCache::new(value) { - Ok(value) => { - cache.lock().unwrap().insert(deployment_id, value); - } - Err(err) => { - tracing::error!( - "Error while compiling cost model for deployment id {}. Error: {}", - deployment_id, err - ) - } - } - } - } - None => break, - } - } - }); + let model_handle = tokio::spawn(Self::cost_models_watcher( + pgpool.clone(), + pglistener, + cost_model_cache.clone(), + cancel_token.clone(), + )); Self { cost_model_cache, @@ -92,17 +54,11 @@ impl MinimumValue { } async fn cost_models_watcher( - pgpool: PgPool, + _pgpool: PgPool, mut pglistener: PgListener, - denylist: Arc>>, + cost_model_cache: Arc>>, cancel_token: tokio_util::sync::CancellationToken, ) { - #[derive(serde::Deserialize)] - struct DenylistNotification { - tg_op: String, - deployment: DeploymentId, - } - loop { tokio::select! { _ = cancel_token.cancelled() => { @@ -112,39 +68,58 @@ impl MinimumValue { pg_notification = pglistener.recv() => { let pg_notification = pg_notification.expect( "should be able to receive Postgres Notify events on the channel \ - 'scalar_tap_deny_notification'", + 'cost_models_update_notify'", ); - let denylist_notification: DenylistNotification = + let cost_model_notification: CostModelNotification = serde_json::from_str(pg_notification.payload()).expect( "should be able to deserialize the Postgres Notify event payload as a \ - DenylistNotification", + CostModelNotification", ); - match denylist_notification.tg_op.as_str() { + let deployment_id = cost_model_notification.deployment; + + match cost_model_notification.tg_op.as_str() { "INSERT" => { - denylist - .write() - .unwrap() - .insert(denylist_notification.sender_address); + let cost_model_source: CostModelSource = cost_model_notification.into(); + let mut cost_model_cache = cost_model_cache + .lock() + .unwrap(); + + match cost_model_cache.get_mut(&deployment_id) { + Some(cache) => { + let _ = cache.insert_model(cost_model_source); + }, + None => { + if let Ok(cache) = CostModelCache::new(cost_model_source).inspect_err(|err| { + tracing::error!( + "Error while compiling cost model for deployment id {}. Error: {}", + deployment_id, err + ) + }) { + cost_model_cache.insert(deployment_id, cache); + } + }, + } } "DELETE" => { - denylist - .write() + cost_model_cache + .lock() .unwrap() - .remove(&denylist_notification.sender_address); + .remove(&cost_model_notification.deployment); } - // UPDATE and TRUNCATE are not expected to happen. Reload the entire denylist. + // UPDATE and TRUNCATE are not expected to happen. Reload the entire cost + // model cache. _ => { error!( - "Received an unexpected denylist table notification: {}. Reloading entire \ - denylist.", - denylist_notification.tg_op + "Received an unexpected cost model table notification: {}. Reloading entire \ + cost model.", + cost_model_notification.tg_op ); - Self::sender_denylist_reload(pgpool.clone(), denylist.clone()) - .await - .expect("should be able to reload the sender denylist") + // Self::sender_denylist_reload(pgpool.clone(), denylist.clone()) + // .await + // .expect("should be able to reload cost models") } } } @@ -229,6 +204,24 @@ pub struct CostModelSource { pub variables: String, } +#[derive(serde::Deserialize)] +struct CostModelNotification { + tg_op: String, + deployment: DeploymentId, + model: String, + variables: String, +} + +impl From for CostModelSource { + fn from(value: CostModelNotification) -> Self { + CostModelSource { + deployment_id: value.deployment, + model: value.model, + variables: value.variables, + } + } +} + pub struct CostModelCache { models: TtlCache, latest_model: CostModel, diff --git a/service/src/routes/cost.rs b/service/src/routes/cost.rs index 630f5987f..fbdbd9df8 100644 --- a/service/src/routes/cost.rs +++ b/service/src/routes/cost.rs @@ -7,12 +7,12 @@ use std::sync::Arc; use async_graphql::{Context, EmptyMutation, EmptySubscription, Object, Schema, SimpleObject}; use async_graphql_axum::{GraphQLRequest, GraphQLResponse}; use axum::extract::State; +use indexer_common::tap::CostModelSource; use lazy_static::lazy_static; use prometheus::{ register_counter, register_counter_vec, register_histogram, register_histogram_vec, Counter, CounterVec, Histogram, HistogramVec, }; -use indexer_common::tap::CostModelSource; use serde::{Deserialize, Serialize}; use serde_json::Value; use thegraph_core::DeploymentId; @@ -140,18 +140,9 @@ impl Query { ) -> Result, anyhow::Error> { let state = &ctx.data_unchecked::>(); - let cost_model_sender = &state.value_check_sender; - let pool = &state.database; let cost_models = database::cost_models(pool, &deployment_ids).await?; - for model in &cost_models { - let _ = cost_model_sender - .tx_cost_model - .send(CostModelSource::from(model.clone())) - .await; - } - Ok(cost_models.into_iter().map(|m| m.into()).collect()) } @@ -160,19 +151,9 @@ impl Query { ctx: &Context<'_>, deployment_id: DeploymentId, ) -> Result, anyhow::Error> { - - let state = &ctx.data_unchecked::>(); - let cost_model_sender = &state.value_check_sender; let pool = &ctx.data_unchecked::>().database; let model = database::cost_model(pool, &deployment_id).await?; - if let Some(model) = &model { - let _ = cost_model_sender - .tx_cost_model - .send(CostModelSource::from(model.clone())) - .await; - } - Ok(model.map(GraphQlCostModel::from)) } } diff --git a/service/src/service.rs b/service/src/service.rs index 7ef9f831f..e583bff75 100644 --- a/service/src/service.rs +++ b/service/src/service.rs @@ -7,9 +7,8 @@ use std::time::Duration; use super::{error::SubgraphServiceError, routes}; use anyhow::anyhow; use axum::{async_trait, routing::post, Json, Router}; -use indexer_common::{ - indexer_service::http::{AttestationOutput, IndexerServiceImpl, IndexerServiceResponse}, - tap::{create_value_check, ValueCheckSender}, +use indexer_common::indexer_service::http::{ + AttestationOutput, IndexerServiceImpl, IndexerServiceResponse, }; use indexer_config::Config; use reqwest::Url; @@ -68,7 +67,6 @@ pub struct SubgraphServiceState { pub graph_node_client: reqwest::Client, pub graph_node_status_url: &'static Url, pub graph_node_query_base_url: &'static Url, - pub value_check_sender: ValueCheckSender, } struct SubgraphService { @@ -148,9 +146,6 @@ pub async fn run() -> anyhow::Result<()> { build_info::build_info!(fn build_info); let release = IndexerServiceRelease::from(build_info()); - // arbitrary value - let (value_check_sender, value_check_receiver) = create_value_check(10); - // Some of the subgraph service configuration goes into the so-called // "state", which will be passed to any request handler, middleware etc. // that is involved in serving requests @@ -166,7 +161,6 @@ pub async fn run() -> anyhow::Result<()> { .expect("Failed to init HTTP client for Graph Node"), graph_node_status_url: &config.graph_node.status_url, graph_node_query_base_url: &config.graph_node.query_url, - value_check_sender: value_check_sender.clone(), }); IndexerService::run(IndexerServiceOptions { @@ -178,8 +172,6 @@ pub async fn run() -> anyhow::Result<()> { .route("/cost", post(routes::cost::cost)) .route("/status", post(routes::status)) .with_state(state), - value_check_receiver, - value_check_sender, }) .await } From f67e6143c5e0fd56a7b86830c663e662565e407c Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Mon, 7 Oct 2024 16:13:53 +0200 Subject: [PATCH 10/21] refactor: update cost.rs Signed-off-by: Gustavo Inacio --- service/src/routes/cost.rs | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/service/src/routes/cost.rs b/service/src/routes/cost.rs index fbdbd9df8..08d4eb2cf 100644 --- a/service/src/routes/cost.rs +++ b/service/src/routes/cost.rs @@ -7,7 +7,6 @@ use std::sync::Arc; use async_graphql::{Context, EmptyMutation, EmptySubscription, Object, Schema, SimpleObject}; use async_graphql_axum::{GraphQLRequest, GraphQLResponse}; use axum::extract::State; -use indexer_common::tap::CostModelSource; use lazy_static::lazy_static; use prometheus::{ register_counter, register_counter_vec, register_histogram, register_histogram_vec, Counter, @@ -67,16 +66,6 @@ pub struct GraphQlCostModel { pub variables: Option, } -impl From for CostModelSource { - fn from(value: CostModel) -> Self { - Self { - deployment_id: value.deployment, - model: value.model.unwrap_or_default(), - variables: value.variables.unwrap_or_default().to_string(), - } - } -} - impl From for GraphQlCostModel { fn from(model: CostModel) -> Self { Self { @@ -138,11 +127,8 @@ impl Query { ctx: &Context<'_>, deployment_ids: Vec, ) -> Result, anyhow::Error> { - let state = &ctx.data_unchecked::>(); - - let pool = &state.database; + let pool = &ctx.data_unchecked::>().database; let cost_models = database::cost_models(pool, &deployment_ids).await?; - Ok(cost_models.into_iter().map(|m| m.into()).collect()) } @@ -152,9 +138,9 @@ impl Query { deployment_id: DeploymentId, ) -> Result, anyhow::Error> { let pool = &ctx.data_unchecked::>().database; - let model = database::cost_model(pool, &deployment_id).await?; - - Ok(model.map(GraphQlCostModel::from)) + database::cost_model(pool, &deployment_id) + .await + .map(|model_opt| model_opt.map(GraphQlCostModel::from)) } } From 26cf47a703f870c43d03bbdbac29974921cae2d9 Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Mon, 7 Oct 2024 16:18:25 +0200 Subject: [PATCH 11/21] refactor: deserialize only once Signed-off-by: Gustavo Inacio --- .../indexer_service/http/indexer_service.rs | 7 +++---- .../indexer_service/http/request_handler.rs | 20 +++++++++---------- service/src/service.rs | 8 ++++---- 3 files changed, 16 insertions(+), 19 deletions(-) diff --git a/common/src/indexer_service/http/indexer_service.rs b/common/src/indexer_service/http/indexer_service.rs index e3aca9640..887db7f82 100644 --- a/common/src/indexer_service/http/indexer_service.rs +++ b/common/src/indexer_service/http/indexer_service.rs @@ -66,15 +66,14 @@ pub enum AttestationOutput { #[async_trait] pub trait IndexerServiceImpl { type Error: std::error::Error; - type Request: DeserializeOwned + Send + Debug + Serialize; type Response: IndexerServiceResponse + Sized; type State: Send + Sync; - async fn process_request( + async fn process_request( &self, manifest_id: DeploymentId, - request: Self::Request, - ) -> Result<(Self::Request, Self::Response), Self::Error>; + request: Request, + ) -> Result<(Request, Self::Response), Self::Error>; } #[derive(Debug, Error)] diff --git a/common/src/indexer_service/http/request_handler.rs b/common/src/indexer_service/http/request_handler.rs index 8cfe000c5..c3a7a6486 100644 --- a/common/src/indexer_service/http/request_handler.rs +++ b/common/src/indexer_service/http/request_handler.rs @@ -80,7 +80,13 @@ where { trace!("Handling request for deployment `{manifest_id}`"); - let request = + #[derive(Debug, serde::Deserialize, serde::Serialize)] + pub struct QueryBody { + pub query: String, + pub variables: Option>, + } + + let request: QueryBody = serde_json::from_slice(&body).map_err(|e| IndexerServiceError::InvalidRequest(e.into()))?; let Some(receipt) = receipt.into_signed_receipt() else { @@ -113,15 +119,7 @@ where let allocation_id = receipt.message.allocation_id; - #[derive(Debug, serde::Deserialize)] - pub struct QueryBody { - pub query: String, - pub variables: Option>, - } - - let query_body: QueryBody = - serde_json::from_slice(&body).map_err(|e| IndexerServiceError::InvalidRequest(e.into()))?; - let variables = query_body + let variables = request .variables .as_ref() .map(ToString::to_string) @@ -129,7 +127,7 @@ where let mut ctx = Context::new(); ctx.insert(AgoraQuery { deployment_id: manifest_id, - query: query_body.query.clone(), + query: request.query.clone(), variables, }); diff --git a/service/src/service.rs b/service/src/service.rs index e583bff75..12504677b 100644 --- a/service/src/service.rs +++ b/service/src/service.rs @@ -12,6 +12,7 @@ use indexer_common::indexer_service::http::{ }; use indexer_config::Config; use reqwest::Url; +use serde::{de::DeserializeOwned, Serialize}; use serde_json::{json, Value}; use sqlx::PgPool; use thegraph_core::DeploymentId; @@ -82,15 +83,14 @@ impl SubgraphService { #[async_trait] impl IndexerServiceImpl for SubgraphService { type Error = SubgraphServiceError; - type Request = serde_json::Value; type Response = SubgraphServiceResponse; type State = SubgraphServiceState; - async fn process_request( + async fn process_request( &self, deployment: DeploymentId, - request: Self::Request, - ) -> Result<(Self::Request, Self::Response), Self::Error> { + request: Request, + ) -> Result<(Request, Self::Response), Self::Error> { let deployment_url = self .state .graph_node_query_base_url From e8eaae716cbf043876769b9b0923e563938f59cd Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Mon, 7 Oct 2024 16:27:29 +0200 Subject: [PATCH 12/21] chore: update message Signed-off-by: Gustavo Inacio --- common/src/tap/checks/value_check.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/src/tap/checks/value_check.rs b/common/src/tap/checks/value_check.rs index 6b260515e..3555d0899 100644 --- a/common/src/tap/checks/value_check.rs +++ b/common/src/tap/checks/value_check.rs @@ -176,7 +176,7 @@ impl Check for MinimumValue { Ok(()) } else { return Err(CheckError::Failed(anyhow!( - "Query receipt does not have the minimum value. Expected value: {}. Minimum value: {}.", + "Query receipt does not have the minimum value. Expected value: {}. Received value: {}.", expected_value, value, ))); } From a609a8214ef6f7015943951addebac6585bbf296 Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Mon, 7 Oct 2024 17:18:12 +0200 Subject: [PATCH 13/21] refactor: update cost model to use history Signed-off-by: Gustavo Inacio --- common/src/tap/checks/value_check.rs | 27 ++++++------ .../20230901142040_cost_models.down.sql | 8 +++- migrations/20230901142040_cost_models.up.sql | 43 ++++++++++++++++-- service/src/database.rs | 44 ++++++------------- 4 files changed, 74 insertions(+), 48 deletions(-) diff --git a/common/src/tap/checks/value_check.rs b/common/src/tap/checks/value_check.rs index 3555d0899..396e35c0e 100644 --- a/common/src/tap/checks/value_check.rs +++ b/common/src/tap/checks/value_check.rs @@ -12,7 +12,6 @@ use std::{ time::Duration, }; use thegraph_core::DeploymentId; -use tokio::task::JoinHandle; use tracing::error; use ttl_cache::TtlCache; @@ -24,7 +23,15 @@ use tap_core::receipt::{ pub struct MinimumValue { cost_model_cache: Arc>>, - model_handle: JoinHandle<()>, + watcher_cancel_token: tokio_util::sync::CancellationToken, +} + +impl Drop for MinimumValue { + fn drop(&mut self) { + // Clean shutdown for the sender_denylist_watcher + // Though since it's not a critical task, we don't wait for it to finish (join). + self.watcher_cancel_token.cancel(); + } } impl MinimumValue { @@ -37,19 +44,17 @@ impl MinimumValue { 'cost_models_update_notify'", ); - // TODO start watcher - let cancel_token = tokio_util::sync::CancellationToken::new(); - - let model_handle = tokio::spawn(Self::cost_models_watcher( + let watcher_cancel_token = tokio_util::sync::CancellationToken::new(); + tokio::spawn(Self::cost_models_watcher( pgpool.clone(), pglistener, cost_model_cache.clone(), - cancel_token.clone(), + watcher_cancel_token.clone(), )); Self { cost_model_cache, - model_handle, + watcher_cancel_token, } } @@ -128,12 +133,6 @@ impl MinimumValue { } } -impl Drop for MinimumValue { - fn drop(&mut self) { - self.model_handle.abort(); - } -} - impl MinimumValue { fn get_expected_value(&self, agora_query: &AgoraQuery) -> anyhow::Result { // get agora model for the allocation_id diff --git a/migrations/20230901142040_cost_models.down.sql b/migrations/20230901142040_cost_models.down.sql index 885eeefc6..744b443e9 100644 --- a/migrations/20230901142040_cost_models.down.sql +++ b/migrations/20230901142040_cost_models.down.sql @@ -1,2 +1,8 @@ -- Add down migration script here -DROP TABLE "CostModels"; +DROP TRIGGER IF EXISTS cost_models_update ON "CostModelsHistory" CASCADE; + +DROP FUNCTION IF EXISTS cost_models_update_notify() CASCADE; + +DROP VIEW "CostModels"; + +DROP TABLE "CostModelsHistory"; diff --git a/migrations/20230901142040_cost_models.up.sql b/migrations/20230901142040_cost_models.up.sql index 96b10f0bd..a5e1862fc 100644 --- a/migrations/20230901142040_cost_models.up.sql +++ b/migrations/20230901142040_cost_models.up.sql @@ -1,8 +1,45 @@ -CREATE TABLE IF NOT EXISTS "CostModels" +CREATE TABLE IF NOT EXISTS "CostModelsHistory" ( - id INT, + id SERIAL PRIMARY KEY, deployment VARCHAR NOT NULL, model TEXT, variables JSONB, - PRIMARY KEY( deployment ) + "createdAt" TIMESTAMP WITH TIME ZONE, + "updatedAt" TIMESTAMP WITH TIME ZONE ); + +CREATE VIEW "CostModels" AS SELECT id, + deployment, + model, + variables, + "createdAt", + "updatedAt" + FROM "CostModelsHistory" t1 + JOIN + ( + SELECT MAX(id) + FROM "CostModelsHistory" + GROUP BY deployment + ) t2 + ON t1.id = t2.MAX; + +CREATE FUNCTION cost_models_update_notify() +RETURNS trigger AS +$$ +BEGIN + IF TG_OP = 'DELETE' THEN + PERFORM pg_notify('cost_models_update_notification', format('{"tg_op": "DELETE", "deployment": "%s"}', OLD.deployment)); + RETURN OLD; + ELSIF TG_OP = 'INSERT' THEN + PERFORM pg_notify('cost_models_update_notification', format('{"tg_op": "INSERT", "deployment": "%s", "model": "%s"}', NEW.deployment, NEW.model)); + RETURN NEW; + ELSE + PERFORM pg_notify('cost_models_update_notification', format('{"tg_op": "%s", "deployment": "%s", "model": "%s"}', NEW.deployment, NEW.model)); + RETURN NEW; + END IF; +END; +$$ LANGUAGE 'plpgsql'; + +CREATE TRIGGER cost_models_update AFTER INSERT OR UPDATE OR DELETE + ON "CostModelsHistory" + FOR EACH ROW EXECUTE PROCEDURE cost_models_update_notify(); diff --git a/service/src/database.rs b/service/src/database.rs index ab0cd80e0..31708b3c3 100644 --- a/service/src/database.rs +++ b/service/src/database.rs @@ -26,7 +26,7 @@ pub async fn connect(url: &str) -> PgPool { /// These can have "global" as the deployment ID. #[derive(Debug, Clone)] struct DbCostModel { - pub deployment: String, + pub deployment: Option, pub model: Option, pub variables: Option, } @@ -46,7 +46,12 @@ impl TryFrom for CostModel { fn try_from(db_model: DbCostModel) -> Result { Ok(Self { - deployment: DeploymentId::from_str(&db_model.deployment)?, + deployment: DeploymentId::from_str(&db_model.deployment.ok_or( + ParseDeploymentIdError::InvalidIpfsHashLength { + value: String::new(), + length: 0, + }, + )?)?, model: db_model.model, variables: db_model.variables, }) @@ -57,7 +62,7 @@ impl From for DbCostModel { fn from(model: CostModel) -> Self { let deployment = model.deployment; DbCostModel { - deployment: format!("{deployment:#x}"), + deployment: Some(format!("{deployment:#x}")), model: model.model, variables: model.variables, } @@ -210,28 +215,11 @@ mod test { use super::*; - async fn setup_cost_models_table(pool: &PgPool) { - sqlx::query!( - r#" - CREATE TABLE "CostModels"( - id INT, - deployment VARCHAR NOT NULL, - model TEXT, - variables JSONB, - PRIMARY KEY( deployment ) - ); - "#, - ) - .execute(pool) - .await - .expect("Create test instance in db"); - } - async fn add_cost_models(pool: &PgPool, models: Vec) { for model in models { sqlx::query!( r#" - INSERT INTO "CostModels" (deployment, model) + INSERT INTO "CostModelsHistory" (deployment, model) VALUES ($1, $2); "#, model.deployment, @@ -249,7 +237,7 @@ mod test { fn global_cost_model() -> DbCostModel { DbCostModel { - deployment: "global".to_string(), + deployment: Some("global".to_string()), model: Some("default => 0.00001;".to_string()), variables: None, } @@ -281,7 +269,7 @@ mod test { ] } - #[sqlx::test] + #[sqlx::test(migrations = "../migrations")] async fn success_cost_models(pool: PgPool) { let test_models = test_data(); let test_deployments = test_models @@ -289,7 +277,6 @@ mod test { .map(|model| model.deployment) .collect::>(); - setup_cost_models_table(&pool).await; add_cost_models(&pool, to_db_models(test_models.clone())).await; // First test: query without deployment filter @@ -344,7 +331,7 @@ mod test { } } - #[sqlx::test] + #[sqlx::test(migrations = "../migrations")] async fn global_fallback_cost_models(pool: PgPool) { let test_models = test_data(); let test_deployments = test_models @@ -353,7 +340,6 @@ mod test { .collect::>(); let global_model = global_cost_model(); - setup_cost_models_table(&pool).await; add_cost_models(&pool, to_db_models(test_models.clone())).await; add_cost_models(&pool, vec![global_model.clone()]).await; @@ -436,9 +422,8 @@ mod test { assert_eq!(missing_model.model, global_model.model); } - #[sqlx::test] + #[sqlx::test(migrations = "../migrations")] async fn success_cost_model(pool: PgPool) { - setup_cost_models_table(&pool).await; add_cost_models(&pool, to_db_models(test_data())).await; let deployment_id_from_bytes = DeploymentId::from_str( @@ -459,12 +444,11 @@ mod test { assert_eq!(model.model, Some("default => 0.00025;".to_string())); } - #[sqlx::test] + #[sqlx::test(migrations = "../migrations")] async fn global_fallback_cost_model(pool: PgPool) { let test_models = test_data(); let global_model = global_cost_model(); - setup_cost_models_table(&pool).await; add_cost_models(&pool, to_db_models(test_models.clone())).await; add_cost_models(&pool, vec![global_model.clone()]).await; From 77911cc40fb86d8c997ce86ff242897a46305f7f Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Tue, 8 Oct 2024 15:29:16 +0200 Subject: [PATCH 14/21] refactor: add value reload Signed-off-by: Gustavo Inacio --- common/src/tap/checks/value_check.rs | 100 +++++++++++++++++++++------ 1 file changed, 77 insertions(+), 23 deletions(-) diff --git a/common/src/tap/checks/value_check.rs b/common/src/tap/checks/value_check.rs index 396e35c0e..d28c5c314 100644 --- a/common/src/tap/checks/value_check.rs +++ b/common/src/tap/checks/value_check.rs @@ -8,7 +8,8 @@ use sqlx::{postgres::PgListener, PgPool}; use std::{ cmp::min, collections::HashMap, - sync::{Arc, Mutex}, + str::FromStr, + sync::{Arc, Mutex, RwLock}, time::Duration, }; use thegraph_core::DeploymentId; @@ -22,7 +23,7 @@ use tap_core::receipt::{ }; pub struct MinimumValue { - cost_model_cache: Arc>>, + cost_model_cache: Arc>>>, watcher_cancel_token: tokio_util::sync::CancellationToken, } @@ -36,7 +37,9 @@ impl Drop for MinimumValue { impl MinimumValue { pub async fn new(pgpool: PgPool) -> Self { - let cost_model_cache = Arc::new(Mutex::new(HashMap::::new())); + let cost_model_cache = Arc::new(RwLock::new( + HashMap::>::new(), + )); let mut pglistener = PgListener::connect_with(&pgpool.clone()).await.unwrap(); pglistener.listen("cost_models_update_notify").await.expect( @@ -58,10 +61,23 @@ impl MinimumValue { } } + fn get_expected_value(&self, agora_query: &AgoraQuery) -> anyhow::Result { + // get agora model for the allocation_id + let cache = self.cost_model_cache.read().unwrap(); + // on average, we'll have zero or one model + let models = cache.get(&agora_query.deployment_id); + + let expected_value = models + .map(|cache| cache.lock().unwrap().cost(agora_query)) + .unwrap_or_default(); + + Ok(expected_value) + } + async fn cost_models_watcher( - _pgpool: PgPool, + pgpool: PgPool, mut pglistener: PgListener, - cost_model_cache: Arc>>, + cost_model_cache: Arc>>>, cancel_token: tokio_util::sync::CancellationToken, ) { loop { @@ -88,12 +104,12 @@ impl MinimumValue { "INSERT" => { let cost_model_source: CostModelSource = cost_model_notification.into(); let mut cost_model_cache = cost_model_cache - .lock() + .write() .unwrap(); match cost_model_cache.get_mut(&deployment_id) { Some(cache) => { - let _ = cache.insert_model(cost_model_source); + let _ = cache.lock().unwrap().insert_model(cost_model_source); }, None => { if let Ok(cache) = CostModelCache::new(cost_model_source).inspect_err(|err| { @@ -102,14 +118,14 @@ impl MinimumValue { deployment_id, err ) }) { - cost_model_cache.insert(deployment_id, cache); + cost_model_cache.insert(deployment_id, Mutex::new(cache)); } }, } } "DELETE" => { cost_model_cache - .lock() + .write() .unwrap() .remove(&cost_model_notification.deployment); } @@ -122,29 +138,47 @@ impl MinimumValue { cost_model_notification.tg_op ); - // Self::sender_denylist_reload(pgpool.clone(), denylist.clone()) - // .await - // .expect("should be able to reload cost models") + Self::value_check_reload(&pgpool, cost_model_cache.clone()) + .await + .expect("should be able to reload cost models") } } } } } } -} -impl MinimumValue { - fn get_expected_value(&self, agora_query: &AgoraQuery) -> anyhow::Result { - // get agora model for the allocation_id - let mut cache = self.cost_model_cache.lock().unwrap(); - // on average, we'll have zero or one model - let models = cache.get_mut(&agora_query.deployment_id); + async fn value_check_reload( + pgpool: &PgPool, + cost_model_cache: Arc>>>, + ) -> anyhow::Result<()> { + let models = sqlx::query!( + r#" + SELECT deployment, model, variables + FROM "CostModels" + WHERE deployment != 'global' + ORDER BY deployment ASC + "# + ) + .fetch_all(pgpool) + .await?; + let models = models + .into_iter() + .map(|record| { + let deployment_id = DeploymentId::from_str(&record.deployment.unwrap())?; + let model = CostModelCache::new(CostModelSource { + deployment_id, + model: record.model.unwrap(), + variables: record.variables.unwrap().to_string(), + })?; + + Ok::<_, anyhow::Error>((deployment_id, Mutex::new(model))) + }) + .collect::, _>>()?; - let expected_value = models - .map(|cache| cache.cost(agora_query)) - .unwrap_or_default(); + *(cost_model_cache.write().unwrap()) = models; - Ok(expected_value) + Ok(()) } } @@ -279,3 +313,23 @@ impl CostModelCache { .unwrap_or_default() } } + +#[cfg(test)] +mod tests { + use sqlx::PgPool; + + #[sqlx::test(migrations = "../migrations")] + async fn initialize_check(pg_pool: PgPool) {} + + #[sqlx::test(migrations = "../migrations")] + async fn should_initialize_check_with_caches(pg_pool: PgPool) {} + + #[sqlx::test(migrations = "../migrations")] + async fn should_add_model_to_cache_on_insert(pg_pool: PgPool) {} + + #[sqlx::test(migrations = "../migrations")] + async fn should_expire_old_model(pg_pool: PgPool) {} + + #[sqlx::test(migrations = "../migrations")] + async fn should_verify_global_model(pg_pool: PgPool) {} +} From 238f00ececa33b61f4ee2ea13e85894f7df689f3 Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Thu, 10 Oct 2024 23:05:47 +0200 Subject: [PATCH 15/21] chore: update sqlx Signed-off-by: Gustavo Inacio --- ...2366d18653f57a9d4b6f231c00b9eea60392a0a6f.json | 15 +++++++++++++++ ...f3c504bd64ca10a64a95e4eb8bd7bb822c1625ed6.json | 12 ------------ ...c568edc4204bbbd32aac6f7da7e99fb501ca5cc14.json | 2 +- ...9c9aa1732bb3638a3a4f942a665cf1fd38eb70c2d.json | 2 +- ...9a2a651e43b6b9ffe9834df85a62707d3a2d051b4.json | 2 +- ...0c204651cfd94b4b82f26baf0755efa80e6045c0a.json | 2 +- ...16d7893b4f7209c985c367215b2eed25adc78c462.json | 15 --------------- 7 files changed, 19 insertions(+), 31 deletions(-) create mode 100644 .sqlx/query-4129630725eb7f563f2b0532366d18653f57a9d4b6f231c00b9eea60392a0a6f.json delete mode 100644 .sqlx/query-58c4216b3ce62314fa727e0f3c504bd64ca10a64a95e4eb8bd7bb822c1625ed6.json delete mode 100644 .sqlx/query-ef6affb9039ad19a69f4a5116d7893b4f7209c985c367215b2eed25adc78c462.json diff --git a/.sqlx/query-4129630725eb7f563f2b0532366d18653f57a9d4b6f231c00b9eea60392a0a6f.json b/.sqlx/query-4129630725eb7f563f2b0532366d18653f57a9d4b6f231c00b9eea60392a0a6f.json new file mode 100644 index 000000000..b617582e0 --- /dev/null +++ b/.sqlx/query-4129630725eb7f563f2b0532366d18653f57a9d4b6f231c00b9eea60392a0a6f.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO \"CostModelsHistory\" (deployment, model)\n VALUES ($1, $2);\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Varchar", + "Text" + ] + }, + "nullable": [] + }, + "hash": "4129630725eb7f563f2b0532366d18653f57a9d4b6f231c00b9eea60392a0a6f" +} diff --git a/.sqlx/query-58c4216b3ce62314fa727e0f3c504bd64ca10a64a95e4eb8bd7bb822c1625ed6.json b/.sqlx/query-58c4216b3ce62314fa727e0f3c504bd64ca10a64a95e4eb8bd7bb822c1625ed6.json deleted file mode 100644 index 033c25bb3..000000000 --- a/.sqlx/query-58c4216b3ce62314fa727e0f3c504bd64ca10a64a95e4eb8bd7bb822c1625ed6.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n CREATE TABLE \"CostModels\"(\n id INT,\n deployment VARCHAR NOT NULL,\n model TEXT,\n variables JSONB,\n PRIMARY KEY( deployment )\n );\n ", - "describe": { - "columns": [], - "parameters": { - "Left": [] - }, - "nullable": [] - }, - "hash": "58c4216b3ce62314fa727e0f3c504bd64ca10a64a95e4eb8bd7bb822c1625ed6" -} diff --git a/.sqlx/query-842bde7fba1c7652b7cfc2dc568edc4204bbbd32aac6f7da7e99fb501ca5cc14.json b/.sqlx/query-842bde7fba1c7652b7cfc2dc568edc4204bbbd32aac6f7da7e99fb501ca5cc14.json index 2b0cde21e..91e33b3ed 100644 --- a/.sqlx/query-842bde7fba1c7652b7cfc2dc568edc4204bbbd32aac6f7da7e99fb501ca5cc14.json +++ b/.sqlx/query-842bde7fba1c7652b7cfc2dc568edc4204bbbd32aac6f7da7e99fb501ca5cc14.json @@ -23,7 +23,7 @@ "Left": [] }, "nullable": [ - false, + true, true, true ] diff --git a/.sqlx/query-b54b1069daf03a377a0e7c09c9aa1732bb3638a3a4f942a665cf1fd38eb70c2d.json b/.sqlx/query-b54b1069daf03a377a0e7c09c9aa1732bb3638a3a4f942a665cf1fd38eb70c2d.json index 7dbd08638..d25a9de5d 100644 --- a/.sqlx/query-b54b1069daf03a377a0e7c09c9aa1732bb3638a3a4f942a665cf1fd38eb70c2d.json +++ b/.sqlx/query-b54b1069daf03a377a0e7c09c9aa1732bb3638a3a4f942a665cf1fd38eb70c2d.json @@ -25,7 +25,7 @@ ] }, "nullable": [ - false, + true, true, true ] diff --git a/.sqlx/query-d93dd26d7221c5e1ae15a919a2a651e43b6b9ffe9834df85a62707d3a2d051b4.json b/.sqlx/query-d93dd26d7221c5e1ae15a919a2a651e43b6b9ffe9834df85a62707d3a2d051b4.json index 6b2da69c8..c9e8b2f4e 100644 --- a/.sqlx/query-d93dd26d7221c5e1ae15a919a2a651e43b6b9ffe9834df85a62707d3a2d051b4.json +++ b/.sqlx/query-d93dd26d7221c5e1ae15a919a2a651e43b6b9ffe9834df85a62707d3a2d051b4.json @@ -25,7 +25,7 @@ ] }, "nullable": [ - false, + true, true, true ] diff --git a/.sqlx/query-e14503b633fc673b65448e70c204651cfd94b4b82f26baf0755efa80e6045c0a.json b/.sqlx/query-e14503b633fc673b65448e70c204651cfd94b4b82f26baf0755efa80e6045c0a.json index 619967273..08c287f3f 100644 --- a/.sqlx/query-e14503b633fc673b65448e70c204651cfd94b4b82f26baf0755efa80e6045c0a.json +++ b/.sqlx/query-e14503b633fc673b65448e70c204651cfd94b4b82f26baf0755efa80e6045c0a.json @@ -25,7 +25,7 @@ ] }, "nullable": [ - false, + true, true, true ] diff --git a/.sqlx/query-ef6affb9039ad19a69f4a5116d7893b4f7209c985c367215b2eed25adc78c462.json b/.sqlx/query-ef6affb9039ad19a69f4a5116d7893b4f7209c985c367215b2eed25adc78c462.json deleted file mode 100644 index 0c8cfd917..000000000 --- a/.sqlx/query-ef6affb9039ad19a69f4a5116d7893b4f7209c985c367215b2eed25adc78c462.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n INSERT INTO \"CostModels\" (deployment, model)\n VALUES ($1, $2);\n ", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Varchar", - "Text" - ] - }, - "nullable": [] - }, - "hash": "ef6affb9039ad19a69f4a5116d7893b4f7209c985c367215b2eed25adc78c462" -} From 8fb750402da648fecd371136325bcc9a9fde0fc4 Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Fri, 11 Oct 2024 22:30:11 +0200 Subject: [PATCH 16/21] refactor: remove ttl, use tasks for expire Signed-off-by: Gustavo Inacio --- Cargo.lock | 16 --- common/Cargo.toml | 1 - common/src/tap.rs | 2 +- common/src/tap/checks/value_check.rs | 183 +++++++++++++++------------ 4 files changed, 104 insertions(+), 98 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 033486e4b..9472072cc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3165,7 +3165,6 @@ dependencies = [ "tower-http", "tower_governor", "tracing", - "ttl_cache", "wiremock 0.5.22", ] @@ -3650,12 +3649,6 @@ dependencies = [ "vcpkg", ] -[[package]] -name = "linked-hash-map" -version = "0.5.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" - [[package]] name = "linkme" version = "0.3.28" @@ -6702,15 +6695,6 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" -[[package]] -name = "ttl_cache" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4189890526f0168710b6ee65ceaedf1460c48a14318ceec933cb26baa492096a" -dependencies = [ - "linked-hash-map", -] - [[package]] name = "tungstenite" version = "0.23.0" diff --git a/common/Cargo.toml b/common/Cargo.toml index 4cf3515c1..911fcc048 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -31,7 +31,6 @@ regex = "1.7.1" axum-extra = { version = "0.9.3", features = [ "typed-header", ], default-features = false } -ttl_cache = "0.5.1" autometrics = { version = "1.0.1", features = ["prometheus-exporter"] } tower_governor = "0.3.2" tower-http = { version = "0.5.2", features = [ diff --git a/common/src/tap.rs b/common/src/tap.rs index 43df83d3f..151281b43 100644 --- a/common/src/tap.rs +++ b/common/src/tap.rs @@ -25,7 +25,7 @@ use tracing::error; mod checks; mod receipt_store; -pub use checks::value_check::{AgoraQuery, CostModelSource}; +pub use checks::value_check::AgoraQuery; #[derive(Clone)] pub struct IndexerTapContext { diff --git a/common/src/tap/checks/value_check.rs b/common/src/tap/checks/value_check.rs index d28c5c314..65c3e2e43 100644 --- a/common/src/tap/checks/value_check.rs +++ b/common/src/tap/checks/value_check.rs @@ -7,14 +7,14 @@ use cost_model::CostModel; use sqlx::{postgres::PgListener, PgPool}; use std::{ cmp::min, - collections::HashMap, + collections::{hash_map::Entry, HashMap, VecDeque}, str::FromStr, sync::{Arc, Mutex, RwLock}, time::Duration, }; -use thegraph_core::DeploymentId; +use thegraph_core::{DeploymentId, ParseDeploymentIdError}; +use tokio::{sync::mpsc::channel, task::JoinHandle, time::sleep}; use tracing::error; -use ttl_cache::TtlCache; use tap_core::receipt::{ checks::{Check, CheckError, CheckResult}, @@ -22,8 +22,13 @@ use tap_core::receipt::{ Context, ReceiptWithState, }; +// we only accept receipts with minimal 1 wei grt +const MINIMAL_VALUE: u128 = 1; + +type CostModelMap = Arc>>>; + pub struct MinimumValue { - cost_model_cache: Arc>>>, + cost_model_cache: CostModelMap, watcher_cancel_token: tokio_util::sync::CancellationToken, } @@ -37,9 +42,7 @@ impl Drop for MinimumValue { impl MinimumValue { pub async fn new(pgpool: PgPool) -> Self { - let cost_model_cache = Arc::new(RwLock::new( - HashMap::>::new(), - )); + let cost_model_cache: CostModelMap = Default::default(); let mut pglistener = PgListener::connect_with(&pgpool.clone()).await.unwrap(); pglistener.listen("cost_models_update_notify").await.expect( @@ -64,12 +67,14 @@ impl MinimumValue { fn get_expected_value(&self, agora_query: &AgoraQuery) -> anyhow::Result { // get agora model for the allocation_id let cache = self.cost_model_cache.read().unwrap(); - // on average, we'll have zero or one model let models = cache.get(&agora_query.deployment_id); let expected_value = models - .map(|cache| cache.lock().unwrap().cost(agora_query)) - .unwrap_or_default(); + .map(|cache| { + let cache = cache.read().unwrap(); + cache.cost(agora_query) + }) + .unwrap_or(MINIMAL_VALUE); Ok(expected_value) } @@ -77,15 +82,33 @@ impl MinimumValue { async fn cost_models_watcher( pgpool: PgPool, mut pglistener: PgListener, - cost_model_cache: Arc>>>, + cost_model_cache: CostModelMap, cancel_token: tokio_util::sync::CancellationToken, ) { + let handles: Arc>>>> = + Default::default(); + let (tx, mut rx) = channel::(64); + loop { tokio::select! { _ = cancel_token.cancelled() => { break; } + Some(deployment_id) = rx.recv() => { + let mut cost_model_write = cost_model_cache.write().unwrap(); + if let Some(cache) = cost_model_write.get_mut(&deployment_id) { + cache.get_mut().unwrap().expire(); + } + if let Entry::Occupied(mut entry) = handles.lock().unwrap().entry(deployment_id) { + let vec = entry.get_mut(); + vec.pop_front(); + if vec.is_empty() { + entry.remove(); + } + } + + } pg_notification = pglistener.recv() => { let pg_notification = pg_notification.expect( "should be able to receive Postgres Notify events on the channel \ @@ -103,31 +126,38 @@ impl MinimumValue { match cost_model_notification.tg_op.as_str() { "INSERT" => { let cost_model_source: CostModelSource = cost_model_notification.into(); - let mut cost_model_cache = cost_model_cache - .write() - .unwrap(); - - match cost_model_cache.get_mut(&deployment_id) { - Some(cache) => { - let _ = cache.lock().unwrap().insert_model(cost_model_source); - }, - None => { - if let Ok(cache) = CostModelCache::new(cost_model_source).inspect_err(|err| { - tracing::error!( - "Error while compiling cost model for deployment id {}. Error: {}", - deployment_id, err - ) - }) { - cost_model_cache.insert(deployment_id, Mutex::new(cache)); - } - }, + { + let mut cost_model_write = cost_model_cache + .write() + .unwrap(); + let cache = cost_model_write.entry(deployment_id).or_default(); + let _ = cache.get_mut().unwrap().insert_model(cost_model_source); } + let _tx = tx.clone(); + + // expire after 60 seconds + handles.lock() + .unwrap() + .entry(deployment_id) + .or_default() + .push_back(tokio::spawn(async move { + // 1 minute after, we expire the older cache + sleep(Duration::from_secs(60)).await; + let _ = _tx.send(deployment_id).await; + })); } "DELETE" => { - cost_model_cache - .write() - .unwrap() - .remove(&cost_model_notification.deployment); + if let Entry::Occupied(mut entry) = cost_model_cache + .write().unwrap().entry(cost_model_notification.deployment) { + let should_remove = { + let mut cost_model = entry.get_mut().write().unwrap(); + cost_model.expire(); + cost_model.is_empty() + }; + if should_remove { + entry.remove(); + } + } } // UPDATE and TRUNCATE are not expected to happen. Reload the entire cost // model cache. @@ -138,6 +168,17 @@ impl MinimumValue { cost_model_notification.tg_op ); + { + // clear all pending expire + let mut handles = handles.lock().unwrap(); + for maps in handles.values() { + for handle in maps { + handle.abort(); + } + } + handles.clear(); + } + Self::value_check_reload(&pgpool, cost_model_cache.clone()) .await .expect("should be able to reload cost models") @@ -150,7 +191,7 @@ impl MinimumValue { async fn value_check_reload( pgpool: &PgPool, - cost_model_cache: Arc>>>, + cost_model_cache: CostModelMap, ) -> anyhow::Result<()> { let models = sqlx::query!( r#" @@ -166,13 +207,14 @@ impl MinimumValue { .into_iter() .map(|record| { let deployment_id = DeploymentId::from_str(&record.deployment.unwrap())?; - let model = CostModelCache::new(CostModelSource { + let mut model = CostModelCache::default(); + let _ = model.insert_model(CostModelSource { deployment_id, model: record.model.unwrap(), - variables: record.variables.unwrap().to_string(), - })?; + variables: record.variables.unwrap_or_default(), + }); - Ok::<_, anyhow::Error>((deployment_id, Mutex::new(model))) + Ok::<_, ParseDeploymentIdError>((deployment_id, RwLock::new(model))) }) .collect::, _>>()?; @@ -220,7 +262,7 @@ fn compile_cost_model(src: CostModelSource) -> anyhow::Result { if src.model.len() > (1 << 16) { return Err(anyhow!("CostModelTooLarge")); } - let model = CostModel::compile(&src.model, &src.variables)?; + let model = CostModel::compile(&src.model, &src.variables.to_string())?; Ok(model) } @@ -231,10 +273,10 @@ pub struct AgoraQuery { } #[derive(Clone, Eq, Hash, PartialEq)] -pub struct CostModelSource { +struct CostModelSource { pub deployment_id: DeploymentId, pub model: String, - pub variables: String, + pub variables: serde_json::Value, } #[derive(serde::Deserialize)] @@ -242,7 +284,7 @@ struct CostModelNotification { tg_op: String, deployment: DeploymentId, model: String, - variables: String, + variables: serde_json::Value, } impl From for CostModelSource { @@ -255,48 +297,29 @@ impl From for CostModelSource { } } -pub struct CostModelCache { - models: TtlCache, - latest_model: CostModel, - latest_source: CostModelSource, +#[derive(Default)] +struct CostModelCache { + models: VecDeque, } impl CostModelCache { - pub fn new(source: CostModelSource) -> anyhow::Result { - let model = compile_cost_model(source.clone())?; - Ok(Self { - latest_model: model, - latest_source: source, - // arbitrary number of models copy - models: TtlCache::new(10), - }) - } - fn insert_model(&mut self, source: CostModelSource) -> anyhow::Result<()> { - if source != self.latest_source { - let model = compile_cost_model(source.clone())?; - // update latest and insert into ttl the old model - let old_model = std::mem::replace(&mut self.latest_model, model); - self.latest_source = source.clone(); - - self.models - // arbitrary cache duration - .insert(source, old_model, Duration::from_secs(60)); - } + let model = compile_cost_model(source.clone())?; + self.models.push_back(model); Ok(()) } - fn get_models(&mut self) -> Vec<&CostModel> { - let mut values: Vec<&CostModel> = self.models.iter().map(|(_, v)| v).collect(); - values.push(&self.latest_model); - values + fn expire(&mut self) { + self.models.pop_front(); } - fn cost(&mut self, query: &AgoraQuery) -> u128 { - let models = self.get_models(); + fn is_empty(&self) -> bool { + self.models.is_empty() + } - models - .into_iter() + fn cost(&self, query: &AgoraQuery) -> u128 { + self.models + .iter() .fold(None, |acc, model| { let value = model .cost(&query.query, &query.variables) @@ -310,7 +333,7 @@ impl CostModelCache { Some(value) } }) - .unwrap_or_default() + .unwrap_or(MINIMAL_VALUE) } } @@ -319,17 +342,17 @@ mod tests { use sqlx::PgPool; #[sqlx::test(migrations = "../migrations")] - async fn initialize_check(pg_pool: PgPool) {} + async fn initialize_check(_pg_pool: PgPool) {} #[sqlx::test(migrations = "../migrations")] - async fn should_initialize_check_with_caches(pg_pool: PgPool) {} + async fn should_initialize_check_with_caches(_pg_pool: PgPool) {} #[sqlx::test(migrations = "../migrations")] - async fn should_add_model_to_cache_on_insert(pg_pool: PgPool) {} + async fn should_add_model_to_cache_on_insert(_pg_pool: PgPool) {} #[sqlx::test(migrations = "../migrations")] - async fn should_expire_old_model(pg_pool: PgPool) {} + async fn should_expire_old_model(_pg_pool: PgPool) {} #[sqlx::test(migrations = "../migrations")] - async fn should_verify_global_model(pg_pool: PgPool) {} + async fn should_verify_global_model(_pg_pool: PgPool) {} } From 51c6207d9fac10a0900592551243f5e9b039e89b Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Fri, 11 Oct 2024 22:47:49 +0200 Subject: [PATCH 17/21] chore: add some todo comments Signed-off-by: Gustavo Inacio --- common/src/tap/checks/value_check.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/common/src/tap/checks/value_check.rs b/common/src/tap/checks/value_check.rs index 65c3e2e43..1c07e444a 100644 --- a/common/src/tap/checks/value_check.rs +++ b/common/src/tap/checks/value_check.rs @@ -69,6 +69,8 @@ impl MinimumValue { let cache = self.cost_model_cache.read().unwrap(); let models = cache.get(&agora_query.deployment_id); + // TODO check global cost model + let expected_value = models .map(|cache| { let cache = cache.read().unwrap(); @@ -193,6 +195,9 @@ impl MinimumValue { pgpool: &PgPool, cost_model_cache: CostModelMap, ) -> anyhow::Result<()> { + // TODO make sure to load last cost model + // plus all models that were created 60 secoonds from now + // let models = sqlx::query!( r#" SELECT deployment, model, variables From 487f013be8dbe9c09d428350971e058d0b5f6d9f Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Fri, 18 Oct 2024 20:54:02 +0200 Subject: [PATCH 18/21] refactor: move cost models to common Signed-off-by: Gustavo Inacio --- common/src/cost_model.rs | 465 ++++++++++++++++++ common/src/lib.rs | 1 + common/src/tap/checks/value_check.rs | 326 ++++++++----- migrations/20230901142040_cost_models.up.sql | 23 +- service/src/database.rs | 469 +------------------ service/src/routes/cost.rs | 6 +- 6 files changed, 675 insertions(+), 615 deletions(-) create mode 100644 common/src/cost_model.rs diff --git a/common/src/cost_model.rs b/common/src/cost_model.rs new file mode 100644 index 000000000..2924a1c0f --- /dev/null +++ b/common/src/cost_model.rs @@ -0,0 +1,465 @@ +// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs. +// SPDX-License-Identifier: Apache-2.0 + +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use sqlx::PgPool; +use std::{collections::HashSet, str::FromStr}; +use thegraph_core::{DeploymentId, ParseDeploymentIdError}; + +/// Internal cost model representation as stored in the database. +/// +/// These can have "global" as the deployment ID. +#[derive(Debug, Clone)] +pub(crate) struct DbCostModel { + pub deployment: Option, + pub model: Option, + pub variables: Option, +} + +/// External representation of cost models. +/// +/// Here, any notion of "global" is gone and deployment IDs are valid deployment IDs. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CostModel { + pub deployment: DeploymentId, + pub model: Option, + pub variables: Option, +} + +impl TryFrom for CostModel { + type Error = ParseDeploymentIdError; + + fn try_from(db_model: DbCostModel) -> Result { + Ok(Self { + deployment: DeploymentId::from_str(&db_model.deployment.unwrap())?, + model: db_model.model, + variables: db_model.variables, + }) + } +} + +impl From for DbCostModel { + fn from(model: CostModel) -> Self { + let deployment = model.deployment; + DbCostModel { + deployment: Some(format!("{deployment:#x}")), + model: model.model, + variables: model.variables, + } + } +} + +/// Query cost models from the database, merging the global cost model in +/// whenever there is no cost model defined for a deployment. +pub async fn cost_models( + pool: &PgPool, + deployments: &[DeploymentId], +) -> Result, anyhow::Error> { + let hex_ids = deployments + .iter() + .map(|d| format!("{d:#x}")) + .collect::>(); + + let mut models = if deployments.is_empty() { + sqlx::query_as!( + DbCostModel, + r#" + SELECT deployment, model, variables + FROM "CostModels" + WHERE deployment != 'global' + ORDER BY deployment ASC + "# + ) + .fetch_all(pool) + .await? + } else { + sqlx::query_as!( + DbCostModel, + r#" + SELECT deployment, model, variables + FROM "CostModels" + WHERE deployment = ANY($1) + AND deployment != 'global' + ORDER BY deployment ASC + "#, + &hex_ids + ) + .fetch_all(pool) + .await? + } + .into_iter() + .map(CostModel::try_from) + .collect::, _>>()?; + + let deployments_with_models = models + .iter() + .map(|model| &model.deployment) + .collect::>(); + + let deployments_without_models = deployments + .iter() + .filter(|deployment| !deployments_with_models.contains(deployment)) + .collect::>(); + + // Query the global cost model + if let Some(global_model) = global_cost_model(pool).await? { + // For all deployments that have a cost model, merge it with the global one + models = models + .into_iter() + .map(|model| merge_global(model, &global_model)) + // Inject a cost model for all deployments that don't have one + .chain( + deployments_without_models + .into_iter() + .map(|deployment| CostModel { + deployment: deployment.to_owned(), + model: global_model.model.clone(), + variables: global_model.variables.clone(), + }), + ) + .collect(); + } + + Ok(models) +} + +/// Make database query for a cost model indexed by deployment id +pub async fn cost_model( + pool: &PgPool, + deployment: &DeploymentId, +) -> Result, anyhow::Error> { + let model = sqlx::query_as!( + DbCostModel, + r#" + SELECT deployment, model, variables + FROM "CostModels" + WHERE deployment = $1 + AND deployment != 'global' + "#, + format!("{:#x}", deployment), + ) + .fetch_optional(pool) + .await? + .map(CostModel::try_from) + .transpose()?; + + let global_model = global_cost_model(pool).await?; + + Ok(match (model, global_model) { + // If we have no global model, return whatever we can find for the deployment + (None, None) => None, + (Some(model), None) => Some(model), + + // If we have a cost model and a global cost model, merge them + (Some(model), Some(global_model)) => Some(merge_global(model, &global_model)), + + // If we have only a global model, use that + (None, Some(global_model)) => Some(CostModel { + deployment: deployment.to_owned(), + model: global_model.model, + variables: global_model.variables, + }), + }) +} + +/// Query global cost model +async fn global_cost_model(pool: &PgPool) -> Result, anyhow::Error> { + sqlx::query_as!( + DbCostModel, + r#" + SELECT deployment, model, variables + FROM "CostModels" + WHERE deployment = $1 + "#, + "global" + ) + .fetch_optional(pool) + .await + .map_err(Into::into) +} + +fn merge_global(model: CostModel, global_model: &DbCostModel) -> CostModel { + CostModel { + deployment: model.deployment, + model: model.model.clone().or(global_model.model.clone()), + variables: model.variables.clone().or(global_model.variables.clone()), + } +} + +#[cfg(test)] +pub(crate) mod test { + + use std::str::FromStr; + + use sqlx::PgPool; + + use super::*; + + pub async fn add_cost_models(pool: &PgPool, models: Vec) { + for model in models { + sqlx::query!( + r#" + INSERT INTO "CostModelsHistory" (deployment, model) + VALUES ($1, $2); + "#, + model.deployment, + model.model, + ) + .execute(pool) + .await + .expect("Create test instance in db"); + } + } + + pub fn to_db_models(models: Vec) -> Vec { + models.into_iter().map(DbCostModel::from).collect() + } + + fn global_cost_model() -> DbCostModel { + DbCostModel { + deployment: Some("global".to_string()), + model: Some("default => 0.00001;".to_string()), + variables: None, + } + } + + pub fn test_data() -> Vec { + vec![ + CostModel { + deployment: "0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + .parse() + .unwrap(), + model: None, + variables: None, + }, + CostModel { + deployment: "0xbd499f7673ca32ef4a642207a8bebdd0fb03888cf2678b298438e3a1ae5206ea" + .parse() + .unwrap(), + model: Some("default => 0.00025;".to_string()), + variables: None, + }, + CostModel { + deployment: "0xcccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc" + .parse() + .unwrap(), + model: Some("default => 0.00012;".to_string()), + variables: None, + }, + ] + } + + #[sqlx::test(migrations = "../migrations")] + async fn success_cost_models(pool: PgPool) { + let test_models = test_data(); + let test_deployments = test_models + .iter() + .map(|model| model.deployment) + .collect::>(); + + add_cost_models(&pool, to_db_models(test_models.clone())).await; + + // First test: query without deployment filter + let models = cost_models(&pool, &[]) + .await + .expect("cost models query without deployment filter"); + + // We expect as many models as we have in the test data + assert_eq!(models.len(), test_models.len()); + + // We expect models for all test deployments to be present and + // identical to the test data + for test_deployment in test_deployments.iter() { + let test_model = test_models + .iter() + .find(|model| &model.deployment == test_deployment) + .expect("finding cost model for test deployment in test data"); + + let model = models + .iter() + .find(|model| &model.deployment == test_deployment) + .expect("finding cost model for test deployment in query result"); + + assert_eq!(test_model.model, model.model); + } + + // Second test: query with a deployment filter + let sample_deployments = vec![ + test_models.first().unwrap().deployment, + test_models.get(1).unwrap().deployment, + ]; + let models = cost_models(&pool, &sample_deployments) + .await + .expect("cost models query with deployment filter"); + + // Expect two cost mdoels to be returned + assert_eq!(models.len(), sample_deployments.len()); + + // Expect both returned deployments to be identical to the test data + for test_deployment in sample_deployments.iter() { + let test_model = test_models + .iter() + .find(|model| &model.deployment == test_deployment) + .expect("finding cost model for test deployment in test data"); + + let model = models + .iter() + .find(|model| &model.deployment == test_deployment) + .expect("finding cost model for test deployment in query result"); + + assert_eq!(test_model.model, model.model); + } + } + + #[sqlx::test(migrations = "../migrations")] + async fn global_fallback_cost_models(pool: PgPool) { + let test_models = test_data(); + let test_deployments = test_models + .iter() + .map(|model| model.deployment) + .collect::>(); + let global_model = global_cost_model(); + + add_cost_models(&pool, to_db_models(test_models.clone())).await; + add_cost_models(&pool, vec![global_model.clone()]).await; + + // First test: fetch cost models without filtering by deployment + let models = cost_models(&pool, &[]) + .await + .expect("cost models query without deployments filter"); + + // Since we've defined 3 cost models and we did not provide a filter, we + // expect all of them to be returned except for the global cost model + assert_eq!(models.len(), test_models.len()); + + // Expect all test deployments to be present in the query result + for test_deployment in test_deployments.iter() { + let test_model = test_models + .iter() + .find(|model| &model.deployment == test_deployment) + .expect("finding cost model for deployment in test data"); + + let model = models + .iter() + .find(|model| &model.deployment == test_deployment) + .expect("finding cost model for deployment in query result"); + + if test_model.model.is_some() { + // If the test model has a model definition, we expect that to be returned + assert_eq!(model.model, test_model.model); + } else { + // If the test model has no model definition, we expect the global + // model definition to be returned + assert_eq!(model.model, global_model.model); + } + } + + // Second test: fetch cost models, filtering by the first two deployment IDs + let sample_deployments = vec![ + test_models.first().unwrap().deployment, + test_models.get(1).unwrap().deployment, + ]; + let models = dbg!(cost_models(&pool, &sample_deployments).await) + .expect("cost models query with deployments filter"); + + // We've filtered by two deployment IDs and are expecting two cost models to be returned + assert_eq!(models.len(), sample_deployments.len()); + + for test_deployment in sample_deployments { + let test_model = test_models + .iter() + .find(|model| model.deployment == test_deployment) + .expect("finding cost model for deployment in test data"); + + let model = models + .iter() + .find(|model| model.deployment == test_deployment) + .expect("finding cost model for deployment in query result"); + + if test_model.model.is_some() { + // If the test model has a model definition, we expect that to be returned + assert_eq!(model.model, test_model.model); + } else { + // If the test model has no model definition, we expect the global + // model definition to be returned + assert_eq!(model.model, global_model.model); + } + } + + // Third test: query for missing cost model + let missing_deployment = + DeploymentId::from_str("Qmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa").unwrap(); + let models = cost_models(&pool, &[missing_deployment]) + .await + .expect("cost models query for missing deployment"); + + // The deployment may be missing in the database but we have a global model + // and expect that to be returned, with the missing deployment ID + let missing_model = models + .iter() + .find(|m| m.deployment == missing_deployment) + .expect("finding missing deployment"); + assert_eq!(missing_model.model, global_model.model); + } + + #[sqlx::test(migrations = "../migrations")] + async fn success_cost_model(pool: PgPool) { + add_cost_models(&pool, to_db_models(test_data())).await; + + let deployment_id_from_bytes = DeploymentId::from_str( + "0xbd499f7673ca32ef4a642207a8bebdd0fb03888cf2678b298438e3a1ae5206ea", + ) + .unwrap(); + let deployment_id_from_hash = + DeploymentId::from_str("Qmb5Ysp5oCUXhLA8NmxmYKDAX2nCMnh7Vvb5uffb9n5vss").unwrap(); + + assert_eq!(deployment_id_from_bytes, deployment_id_from_hash); + + let model = cost_model(&pool, &deployment_id_from_bytes) + .await + .expect("cost model query") + .expect("cost model for deployment"); + + assert_eq!(model.deployment, deployment_id_from_hash); + assert_eq!(model.model, Some("default => 0.00025;".to_string())); + } + + #[sqlx::test(migrations = "../migrations")] + async fn global_fallback_cost_model(pool: PgPool) { + let test_models = test_data(); + let global_model = global_cost_model(); + + add_cost_models(&pool, to_db_models(test_models.clone())).await; + add_cost_models(&pool, vec![global_model.clone()]).await; + + // Test that the behavior is correct for existing deployments + for test_model in test_models { + let model = cost_model(&pool, &test_model.deployment) + .await + .expect("cost model query") + .expect("global cost model fallback"); + + assert_eq!(model.deployment, test_model.deployment); + + if test_model.model.is_some() { + // If the test model has a model definition, we expect that to be returned + assert_eq!(model.model, test_model.model); + } else { + // If the test model has no model definition, we expect the global + // model definition to be returned + assert_eq!(model.model, global_model.model); + } + } + + // Test that querying a non-existing deployment returns the default cost model + let missing_deployment = + DeploymentId::from_str("Qmnononononononononononononononononononononono").unwrap(); + let model = cost_model(&pool, &missing_deployment) + .await + .expect("cost model query") + .expect("global cost model fallback"); + assert_eq!(model.deployment, missing_deployment); + assert_eq!(model.model, global_model.model); + } +} diff --git a/common/src/lib.rs b/common/src/lib.rs index 63470d754..7b20b85c2 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -4,6 +4,7 @@ pub mod address; pub mod allocations; pub mod attestations; +pub mod cost_model; pub mod escrow_accounts; pub mod graphql; pub mod indexer_service; diff --git a/common/src/tap/checks/value_check.rs b/common/src/tap/checks/value_check.rs index 1c07e444a..ec9abe714 100644 --- a/common/src/tap/checks/value_check.rs +++ b/common/src/tap/checks/value_check.rs @@ -4,7 +4,10 @@ use anyhow::anyhow; use bigdecimal::ToPrimitive; use cost_model::CostModel; -use sqlx::{postgres::PgListener, PgPool}; +use sqlx::{ + postgres::{PgListener, PgNotification}, + PgPool, +}; use std::{ cmp::min, collections::{hash_map::Entry, HashMap, VecDeque}, @@ -13,8 +16,12 @@ use std::{ time::Duration, }; use thegraph_core::{DeploymentId, ParseDeploymentIdError}; -use tokio::{sync::mpsc::channel, task::JoinHandle, time::sleep}; -use tracing::error; +use tokio::{ + sync::mpsc::{channel, Sender}, + task::JoinHandle, + time::sleep, +}; +use tracing::{debug, error}; use tap_core::receipt::{ checks::{Check, CheckError, CheckResult}, @@ -29,12 +36,159 @@ type CostModelMap = Arc>>>; pub struct MinimumValue { cost_model_cache: CostModelMap, + global_model_cache: Arc>, watcher_cancel_token: tokio_util::sync::CancellationToken, } +struct CostModelWatcher { + pgpool: PgPool, + tx: Sender, + + cost_model_cache: CostModelMap, + global_model_cache: Arc>, + + handles: Arc>>>>, +} + +impl CostModelWatcher { + async fn cost_models_watcher( + pgpool: PgPool, + mut pglistener: PgListener, + cost_model_cache: CostModelMap, + global_model_cache: Arc>, + cancel_token: tokio_util::sync::CancellationToken, + ) { + let handles: Arc>>>> = + Default::default(); + let (tx, mut rx) = channel::(64); + let cost_model_watcher = CostModelWatcher { + pgpool, + tx, + handles, + global_model_cache, + cost_model_cache, + }; + + loop { + tokio::select! { + _ = cancel_token.cancelled() => { + break; + } + Some(deployment_id) = rx.recv() => { + cost_model_watcher.cancel_cache_expire(deployment_id).await; + } + pg_notification = pglistener.recv() => { + let pg_notification = pg_notification.expect( + "should be able to receive Postgres Notify events on the channel \ + 'cost_models_update_notify'", + ); + cost_model_watcher.new_notification( + pg_notification, + ).await; + + } + } + } + } + + async fn new_notification(&self, pg_notification: PgNotification) { + let cost_model_notification: CostModelNotification = + serde_json::from_str(pg_notification.payload()).expect( + "should be able to deserialize the Postgres Notify event payload as a \ + CostModelNotification", + ); + + let deployment_id = match cost_model_notification.deployment.as_str() { + "global" => { + debug!("Received an update for 'global' cost model"); + return; + } + deployment_id => DeploymentId::from_str(deployment_id).unwrap(), + }; + + match cost_model_notification.tg_op.as_str() { + "INSERT" => { + let cost_model_source: CostModelSource = + cost_model_notification.try_into().unwrap(); + { + let mut cost_model_write = self.cost_model_cache.write().unwrap(); + let cache = cost_model_write.entry(deployment_id).or_default(); + let _ = cache.get_mut().unwrap().insert_model(cost_model_source); + } + let _tx = self.tx.clone(); + + // expire after 60 seconds + self.handles + .lock() + .unwrap() + .entry(deployment_id) + .or_default() + .push_back(tokio::spawn(async move { + // 1 minute after, we expire the older cache + sleep(Duration::from_secs(60)).await; + let _ = _tx.send(deployment_id).await; + })); + } + "DELETE" => { + if let Entry::Occupied(mut entry) = + self.cost_model_cache.write().unwrap().entry(deployment_id) + { + let should_remove = { + let mut cost_model = entry.get_mut().write().unwrap(); + cost_model.expire(); + cost_model.is_empty() + }; + if should_remove { + entry.remove(); + } + } + } + // UPDATE and TRUNCATE are not expected to happen. Reload the entire cost + // model cache. + _ => { + error!( + "Received an unexpected cost model table notification: {}. Reloading entire \ + cost model.", + cost_model_notification.tg_op + ); + + { + // clear all pending expire + let mut handles = self.handles.lock().unwrap(); + for maps in handles.values() { + for handle in maps { + handle.abort(); + } + } + handles.clear(); + } + + MinimumValue::value_check_reload(&self.pgpool, self.cost_model_cache.clone()) + .await + .expect("should be able to reload cost models") + } + } + } + + async fn cancel_cache_expire(&self, deployment_id: DeploymentId) { + let mut cost_model_write = self.cost_model_cache.write().unwrap(); + if let Some(cache) = cost_model_write.get_mut(&deployment_id) { + cache.get_mut().unwrap().expire(); + } + + if let Entry::Occupied(mut entry) = self.handles.lock().unwrap().entry(deployment_id) { + let vec = entry.get_mut(); + vec.pop_front(); + if vec.is_empty() { + entry.remove(); + } + } + } +} + impl Drop for MinimumValue { fn drop(&mut self) { - // Clean shutdown for the sender_denylist_watcher + // Clean shutdown for the minimum value check // Though since it's not a critical task, we don't wait for it to finish (join). self.watcher_cancel_token.cancel(); } @@ -43,6 +197,9 @@ impl Drop for MinimumValue { impl MinimumValue { pub async fn new(pgpool: PgPool) -> Self { let cost_model_cache: CostModelMap = Default::default(); + Self::value_check_reload(&pgpool, cost_model_cache.clone()) + .await + .expect("should be able to reload cost models"); let mut pglistener = PgListener::connect_with(&pgpool.clone()).await.unwrap(); pglistener.listen("cost_models_update_notify").await.expect( @@ -50,15 +207,19 @@ impl MinimumValue { 'cost_models_update_notify'", ); + let global_model_cache: Arc> = Default::default(); + let watcher_cancel_token = tokio_util::sync::CancellationToken::new(); - tokio::spawn(Self::cost_models_watcher( + tokio::spawn(CostModelWatcher::cost_models_watcher( pgpool.clone(), pglistener, cost_model_cache.clone(), + global_model_cache.clone(), watcher_cancel_token.clone(), )); Self { + global_model_cache, cost_model_cache, watcher_cancel_token, } @@ -81,116 +242,6 @@ impl MinimumValue { Ok(expected_value) } - async fn cost_models_watcher( - pgpool: PgPool, - mut pglistener: PgListener, - cost_model_cache: CostModelMap, - cancel_token: tokio_util::sync::CancellationToken, - ) { - let handles: Arc>>>> = - Default::default(); - let (tx, mut rx) = channel::(64); - - loop { - tokio::select! { - _ = cancel_token.cancelled() => { - break; - } - Some(deployment_id) = rx.recv() => { - let mut cost_model_write = cost_model_cache.write().unwrap(); - if let Some(cache) = cost_model_write.get_mut(&deployment_id) { - cache.get_mut().unwrap().expire(); - } - - if let Entry::Occupied(mut entry) = handles.lock().unwrap().entry(deployment_id) { - let vec = entry.get_mut(); - vec.pop_front(); - if vec.is_empty() { - entry.remove(); - } - } - - } - pg_notification = pglistener.recv() => { - let pg_notification = pg_notification.expect( - "should be able to receive Postgres Notify events on the channel \ - 'cost_models_update_notify'", - ); - - let cost_model_notification: CostModelNotification = - serde_json::from_str(pg_notification.payload()).expect( - "should be able to deserialize the Postgres Notify event payload as a \ - CostModelNotification", - ); - - let deployment_id = cost_model_notification.deployment; - - match cost_model_notification.tg_op.as_str() { - "INSERT" => { - let cost_model_source: CostModelSource = cost_model_notification.into(); - { - let mut cost_model_write = cost_model_cache - .write() - .unwrap(); - let cache = cost_model_write.entry(deployment_id).or_default(); - let _ = cache.get_mut().unwrap().insert_model(cost_model_source); - } - let _tx = tx.clone(); - - // expire after 60 seconds - handles.lock() - .unwrap() - .entry(deployment_id) - .or_default() - .push_back(tokio::spawn(async move { - // 1 minute after, we expire the older cache - sleep(Duration::from_secs(60)).await; - let _ = _tx.send(deployment_id).await; - })); - } - "DELETE" => { - if let Entry::Occupied(mut entry) = cost_model_cache - .write().unwrap().entry(cost_model_notification.deployment) { - let should_remove = { - let mut cost_model = entry.get_mut().write().unwrap(); - cost_model.expire(); - cost_model.is_empty() - }; - if should_remove { - entry.remove(); - } - } - } - // UPDATE and TRUNCATE are not expected to happen. Reload the entire cost - // model cache. - _ => { - error!( - "Received an unexpected cost model table notification: {}. Reloading entire \ - cost model.", - cost_model_notification.tg_op - ); - - { - // clear all pending expire - let mut handles = handles.lock().unwrap(); - for maps in handles.values() { - for handle in maps { - handle.abort(); - } - } - handles.clear(); - } - - Self::value_check_reload(&pgpool, cost_model_cache.clone()) - .await - .expect("should be able to reload cost models") - } - } - } - } - } - } - async fn value_check_reload( pgpool: &PgPool, cost_model_cache: CostModelMap, @@ -214,8 +265,7 @@ impl MinimumValue { let deployment_id = DeploymentId::from_str(&record.deployment.unwrap())?; let mut model = CostModelCache::default(); let _ = model.insert_model(CostModelSource { - deployment_id, - model: record.model.unwrap(), + model: record.model.unwrap_or_default(), variables: record.variables.unwrap_or_default(), }); @@ -263,11 +313,11 @@ impl Check for MinimumValue { } } -fn compile_cost_model(src: CostModelSource) -> anyhow::Result { - if src.model.len() > (1 << 16) { +fn compile_cost_model(model: String, variables: String) -> anyhow::Result { + if model.len() > (1 << 16) { return Err(anyhow!("CostModelTooLarge")); } - let model = CostModel::compile(&src.model, &src.variables.to_string())?; + let model = CostModel::compile(&model, &variables)?; Ok(model) } @@ -279,7 +329,6 @@ pub struct AgoraQuery { #[derive(Clone, Eq, Hash, PartialEq)] struct CostModelSource { - pub deployment_id: DeploymentId, pub model: String, pub variables: serde_json::Value, } @@ -287,17 +336,16 @@ struct CostModelSource { #[derive(serde::Deserialize)] struct CostModelNotification { tg_op: String, - deployment: DeploymentId, - model: String, - variables: serde_json::Value, + deployment: String, + model: Option, + variables: Option, } impl From for CostModelSource { fn from(value: CostModelNotification) -> Self { CostModelSource { - deployment_id: value.deployment, - model: value.model, - variables: value.variables, + model: value.model.unwrap_or_default(), + variables: value.variables.unwrap_or_default(), } } } @@ -309,7 +357,7 @@ struct CostModelCache { impl CostModelCache { fn insert_model(&mut self, source: CostModelSource) -> anyhow::Result<()> { - let model = compile_cost_model(source.clone())?; + let model = compile_cost_model(source.model, source.variables.to_string())?; self.models.push_back(model); Ok(()) } @@ -346,11 +394,29 @@ impl CostModelCache { mod tests { use sqlx::PgPool; + use crate::cost_model::test::{add_cost_models, to_db_models}; + + use super::MinimumValue; + #[sqlx::test(migrations = "../migrations")] - async fn initialize_check(_pg_pool: PgPool) {} + async fn initialize_check(pgpool: PgPool) { + let check = MinimumValue::new(pgpool).await; + assert_eq!(check.cost_model_cache.read().unwrap().len(), 0); + } #[sqlx::test(migrations = "../migrations")] - async fn should_initialize_check_with_caches(_pg_pool: PgPool) {} + async fn should_initialize_check_with_caches(pgpool: PgPool) { + // insert 2 cost models for different deployment_id + let test_models = crate::cost_model::test::test_data(); + + add_cost_models(&pgpool, to_db_models(test_models.clone())).await; + + let check = MinimumValue::new(pgpool).await; + assert_eq!( + check.cost_model_cache.read().unwrap().len(), + test_models.len() + ); + } #[sqlx::test(migrations = "../migrations")] async fn should_add_model_to_cache_on_insert(_pg_pool: PgPool) {} diff --git a/migrations/20230901142040_cost_models.up.sql b/migrations/20230901142040_cost_models.up.sql index a5e1862fc..1e3dc7ad7 100644 --- a/migrations/20230901142040_cost_models.up.sql +++ b/migrations/20230901142040_cost_models.up.sql @@ -8,20 +8,15 @@ CREATE TABLE IF NOT EXISTS "CostModelsHistory" "updatedAt" TIMESTAMP WITH TIME ZONE ); -CREATE VIEW "CostModels" AS SELECT id, - deployment, - model, - variables, - "createdAt", - "updatedAt" - FROM "CostModelsHistory" t1 - JOIN - ( - SELECT MAX(id) - FROM "CostModelsHistory" - GROUP BY deployment - ) t2 - ON t1.id = t2.MAX; +CREATE VIEW "CostModels" AS +SELECT DISTINCT ON (deployment, model, variables) + deployment, + model, + variables, + "createdAt", + "updatedAt" +FROM "CostModelsHistory" +ORDER BY deployment, model, variables, id DESC; CREATE FUNCTION cost_models_update_notify() RETURNS trigger AS diff --git a/service/src/database.rs b/service/src/database.rs index 31708b3c3..7054f0358 100644 --- a/service/src/database.rs +++ b/service/src/database.rs @@ -1,13 +1,8 @@ // Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs. // SPDX-License-Identifier: Apache-2.0 -use std::time::Duration; -use std::{collections::HashSet, str::FromStr}; - -use serde::{Deserialize, Serialize}; -use serde_json::Value; use sqlx::{postgres::PgPoolOptions, PgPool}; -use thegraph_core::{DeploymentId, ParseDeploymentIdError}; +use std::time::Duration; use tracing::debug; pub async fn connect(url: &str) -> PgPool { @@ -20,465 +15,3 @@ pub async fn connect(url: &str) -> PgPool { .await .expect("Should be able to connect to the database") } - -/// Internal cost model representation as stored in the database. -/// -/// These can have "global" as the deployment ID. -#[derive(Debug, Clone)] -struct DbCostModel { - pub deployment: Option, - pub model: Option, - pub variables: Option, -} - -/// External representation of cost models. -/// -/// Here, any notion of "global" is gone and deployment IDs are valid deployment IDs. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CostModel { - pub deployment: DeploymentId, - pub model: Option, - pub variables: Option, -} - -impl TryFrom for CostModel { - type Error = ParseDeploymentIdError; - - fn try_from(db_model: DbCostModel) -> Result { - Ok(Self { - deployment: DeploymentId::from_str(&db_model.deployment.ok_or( - ParseDeploymentIdError::InvalidIpfsHashLength { - value: String::new(), - length: 0, - }, - )?)?, - model: db_model.model, - variables: db_model.variables, - }) - } -} - -impl From for DbCostModel { - fn from(model: CostModel) -> Self { - let deployment = model.deployment; - DbCostModel { - deployment: Some(format!("{deployment:#x}")), - model: model.model, - variables: model.variables, - } - } -} - -/// Query cost models from the database, merging the global cost model in -/// whenever there is no cost model defined for a deployment. -pub async fn cost_models( - pool: &PgPool, - deployments: &[DeploymentId], -) -> Result, anyhow::Error> { - let hex_ids = deployments - .iter() - .map(|d| format!("{d:#x}")) - .collect::>(); - - let mut models = if deployments.is_empty() { - sqlx::query_as!( - DbCostModel, - r#" - SELECT deployment, model, variables - FROM "CostModels" - WHERE deployment != 'global' - ORDER BY deployment ASC - "# - ) - .fetch_all(pool) - .await? - } else { - sqlx::query_as!( - DbCostModel, - r#" - SELECT deployment, model, variables - FROM "CostModels" - WHERE deployment = ANY($1) - AND deployment != 'global' - ORDER BY deployment ASC - "#, - &hex_ids - ) - .fetch_all(pool) - .await? - } - .into_iter() - .map(CostModel::try_from) - .collect::, _>>()?; - - let deployments_with_models = models - .iter() - .map(|model| &model.deployment) - .collect::>(); - - let deployments_without_models = deployments - .iter() - .filter(|deployment| !deployments_with_models.contains(deployment)) - .collect::>(); - - // Query the global cost model - if let Some(global_model) = global_cost_model(pool).await? { - // For all deployments that have a cost model, merge it with the global one - models = models - .into_iter() - .map(|model| merge_global(model, &global_model)) - // Inject a cost model for all deployments that don't have one - .chain( - deployments_without_models - .into_iter() - .map(|deployment| CostModel { - deployment: deployment.to_owned(), - model: global_model.model.clone(), - variables: global_model.variables.clone(), - }), - ) - .collect(); - } - - Ok(models) -} - -/// Make database query for a cost model indexed by deployment id -pub async fn cost_model( - pool: &PgPool, - deployment: &DeploymentId, -) -> Result, anyhow::Error> { - let model = sqlx::query_as!( - DbCostModel, - r#" - SELECT deployment, model, variables - FROM "CostModels" - WHERE deployment = $1 - AND deployment != 'global' - "#, - format!("{:#x}", deployment), - ) - .fetch_optional(pool) - .await? - .map(CostModel::try_from) - .transpose()?; - - let global_model = global_cost_model(pool).await?; - - Ok(match (model, global_model) { - // If we have no global model, return whatever we can find for the deployment - (None, None) => None, - (Some(model), None) => Some(model), - - // If we have a cost model and a global cost model, merge them - (Some(model), Some(global_model)) => Some(merge_global(model, &global_model)), - - // If we have only a global model, use that - (None, Some(global_model)) => Some(CostModel { - deployment: deployment.to_owned(), - model: global_model.model, - variables: global_model.variables, - }), - }) -} - -/// Query global cost model -async fn global_cost_model(pool: &PgPool) -> Result, anyhow::Error> { - sqlx::query_as!( - DbCostModel, - r#" - SELECT deployment, model, variables - FROM "CostModels" - WHERE deployment = $1 - "#, - "global" - ) - .fetch_optional(pool) - .await - .map_err(Into::into) -} - -fn merge_global(model: CostModel, global_model: &DbCostModel) -> CostModel { - CostModel { - deployment: model.deployment, - model: model.model.clone().or(global_model.model.clone()), - variables: model.variables.clone().or(global_model.variables.clone()), - } -} - -#[cfg(test)] -mod test { - - use std::str::FromStr; - - use sqlx::PgPool; - - use super::*; - - async fn add_cost_models(pool: &PgPool, models: Vec) { - for model in models { - sqlx::query!( - r#" - INSERT INTO "CostModelsHistory" (deployment, model) - VALUES ($1, $2); - "#, - model.deployment, - model.model, - ) - .execute(pool) - .await - .expect("Create test instance in db"); - } - } - - fn to_db_models(models: Vec) -> Vec { - models.into_iter().map(DbCostModel::from).collect() - } - - fn global_cost_model() -> DbCostModel { - DbCostModel { - deployment: Some("global".to_string()), - model: Some("default => 0.00001;".to_string()), - variables: None, - } - } - - fn test_data() -> Vec { - vec![ - CostModel { - deployment: "0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" - .parse() - .unwrap(), - model: None, - variables: None, - }, - CostModel { - deployment: "0xbd499f7673ca32ef4a642207a8bebdd0fb03888cf2678b298438e3a1ae5206ea" - .parse() - .unwrap(), - model: Some("default => 0.00025;".to_string()), - variables: None, - }, - CostModel { - deployment: "0xcccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc" - .parse() - .unwrap(), - model: Some("default => 0.00012;".to_string()), - variables: None, - }, - ] - } - - #[sqlx::test(migrations = "../migrations")] - async fn success_cost_models(pool: PgPool) { - let test_models = test_data(); - let test_deployments = test_models - .iter() - .map(|model| model.deployment) - .collect::>(); - - add_cost_models(&pool, to_db_models(test_models.clone())).await; - - // First test: query without deployment filter - let models = cost_models(&pool, &[]) - .await - .expect("cost models query without deployment filter"); - - // We expect as many models as we have in the test data - assert_eq!(models.len(), test_models.len()); - - // We expect models for all test deployments to be present and - // identical to the test data - for test_deployment in test_deployments.iter() { - let test_model = test_models - .iter() - .find(|model| &model.deployment == test_deployment) - .expect("finding cost model for test deployment in test data"); - - let model = models - .iter() - .find(|model| &model.deployment == test_deployment) - .expect("finding cost model for test deployment in query result"); - - assert_eq!(test_model.model, model.model); - } - - // Second test: query with a deployment filter - let sample_deployments = vec![ - test_models.first().unwrap().deployment, - test_models.get(1).unwrap().deployment, - ]; - let models = cost_models(&pool, &sample_deployments) - .await - .expect("cost models query with deployment filter"); - - // Expect two cost mdoels to be returned - assert_eq!(models.len(), sample_deployments.len()); - - // Expect both returned deployments to be identical to the test data - for test_deployment in sample_deployments.iter() { - let test_model = test_models - .iter() - .find(|model| &model.deployment == test_deployment) - .expect("finding cost model for test deployment in test data"); - - let model = models - .iter() - .find(|model| &model.deployment == test_deployment) - .expect("finding cost model for test deployment in query result"); - - assert_eq!(test_model.model, model.model); - } - } - - #[sqlx::test(migrations = "../migrations")] - async fn global_fallback_cost_models(pool: PgPool) { - let test_models = test_data(); - let test_deployments = test_models - .iter() - .map(|model| model.deployment) - .collect::>(); - let global_model = global_cost_model(); - - add_cost_models(&pool, to_db_models(test_models.clone())).await; - add_cost_models(&pool, vec![global_model.clone()]).await; - - // First test: fetch cost models without filtering by deployment - let models = cost_models(&pool, &[]) - .await - .expect("cost models query without deployments filter"); - - // Since we've defined 3 cost models and we did not provide a filter, we - // expect all of them to be returned except for the global cost model - assert_eq!(models.len(), test_models.len()); - - // Expect all test deployments to be present in the query result - for test_deployment in test_deployments.iter() { - let test_model = test_models - .iter() - .find(|model| &model.deployment == test_deployment) - .expect("finding cost model for deployment in test data"); - - let model = models - .iter() - .find(|model| &model.deployment == test_deployment) - .expect("finding cost model for deployment in query result"); - - if test_model.model.is_some() { - // If the test model has a model definition, we expect that to be returned - assert_eq!(model.model, test_model.model); - } else { - // If the test model has no model definition, we expect the global - // model definition to be returned - assert_eq!(model.model, global_model.model); - } - } - - // Second test: fetch cost models, filtering by the first two deployment IDs - let sample_deployments = vec![ - test_models.first().unwrap().deployment, - test_models.get(1).unwrap().deployment, - ]; - let models = dbg!(cost_models(&pool, &sample_deployments).await) - .expect("cost models query with deployments filter"); - - // We've filtered by two deployment IDs and are expecting two cost models to be returned - assert_eq!(models.len(), sample_deployments.len()); - - for test_deployment in sample_deployments { - let test_model = test_models - .iter() - .find(|model| model.deployment == test_deployment) - .expect("finding cost model for deployment in test data"); - - let model = models - .iter() - .find(|model| model.deployment == test_deployment) - .expect("finding cost model for deployment in query result"); - - if test_model.model.is_some() { - // If the test model has a model definition, we expect that to be returned - assert_eq!(model.model, test_model.model); - } else { - // If the test model has no model definition, we expect the global - // model definition to be returned - assert_eq!(model.model, global_model.model); - } - } - - // Third test: query for missing cost model - let missing_deployment = - DeploymentId::from_str("Qmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa").unwrap(); - let models = cost_models(&pool, &[missing_deployment]) - .await - .expect("cost models query for missing deployment"); - - // The deployment may be missing in the database but we have a global model - // and expect that to be returned, with the missing deployment ID - let missing_model = models - .iter() - .find(|m| m.deployment == missing_deployment) - .expect("finding missing deployment"); - assert_eq!(missing_model.model, global_model.model); - } - - #[sqlx::test(migrations = "../migrations")] - async fn success_cost_model(pool: PgPool) { - add_cost_models(&pool, to_db_models(test_data())).await; - - let deployment_id_from_bytes = DeploymentId::from_str( - "0xbd499f7673ca32ef4a642207a8bebdd0fb03888cf2678b298438e3a1ae5206ea", - ) - .unwrap(); - let deployment_id_from_hash = - DeploymentId::from_str("Qmb5Ysp5oCUXhLA8NmxmYKDAX2nCMnh7Vvb5uffb9n5vss").unwrap(); - - assert_eq!(deployment_id_from_bytes, deployment_id_from_hash); - - let model = cost_model(&pool, &deployment_id_from_bytes) - .await - .expect("cost model query") - .expect("cost model for deployment"); - - assert_eq!(model.deployment, deployment_id_from_hash); - assert_eq!(model.model, Some("default => 0.00025;".to_string())); - } - - #[sqlx::test(migrations = "../migrations")] - async fn global_fallback_cost_model(pool: PgPool) { - let test_models = test_data(); - let global_model = global_cost_model(); - - add_cost_models(&pool, to_db_models(test_models.clone())).await; - add_cost_models(&pool, vec![global_model.clone()]).await; - - // Test that the behavior is correct for existing deployments - for test_model in test_models { - let model = cost_model(&pool, &test_model.deployment) - .await - .expect("cost model query") - .expect("global cost model fallback"); - - assert_eq!(model.deployment, test_model.deployment); - - if test_model.model.is_some() { - // If the test model has a model definition, we expect that to be returned - assert_eq!(model.model, test_model.model); - } else { - // If the test model has no model definition, we expect the global - // model definition to be returned - assert_eq!(model.model, global_model.model); - } - } - - // Test that querying a non-existing deployment returns the default cost model - let missing_deployment = - DeploymentId::from_str("Qmnononononononononononononononononononononono").unwrap(); - let model = cost_model(&pool, &missing_deployment) - .await - .expect("cost model query") - .expect("global cost model fallback"); - assert_eq!(model.deployment, missing_deployment); - assert_eq!(model.model, global_model.model); - } -} diff --git a/service/src/routes/cost.rs b/service/src/routes/cost.rs index 08d4eb2cf..044a15f9f 100644 --- a/service/src/routes/cost.rs +++ b/service/src/routes/cost.rs @@ -7,6 +7,7 @@ use std::sync::Arc; use async_graphql::{Context, EmptyMutation, EmptySubscription, Object, Schema, SimpleObject}; use async_graphql_axum::{GraphQLRequest, GraphQLResponse}; use axum::extract::State; +use indexer_common::cost_model::{self, CostModel}; use lazy_static::lazy_static; use prometheus::{ register_counter, register_counter_vec, register_histogram, register_histogram_vec, Counter, @@ -16,7 +17,6 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use thegraph_core::DeploymentId; -use crate::database::{self, CostModel}; use crate::service::SubgraphServiceState; lazy_static! { @@ -128,7 +128,7 @@ impl Query { deployment_ids: Vec, ) -> Result, anyhow::Error> { let pool = &ctx.data_unchecked::>().database; - let cost_models = database::cost_models(pool, &deployment_ids).await?; + let cost_models = cost_model::cost_models(pool, &deployment_ids).await?; Ok(cost_models.into_iter().map(|m| m.into()).collect()) } @@ -138,7 +138,7 @@ impl Query { deployment_id: DeploymentId, ) -> Result, anyhow::Error> { let pool = &ctx.data_unchecked::>().database; - database::cost_model(pool, &deployment_id) + cost_model::cost_model(pool, &deployment_id) .await .map(|model_opt| model_opt.map(GraphQlCostModel::from)) } From 1ab0866451278a0de0bb376af9d8ae6d53c53d46 Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Thu, 24 Oct 2024 23:03:52 +0200 Subject: [PATCH 19/21] refactor: remove cache, use simple check Signed-off-by: Gustavo Inacio --- ...d18653f57a9d4b6f231c00b9eea60392a0a6f.json | 15 - ...edc4204bbbd32aac6f7da7e99fb501ca5cc14.json | 2 +- ...a1732bb3638a3a4f942a665cf1fd38eb70c2d.json | 2 +- ...651e43b6b9ffe9834df85a62707d3a2d051b4.json | 2 +- ...bc7c298006af22134c144048b0ac7183353c0.json | 12 + ...4651cfd94b4b82f26baf0755efa80e6045c0a.json | 2 +- ...893b4f7209c985c367215b2eed25adc78c462.json | 15 + common/src/cost_model.rs | 14 +- common/src/tap/checks/value_check.rs | 566 +++++++++++------- .../20230901142040_cost_models.down.sql | 8 +- migrations/20230901142040_cost_models.up.sql | 38 +- ...91258_add_cost_model_notification.down.sql | 4 + ...4191258_add_cost_model_notification.up.sql | 21 + 13 files changed, 403 insertions(+), 298 deletions(-) delete mode 100644 .sqlx/query-4129630725eb7f563f2b0532366d18653f57a9d4b6f231c00b9eea60392a0a6f.json create mode 100644 .sqlx/query-dcc710ebd4fa7a86f95a45ebf13bc7c298006af22134c144048b0ac7183353c0.json create mode 100644 .sqlx/query-ef6affb9039ad19a69f4a5116d7893b4f7209c985c367215b2eed25adc78c462.json create mode 100644 migrations/20241024191258_add_cost_model_notification.down.sql create mode 100644 migrations/20241024191258_add_cost_model_notification.up.sql diff --git a/.sqlx/query-4129630725eb7f563f2b0532366d18653f57a9d4b6f231c00b9eea60392a0a6f.json b/.sqlx/query-4129630725eb7f563f2b0532366d18653f57a9d4b6f231c00b9eea60392a0a6f.json deleted file mode 100644 index b617582e0..000000000 --- a/.sqlx/query-4129630725eb7f563f2b0532366d18653f57a9d4b6f231c00b9eea60392a0a6f.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n INSERT INTO \"CostModelsHistory\" (deployment, model)\n VALUES ($1, $2);\n ", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Varchar", - "Text" - ] - }, - "nullable": [] - }, - "hash": "4129630725eb7f563f2b0532366d18653f57a9d4b6f231c00b9eea60392a0a6f" -} diff --git a/.sqlx/query-842bde7fba1c7652b7cfc2dc568edc4204bbbd32aac6f7da7e99fb501ca5cc14.json b/.sqlx/query-842bde7fba1c7652b7cfc2dc568edc4204bbbd32aac6f7da7e99fb501ca5cc14.json index 91e33b3ed..2b0cde21e 100644 --- a/.sqlx/query-842bde7fba1c7652b7cfc2dc568edc4204bbbd32aac6f7da7e99fb501ca5cc14.json +++ b/.sqlx/query-842bde7fba1c7652b7cfc2dc568edc4204bbbd32aac6f7da7e99fb501ca5cc14.json @@ -23,7 +23,7 @@ "Left": [] }, "nullable": [ - true, + false, true, true ] diff --git a/.sqlx/query-b54b1069daf03a377a0e7c09c9aa1732bb3638a3a4f942a665cf1fd38eb70c2d.json b/.sqlx/query-b54b1069daf03a377a0e7c09c9aa1732bb3638a3a4f942a665cf1fd38eb70c2d.json index d25a9de5d..7dbd08638 100644 --- a/.sqlx/query-b54b1069daf03a377a0e7c09c9aa1732bb3638a3a4f942a665cf1fd38eb70c2d.json +++ b/.sqlx/query-b54b1069daf03a377a0e7c09c9aa1732bb3638a3a4f942a665cf1fd38eb70c2d.json @@ -25,7 +25,7 @@ ] }, "nullable": [ - true, + false, true, true ] diff --git a/.sqlx/query-d93dd26d7221c5e1ae15a919a2a651e43b6b9ffe9834df85a62707d3a2d051b4.json b/.sqlx/query-d93dd26d7221c5e1ae15a919a2a651e43b6b9ffe9834df85a62707d3a2d051b4.json index c9e8b2f4e..6b2da69c8 100644 --- a/.sqlx/query-d93dd26d7221c5e1ae15a919a2a651e43b6b9ffe9834df85a62707d3a2d051b4.json +++ b/.sqlx/query-d93dd26d7221c5e1ae15a919a2a651e43b6b9ffe9834df85a62707d3a2d051b4.json @@ -25,7 +25,7 @@ ] }, "nullable": [ - true, + false, true, true ] diff --git a/.sqlx/query-dcc710ebd4fa7a86f95a45ebf13bc7c298006af22134c144048b0ac7183353c0.json b/.sqlx/query-dcc710ebd4fa7a86f95a45ebf13bc7c298006af22134c144048b0ac7183353c0.json new file mode 100644 index 000000000..65922247c --- /dev/null +++ b/.sqlx/query-dcc710ebd4fa7a86f95a45ebf13bc7c298006af22134c144048b0ac7183353c0.json @@ -0,0 +1,12 @@ +{ + "db_name": "PostgreSQL", + "query": "DELETE FROM \"CostModels\"", + "describe": { + "columns": [], + "parameters": { + "Left": [] + }, + "nullable": [] + }, + "hash": "dcc710ebd4fa7a86f95a45ebf13bc7c298006af22134c144048b0ac7183353c0" +} diff --git a/.sqlx/query-e14503b633fc673b65448e70c204651cfd94b4b82f26baf0755efa80e6045c0a.json b/.sqlx/query-e14503b633fc673b65448e70c204651cfd94b4b82f26baf0755efa80e6045c0a.json index 08c287f3f..619967273 100644 --- a/.sqlx/query-e14503b633fc673b65448e70c204651cfd94b4b82f26baf0755efa80e6045c0a.json +++ b/.sqlx/query-e14503b633fc673b65448e70c204651cfd94b4b82f26baf0755efa80e6045c0a.json @@ -25,7 +25,7 @@ ] }, "nullable": [ - true, + false, true, true ] diff --git a/.sqlx/query-ef6affb9039ad19a69f4a5116d7893b4f7209c985c367215b2eed25adc78c462.json b/.sqlx/query-ef6affb9039ad19a69f4a5116d7893b4f7209c985c367215b2eed25adc78c462.json new file mode 100644 index 000000000..0c8cfd917 --- /dev/null +++ b/.sqlx/query-ef6affb9039ad19a69f4a5116d7893b4f7209c985c367215b2eed25adc78c462.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO \"CostModels\" (deployment, model)\n VALUES ($1, $2);\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Varchar", + "Text" + ] + }, + "nullable": [] + }, + "hash": "ef6affb9039ad19a69f4a5116d7893b4f7209c985c367215b2eed25adc78c462" +} diff --git a/common/src/cost_model.rs b/common/src/cost_model.rs index 2924a1c0f..a2d50570b 100644 --- a/common/src/cost_model.rs +++ b/common/src/cost_model.rs @@ -12,7 +12,7 @@ use thegraph_core::{DeploymentId, ParseDeploymentIdError}; /// These can have "global" as the deployment ID. #[derive(Debug, Clone)] pub(crate) struct DbCostModel { - pub deployment: Option, + pub deployment: String, pub model: Option, pub variables: Option, } @@ -32,7 +32,7 @@ impl TryFrom for CostModel { fn try_from(db_model: DbCostModel) -> Result { Ok(Self { - deployment: DeploymentId::from_str(&db_model.deployment.unwrap())?, + deployment: DeploymentId::from_str(&db_model.deployment)?, model: db_model.model, variables: db_model.variables, }) @@ -43,7 +43,7 @@ impl From for DbCostModel { fn from(model: CostModel) -> Self { let deployment = model.deployment; DbCostModel { - deployment: Some(format!("{deployment:#x}")), + deployment: format!("{deployment:#x}"), model: model.model, variables: model.variables, } @@ -164,7 +164,7 @@ pub async fn cost_model( } /// Query global cost model -async fn global_cost_model(pool: &PgPool) -> Result, anyhow::Error> { +pub(crate) async fn global_cost_model(pool: &PgPool) -> Result, anyhow::Error> { sqlx::query_as!( DbCostModel, r#" @@ -200,7 +200,7 @@ pub(crate) mod test { for model in models { sqlx::query!( r#" - INSERT INTO "CostModelsHistory" (deployment, model) + INSERT INTO "CostModels" (deployment, model) VALUES ($1, $2); "#, model.deployment, @@ -216,9 +216,9 @@ pub(crate) mod test { models.into_iter().map(DbCostModel::from).collect() } - fn global_cost_model() -> DbCostModel { + pub fn global_cost_model() -> DbCostModel { DbCostModel { - deployment: Some("global".to_string()), + deployment: "global".to_string(), model: Some("default => 0.00001;".to_string()), variables: None, } diff --git a/common/src/tap/checks/value_check.rs b/common/src/tap/checks/value_check.rs index ec9abe714..e42d29b93 100644 --- a/common/src/tap/checks/value_check.rs +++ b/common/src/tap/checks/value_check.rs @@ -1,27 +1,20 @@ // Copyright 2023-, GraphOps and Semiotic Labs. // SPDX-License-Identifier: Apache-2.0 +use ::cost_model::CostModel; use anyhow::anyhow; use bigdecimal::ToPrimitive; -use cost_model::CostModel; use sqlx::{ postgres::{PgListener, PgNotification}, PgPool, }; use std::{ - cmp::min, - collections::{hash_map::Entry, HashMap, VecDeque}, + collections::HashMap, str::FromStr, - sync::{Arc, Mutex, RwLock}, - time::Duration, + sync::{Arc, RwLock}, }; -use thegraph_core::{DeploymentId, ParseDeploymentIdError}; -use tokio::{ - sync::mpsc::{channel, Sender}, - task::JoinHandle, - time::sleep, -}; -use tracing::{debug, error}; +use thegraph_core::DeploymentId; +use tracing::error; use tap_core::receipt::{ checks::{Check, CheckError, CheckResult}, @@ -29,44 +22,45 @@ use tap_core::receipt::{ Context, ReceiptWithState, }; +use crate::cost_model; + // we only accept receipts with minimal 1 wei grt const MINIMAL_VALUE: u128 = 1; -type CostModelMap = Arc>>>; +pub struct AgoraQuery { + pub deployment_id: DeploymentId, + pub query: String, + pub variables: String, +} + +type CostModelMap = Arc>>; +type GlobalModel = Arc>>; pub struct MinimumValue { - cost_model_cache: CostModelMap, - global_model_cache: Arc>, + cost_model_map: CostModelMap, + global_model: GlobalModel, watcher_cancel_token: tokio_util::sync::CancellationToken, } struct CostModelWatcher { pgpool: PgPool, - tx: Sender, - - cost_model_cache: CostModelMap, - global_model_cache: Arc>, - handles: Arc>>>>, + cost_models: CostModelMap, + global_model: GlobalModel, } impl CostModelWatcher { async fn cost_models_watcher( pgpool: PgPool, mut pglistener: PgListener, - cost_model_cache: CostModelMap, - global_model_cache: Arc>, + cost_models: CostModelMap, + global_model: GlobalModel, cancel_token: tokio_util::sync::CancellationToken, ) { - let handles: Arc>>>> = - Default::default(); - let (tx, mut rx) = channel::(64); let cost_model_watcher = CostModelWatcher { pgpool, - tx, - handles, - global_model_cache, - cost_model_cache, + global_model, + cost_models, }; loop { @@ -74,113 +68,80 @@ impl CostModelWatcher { _ = cancel_token.cancelled() => { break; } - Some(deployment_id) = rx.recv() => { - cost_model_watcher.cancel_cache_expire(deployment_id).await; - } - pg_notification = pglistener.recv() => { - let pg_notification = pg_notification.expect( - "should be able to receive Postgres Notify events on the channel \ - 'cost_models_update_notify'", - ); + Ok(pg_notification) = pglistener.recv() => { cost_model_watcher.new_notification( pg_notification, ).await; - } } } } async fn new_notification(&self, pg_notification: PgNotification) { - let cost_model_notification: CostModelNotification = - serde_json::from_str(pg_notification.payload()).expect( - "should be able to deserialize the Postgres Notify event payload as a \ - CostModelNotification", - ); - - let deployment_id = match cost_model_notification.deployment.as_str() { - "global" => { - debug!("Received an update for 'global' cost model"); - return; - } - deployment_id => DeploymentId::from_str(deployment_id).unwrap(), - }; - - match cost_model_notification.tg_op.as_str() { - "INSERT" => { - let cost_model_source: CostModelSource = - cost_model_notification.try_into().unwrap(); - { - let mut cost_model_write = self.cost_model_cache.write().unwrap(); - let cache = cost_model_write.entry(deployment_id).or_default(); - let _ = cache.get_mut().unwrap().insert_model(cost_model_source); - } - let _tx = self.tx.clone(); - - // expire after 60 seconds - self.handles - .lock() - .unwrap() - .entry(deployment_id) - .or_default() - .push_back(tokio::spawn(async move { - // 1 minute after, we expire the older cache - sleep(Duration::from_secs(60)).await; - let _ = _tx.send(deployment_id).await; - })); + let payload = pg_notification.payload(); + let cost_model_notification: Result = + serde_json::from_str(payload); + + match cost_model_notification { + Ok(CostModelNotification::Insert { + deployment, + model, + variables, + }) => { + let model = compile_cost_model(model, variables).unwrap(); + + match deployment.as_str() { + "global" => { + *self.global_model.write().unwrap() = Some(model); + } + deployment_id => match DeploymentId::from_str(deployment_id) { + Ok(deployment_id) => { + let mut cost_model_write = self.cost_models.write().unwrap(); + cost_model_write.insert(deployment_id, model); + } + Err(_) => { + error!( + "Received insert request for an invalid deployment_id: {}", + deployment_id + ) + } + }, + }; } - "DELETE" => { - if let Entry::Occupied(mut entry) = - self.cost_model_cache.write().unwrap().entry(deployment_id) - { - let should_remove = { - let mut cost_model = entry.get_mut().write().unwrap(); - cost_model.expire(); - cost_model.is_empty() - }; - if should_remove { - entry.remove(); + Ok(CostModelNotification::Delete { deployment }) => { + match deployment.as_str() { + "global" => { + *self.global_model.write().unwrap() = None; } - } + deployment_id => match DeploymentId::from_str(deployment_id) { + Ok(deployment_id) => { + self.cost_models.write().unwrap().remove(&deployment_id); + } + Err(_) => { + error!( + "Received delete request for an invalid deployment_id: {}", + deployment_id + ) + } + }, + }; } // UPDATE and TRUNCATE are not expected to happen. Reload the entire cost // model cache. - _ => { + Err(_) => { error!( "Received an unexpected cost model table notification: {}. Reloading entire \ cost model.", - cost_model_notification.tg_op + payload ); - { - // clear all pending expire - let mut handles = self.handles.lock().unwrap(); - for maps in handles.values() { - for handle in maps { - handle.abort(); - } - } - handles.clear(); - } - - MinimumValue::value_check_reload(&self.pgpool, self.cost_model_cache.clone()) - .await - .expect("should be able to reload cost models") - } - } - } - - async fn cancel_cache_expire(&self, deployment_id: DeploymentId) { - let mut cost_model_write = self.cost_model_cache.write().unwrap(); - if let Some(cache) = cost_model_write.get_mut(&deployment_id) { - cache.get_mut().unwrap().expire(); - } - - if let Entry::Occupied(mut entry) = self.handles.lock().unwrap().entry(deployment_id) { - let vec = entry.get_mut(); - vec.pop_front(); - if vec.is_empty() { - entry.remove(); + MinimumValue::value_check_reload( + &self.pgpool, + self.cost_models.clone(), + self.global_model.clone(), + ) + .await + .expect("should be able to reload cost models") } } } @@ -196,59 +157,60 @@ impl Drop for MinimumValue { impl MinimumValue { pub async fn new(pgpool: PgPool) -> Self { - let cost_model_cache: CostModelMap = Default::default(); - Self::value_check_reload(&pgpool, cost_model_cache.clone()) + let cost_model_map: CostModelMap = Default::default(); + let global_model: GlobalModel = Default::default(); + Self::value_check_reload(&pgpool, cost_model_map.clone(), global_model.clone()) .await .expect("should be able to reload cost models"); let mut pglistener = PgListener::connect_with(&pgpool.clone()).await.unwrap(); - pglistener.listen("cost_models_update_notify").await.expect( - "should be able to subscribe to Postgres Notify events on the channel \ - 'cost_models_update_notify'", - ); - - let global_model_cache: Arc> = Default::default(); + pglistener + .listen("cost_models_update_notification") + .await + .expect( + "should be able to subscribe to Postgres Notify events on the channel \ + 'cost_models_update_notification'", + ); let watcher_cancel_token = tokio_util::sync::CancellationToken::new(); tokio::spawn(CostModelWatcher::cost_models_watcher( pgpool.clone(), pglistener, - cost_model_cache.clone(), - global_model_cache.clone(), + cost_model_map.clone(), + global_model.clone(), watcher_cancel_token.clone(), )); Self { - global_model_cache, - cost_model_cache, + global_model, + cost_model_map, watcher_cancel_token, } } fn get_expected_value(&self, agora_query: &AgoraQuery) -> anyhow::Result { - // get agora model for the allocation_id - let cache = self.cost_model_cache.read().unwrap(); - let models = cache.get(&agora_query.deployment_id); - - // TODO check global cost model - - let expected_value = models - .map(|cache| { - let cache = cache.read().unwrap(); - cache.cost(agora_query) - }) - .unwrap_or(MINIMAL_VALUE); + // get agora model for the deployment_id + let model = self.cost_model_map.read().unwrap(); + let subgraph_model = model.get(&agora_query.deployment_id); + let global_model = self.global_model.read().unwrap(); + + let expected_value = match (subgraph_model, global_model.as_ref()) { + (Some(model), _) | (_, Some(model)) => model + .cost(&agora_query.query, &agora_query.variables) + .map(|fee| fee.to_u128()) + .ok() + .flatten(), + _ => None, + }; - Ok(expected_value) + Ok(expected_value.unwrap_or(MINIMAL_VALUE)) } async fn value_check_reload( pgpool: &PgPool, - cost_model_cache: CostModelMap, + cost_model_map: CostModelMap, + global_model: GlobalModel, ) -> anyhow::Result<()> { - // TODO make sure to load last cost model - // plus all models that were created 60 secoonds from now - // let models = sqlx::query!( r#" SELECT deployment, model, variables @@ -261,19 +223,29 @@ impl MinimumValue { .await?; let models = models .into_iter() - .map(|record| { - let deployment_id = DeploymentId::from_str(&record.deployment.unwrap())?; - let mut model = CostModelCache::default(); - let _ = model.insert_model(CostModelSource { - model: record.model.unwrap_or_default(), - variables: record.variables.unwrap_or_default(), - }); - - Ok::<_, ParseDeploymentIdError>((deployment_id, RwLock::new(model))) + .flat_map(|record| { + let deployment_id = DeploymentId::from_str(&record.deployment).ok()?; + let model = compile_cost_model( + record.model?, + record.variables.map(|v| v.to_string()).unwrap_or_default(), + ) + .ok()?; + Some((deployment_id, model)) }) - .collect::, _>>()?; - - *(cost_model_cache.write().unwrap()) = models; + .collect::>(); + + *cost_model_map.write().unwrap() = models; + + *global_model.write().unwrap() = + cost_model::global_cost_model(pgpool) + .await? + .and_then(|model| { + compile_cost_model( + model.model.unwrap_or_default(), + model.variables.map(|v| v.to_string()).unwrap_or_default(), + ) + .ok() + }); Ok(()) } @@ -321,109 +293,243 @@ fn compile_cost_model(model: String, variables: String) -> anyhow::Result, - variables: Option, -} + use sqlx::PgPool; + use tap_core::receipt::{checks::Check, Context, ReceiptWithState}; + use tokio::time::sleep; -impl From for CostModelSource { - fn from(value: CostModelNotification) -> Self { - CostModelSource { - model: value.model.unwrap_or_default(), - variables: value.variables.unwrap_or_default(), - } + use crate::{ + cost_model::test::{add_cost_models, global_cost_model, to_db_models}, + tap::AgoraQuery, + test_vectors::create_signed_receipt, + }; + + use super::MinimumValue; + + #[sqlx::test(migrations = "../migrations")] + async fn initialize_check(pgpool: PgPool) { + let check = MinimumValue::new(pgpool).await; + assert_eq!(check.cost_model_map.read().unwrap().len(), 0); } -} -#[derive(Default)] -struct CostModelCache { - models: VecDeque, -} + #[sqlx::test(migrations = "../migrations")] + async fn should_initialize_check_with_models(pgpool: PgPool) { + // insert 2 cost models for different deployment_id + let test_models = crate::cost_model::test::test_data(); -impl CostModelCache { - fn insert_model(&mut self, source: CostModelSource) -> anyhow::Result<()> { - let model = compile_cost_model(source.model, source.variables.to_string())?; - self.models.push_back(model); - Ok(()) + add_cost_models(&pgpool, to_db_models(test_models.clone())).await; + + let check = MinimumValue::new(pgpool).await; + assert_eq!(check.cost_model_map.read().unwrap().len(), 2); + + // no global model + assert!(check.global_model.read().unwrap().is_none()); } - fn expire(&mut self) { - self.models.pop_front(); + #[sqlx::test(migrations = "../migrations")] + async fn should_watch_model_insert(pgpool: PgPool) { + let check = MinimumValue::new(pgpool.clone()).await; + assert_eq!(check.cost_model_map.read().unwrap().len(), 0); + + // insert 2 cost models for different deployment_id + let test_models = crate::cost_model::test::test_data(); + add_cost_models(&pgpool, to_db_models(test_models.clone())).await; + sleep(Duration::from_millis(200)).await; + + assert_eq!( + check.cost_model_map.read().unwrap().len(), + test_models.len() + ); } - fn is_empty(&self) -> bool { - self.models.is_empty() + #[sqlx::test(migrations = "../migrations")] + async fn should_watch_model_remove(pgpool: PgPool) { + // insert 2 cost models for different deployment_id + let test_models = crate::cost_model::test::test_data(); + add_cost_models(&pgpool, to_db_models(test_models.clone())).await; + + let check = MinimumValue::new(pgpool.clone()).await; + assert_eq!(check.cost_model_map.read().unwrap().len(), 2); + + // remove + sqlx::query!(r#"DELETE FROM "CostModels""#) + .execute(&pgpool) + .await + .unwrap(); + + sleep(Duration::from_millis(200)).await; + + assert_eq!(check.cost_model_map.read().unwrap().len(), 0); } - fn cost(&self, query: &AgoraQuery) -> u128 { - self.models - .iter() - .fold(None, |acc, model| { - let value = model - .cost(&query.query, &query.variables) - .ok() - .map(|fee| fee.to_u128().unwrap_or_default()) - .unwrap_or_default(); - if let Some(acc) = acc { - // return the minimum value of the cache list - Some(min(acc, value)) - } else { - Some(value) - } - }) - .unwrap_or(MINIMAL_VALUE) + #[sqlx::test(migrations = "../migrations")] + async fn should_start_global_model(pgpool: PgPool) { + let global_model = global_cost_model(); + add_cost_models(&pgpool, vec![global_model.clone()]).await; + + let check = MinimumValue::new(pgpool.clone()).await; + assert!(check.global_model.read().unwrap().is_some()); } -} -#[cfg(test)] -mod tests { - use sqlx::PgPool; + #[sqlx::test(migrations = "../migrations")] + async fn should_watch_global_model(pgpool: PgPool) { + let check = MinimumValue::new(pgpool.clone()).await; - use crate::cost_model::test::{add_cost_models, to_db_models}; + let global_model = global_cost_model(); + add_cost_models(&pgpool, vec![global_model.clone()]).await; + sleep(Duration::from_millis(10)).await; - use super::MinimumValue; + assert!(check.global_model.read().unwrap().is_some()); + } #[sqlx::test(migrations = "../migrations")] - async fn initialize_check(pgpool: PgPool) { - let check = MinimumValue::new(pgpool).await; - assert_eq!(check.cost_model_cache.read().unwrap().len(), 0); + async fn should_remove_global_model(pgpool: PgPool) { + let global_model = global_cost_model(); + add_cost_models(&pgpool, vec![global_model.clone()]).await; + + let check = MinimumValue::new(pgpool.clone()).await; + assert!(check.global_model.read().unwrap().is_some()); + + sqlx::query!(r#"DELETE FROM "CostModels""#) + .execute(&pgpool) + .await + .unwrap(); + + sleep(Duration::from_millis(10)).await; + + assert_eq!(check.cost_model_map.read().unwrap().len(), 0); } + const ALLOCATION_ID: Address = address!("deadbeefcafebabedeadbeefcafebabedeadbeef"); + #[sqlx::test(migrations = "../migrations")] - async fn should_initialize_check_with_caches(pgpool: PgPool) { - // insert 2 cost models for different deployment_id + async fn should_check_minimal_value(pgpool: PgPool) { + // insert cost models for different deployment_id let test_models = crate::cost_model::test::test_data(); add_cost_models(&pgpool, to_db_models(test_models.clone())).await; let check = MinimumValue::new(pgpool).await; - assert_eq!( - check.cost_model_cache.read().unwrap().len(), - test_models.len() + + let deployment_id = test_models[0].deployment; + let mut ctx = Context::new(); + ctx.insert(AgoraQuery { + deployment_id, + query: "query { a(skip: 10), b(bob: 5) }".into(), + variables: "".into(), + }); + + let signed_receipt = create_signed_receipt(ALLOCATION_ID, u64::MAX, u64::MAX, 0).await; + let receipt = ReceiptWithState::new(signed_receipt); + + assert!( + check.check(&ctx, &receipt).await.is_err(), + "Should deny if value is 0 for any query" + ); + + let signed_receipt = create_signed_receipt(ALLOCATION_ID, u64::MAX, u64::MAX, 1).await; + let receipt = ReceiptWithState::new(signed_receipt); + assert!( + check.check(&ctx, &receipt).await.is_ok(), + "Should accept if value is more than 0 for any query" ); + + let deployment_id = test_models[1].deployment; + let mut ctx = Context::new(); + ctx.insert(AgoraQuery { + deployment_id, + query: "query { a(skip: 10), b(bob: 5) }".into(), + variables: "".into(), + }); + let minimal_value = 500000000000000; + + let signed_receipt = + create_signed_receipt(ALLOCATION_ID, u64::MAX, u64::MAX, minimal_value - 1).await; + let receipt = ReceiptWithState::new(signed_receipt); + assert!( + check.check(&ctx, &receipt).await.is_err(), + "Should require minimal value" + ); + + let signed_receipt = + create_signed_receipt(ALLOCATION_ID, u64::MAX, u64::MAX, 500000000000000).await; + let receipt = ReceiptWithState::new(signed_receipt); + check + .check(&ctx, &receipt) + .await + .expect("should accept equals minimal"); + + let signed_receipt = + create_signed_receipt(ALLOCATION_ID, u64::MAX, u64::MAX, minimal_value + 1).await; + let receipt = ReceiptWithState::new(signed_receipt); + check + .check(&ctx, &receipt) + .await + .expect("should accept more than minimal"); } #[sqlx::test(migrations = "../migrations")] - async fn should_add_model_to_cache_on_insert(_pg_pool: PgPool) {} + async fn should_check_using_global(pgpool: PgPool) { + // insert cost models for different deployment_id + let test_models = crate::cost_model::test::test_data(); + let global_model = global_cost_model(); - #[sqlx::test(migrations = "../migrations")] - async fn should_expire_old_model(_pg_pool: PgPool) {} + add_cost_models(&pgpool, vec![global_model.clone()]).await; + add_cost_models(&pgpool, to_db_models(test_models.clone())).await; - #[sqlx::test(migrations = "../migrations")] - async fn should_verify_global_model(_pg_pool: PgPool) {} + let check = MinimumValue::new(pgpool).await; + + let deployment_id = test_models[0].deployment; + let mut ctx = Context::new(); + ctx.insert(AgoraQuery { + deployment_id, + query: "query { a(skip: 10), b(bob: 5) }".into(), + variables: "".into(), + }); + + let minimal_global_value = 20000000000000; + let signed_receipt = + create_signed_receipt(ALLOCATION_ID, u64::MAX, u64::MAX, minimal_global_value - 1) + .await; + let receipt = ReceiptWithState::new(signed_receipt); + + assert!( + check.check(&ctx, &receipt).await.is_err(), + "Should deny less than global" + ); + + let signed_receipt = + create_signed_receipt(ALLOCATION_ID, u64::MAX, u64::MAX, minimal_global_value).await; + let receipt = ReceiptWithState::new(signed_receipt); + check + .check(&ctx, &receipt) + .await + .expect("should accept equals global"); + + let signed_receipt = + create_signed_receipt(ALLOCATION_ID, u64::MAX, u64::MAX, minimal_global_value + 1) + .await; + let receipt = ReceiptWithState::new(signed_receipt); + check + .check(&ctx, &receipt) + .await + .expect("should accept more than global"); + } } diff --git a/migrations/20230901142040_cost_models.down.sql b/migrations/20230901142040_cost_models.down.sql index 744b443e9..885eeefc6 100644 --- a/migrations/20230901142040_cost_models.down.sql +++ b/migrations/20230901142040_cost_models.down.sql @@ -1,8 +1,2 @@ -- Add down migration script here -DROP TRIGGER IF EXISTS cost_models_update ON "CostModelsHistory" CASCADE; - -DROP FUNCTION IF EXISTS cost_models_update_notify() CASCADE; - -DROP VIEW "CostModels"; - -DROP TABLE "CostModelsHistory"; +DROP TABLE "CostModels"; diff --git a/migrations/20230901142040_cost_models.up.sql b/migrations/20230901142040_cost_models.up.sql index 1e3dc7ad7..96b10f0bd 100644 --- a/migrations/20230901142040_cost_models.up.sql +++ b/migrations/20230901142040_cost_models.up.sql @@ -1,40 +1,8 @@ -CREATE TABLE IF NOT EXISTS "CostModelsHistory" +CREATE TABLE IF NOT EXISTS "CostModels" ( - id SERIAL PRIMARY KEY, + id INT, deployment VARCHAR NOT NULL, model TEXT, variables JSONB, - "createdAt" TIMESTAMP WITH TIME ZONE, - "updatedAt" TIMESTAMP WITH TIME ZONE + PRIMARY KEY( deployment ) ); - -CREATE VIEW "CostModels" AS -SELECT DISTINCT ON (deployment, model, variables) - deployment, - model, - variables, - "createdAt", - "updatedAt" -FROM "CostModelsHistory" -ORDER BY deployment, model, variables, id DESC; - -CREATE FUNCTION cost_models_update_notify() -RETURNS trigger AS -$$ -BEGIN - IF TG_OP = 'DELETE' THEN - PERFORM pg_notify('cost_models_update_notification', format('{"tg_op": "DELETE", "deployment": "%s"}', OLD.deployment)); - RETURN OLD; - ELSIF TG_OP = 'INSERT' THEN - PERFORM pg_notify('cost_models_update_notification', format('{"tg_op": "INSERT", "deployment": "%s", "model": "%s"}', NEW.deployment, NEW.model)); - RETURN NEW; - ELSE - PERFORM pg_notify('cost_models_update_notification', format('{"tg_op": "%s", "deployment": "%s", "model": "%s"}', NEW.deployment, NEW.model)); - RETURN NEW; - END IF; -END; -$$ LANGUAGE 'plpgsql'; - -CREATE TRIGGER cost_models_update AFTER INSERT OR UPDATE OR DELETE - ON "CostModelsHistory" - FOR EACH ROW EXECUTE PROCEDURE cost_models_update_notify(); diff --git a/migrations/20241024191258_add_cost_model_notification.down.sql b/migrations/20241024191258_add_cost_model_notification.down.sql new file mode 100644 index 000000000..16d32dfcc --- /dev/null +++ b/migrations/20241024191258_add_cost_model_notification.down.sql @@ -0,0 +1,4 @@ +-- Add down migration script here +DROP TRIGGER IF EXISTS cost_models_update ON "CostModels" CASCADE; + +DROP FUNCTION IF EXISTS cost_models_update_notify() CASCADE; diff --git a/migrations/20241024191258_add_cost_model_notification.up.sql b/migrations/20241024191258_add_cost_model_notification.up.sql new file mode 100644 index 000000000..c2f3fb37c --- /dev/null +++ b/migrations/20241024191258_add_cost_model_notification.up.sql @@ -0,0 +1,21 @@ +-- Add up migration script here +CREATE FUNCTION cost_models_update_notify() +RETURNS trigger AS +$$ +BEGIN + IF TG_OP = 'DELETE' THEN + PERFORM pg_notify('cost_models_update_notification', format('{"tg_op": "DELETE", "deployment": "%s"}', OLD.deployment)); + RETURN OLD; + ELSIF TG_OP = 'INSERT' THEN + PERFORM pg_notify('cost_models_update_notification', format('{"tg_op": "INSERT", "deployment": "%s", "model": "%s", "variables": "%s"}', NEW.deployment, NEW.model, NEW.variables)); + RETURN NEW; + ELSE + PERFORM pg_notify('cost_models_update_notification', format('{"tg_op": "%s", "deployment": "%s", "model": "%s", "variables": "%s" }', NEW.deployment, NEW.model, NEW.variables)); + RETURN NEW; + END IF; +END; +$$ LANGUAGE 'plpgsql'; + +CREATE TRIGGER cost_models_update AFTER INSERT OR UPDATE OR DELETE + ON "CostModels" + FOR EACH ROW EXECUTE PROCEDURE cost_models_update_notify(); From 026815bd22507cc24446cc281f0f57a0de9dd369 Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Tue, 29 Oct 2024 12:51:47 -0600 Subject: [PATCH 20/21] docs: add documentation to public structs Signed-off-by: Gustavo Inacio --- common/src/tap/checks/value_check.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/common/src/tap/checks/value_check.rs b/common/src/tap/checks/value_check.rs index e42d29b93..d27d49bd2 100644 --- a/common/src/tap/checks/value_check.rs +++ b/common/src/tap/checks/value_check.rs @@ -27,6 +27,10 @@ use crate::cost_model; // we only accept receipts with minimal 1 wei grt const MINIMAL_VALUE: u128 = 1; +/// Represents a query that can be checked against an agora model +/// +/// It contains the deployment_id to check which agora model evaluate +/// and also the query and variables to perform the evaluation pub struct AgoraQuery { pub deployment_id: DeploymentId, pub query: String, @@ -36,6 +40,10 @@ pub struct AgoraQuery { type CostModelMap = Arc>>; type GlobalModel = Arc>>; +/// Represents the check for minimum for a receipt +/// +/// It contains all information needed in memory to +/// make it as fast as possible pub struct MinimumValue { cost_model_map: CostModelMap, global_model: GlobalModel, From 8092cab0f8ec0607204f0b5464b1db91a08b27f7 Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Tue, 29 Oct 2024 12:52:30 -0600 Subject: [PATCH 21/21] refactor: split new_notification Signed-off-by: Gustavo Inacio --- common/src/tap/checks/value_check.rs | 116 ++++++++++++++------------- 1 file changed, 61 insertions(+), 55 deletions(-) diff --git a/common/src/tap/checks/value_check.rs b/common/src/tap/checks/value_check.rs index d27d49bd2..daf9237c3 100644 --- a/common/src/tap/checks/value_check.rs +++ b/common/src/tap/checks/value_check.rs @@ -95,63 +95,69 @@ impl CostModelWatcher { deployment, model, variables, - }) => { - let model = compile_cost_model(model, variables).unwrap(); - - match deployment.as_str() { - "global" => { - *self.global_model.write().unwrap() = Some(model); - } - deployment_id => match DeploymentId::from_str(deployment_id) { - Ok(deployment_id) => { - let mut cost_model_write = self.cost_models.write().unwrap(); - cost_model_write.insert(deployment_id, model); - } - Err(_) => { - error!( - "Received insert request for an invalid deployment_id: {}", - deployment_id - ) - } - }, - }; - } - Ok(CostModelNotification::Delete { deployment }) => { - match deployment.as_str() { - "global" => { - *self.global_model.write().unwrap() = None; - } - deployment_id => match DeploymentId::from_str(deployment_id) { - Ok(deployment_id) => { - self.cost_models.write().unwrap().remove(&deployment_id); - } - Err(_) => { - error!( - "Received delete request for an invalid deployment_id: {}", - deployment_id - ) - } - }, - }; - } + }) => self.handle_insert(deployment, model, variables), + Ok(CostModelNotification::Delete { deployment }) => self.handle_delete(deployment), // UPDATE and TRUNCATE are not expected to happen. Reload the entire cost // model cache. - Err(_) => { - error!( - "Received an unexpected cost model table notification: {}. Reloading entire \ - cost model.", - payload - ); + Err(_) => self.handle_unexpected_notification(payload).await, + } + } - MinimumValue::value_check_reload( - &self.pgpool, - self.cost_models.clone(), - self.global_model.clone(), - ) - .await - .expect("should be able to reload cost models") + fn handle_insert(&self, deployment: String, model: String, variables: String) { + let model = compile_cost_model(model, variables).unwrap(); + + match deployment.as_str() { + "global" => { + *self.global_model.write().unwrap() = Some(model); } - } + deployment_id => match DeploymentId::from_str(deployment_id) { + Ok(deployment_id) => { + let mut cost_model_write = self.cost_models.write().unwrap(); + cost_model_write.insert(deployment_id, model); + } + Err(_) => { + error!( + "Received insert request for an invalid deployment_id: {}", + deployment_id + ) + } + }, + }; + } + + fn handle_delete(&self, deployment: String) { + match deployment.as_str() { + "global" => { + *self.global_model.write().unwrap() = None; + } + deployment_id => match DeploymentId::from_str(deployment_id) { + Ok(deployment_id) => { + self.cost_models.write().unwrap().remove(&deployment_id); + } + Err(_) => { + error!( + "Received delete request for an invalid deployment_id: {}", + deployment_id + ) + } + }, + }; + } + + async fn handle_unexpected_notification(&self, payload: &str) { + error!( + "Received an unexpected cost model table notification: {}. Reloading entire \ + cost model.", + payload + ); + + MinimumValue::value_check_reload( + &self.pgpool, + self.cost_models.clone(), + self.global_model.clone(), + ) + .await + .expect("should be able to reload cost models") } } @@ -196,7 +202,7 @@ impl MinimumValue { } } - fn get_expected_value(&self, agora_query: &AgoraQuery) -> anyhow::Result { + fn expected_value(&self, agora_query: &AgoraQuery) -> anyhow::Result { // get agora model for the deployment_id let model = self.cost_model_map.read().unwrap(); let subgraph_model = model.get(&agora_query.deployment_id); @@ -267,7 +273,7 @@ impl Check for MinimumValue { .ok_or(CheckError::Failed(anyhow!("Could not find agora query")))?; let expected_value = self - .get_expected_value(agora_query) + .expected_value(agora_query) .map_err(CheckError::Failed)?; // get value