diff --git a/.sqlx/query-58c4216b3ce62314fa727e0f3c504bd64ca10a64a95e4eb8bd7bb822c1625ed6.json b/.sqlx/query-58c4216b3ce62314fa727e0f3c504bd64ca10a64a95e4eb8bd7bb822c1625ed6.json deleted file mode 100644 index 033c25bb3..000000000 --- a/.sqlx/query-58c4216b3ce62314fa727e0f3c504bd64ca10a64a95e4eb8bd7bb822c1625ed6.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n CREATE TABLE \"CostModels\"(\n id INT,\n deployment VARCHAR NOT NULL,\n model TEXT,\n variables JSONB,\n PRIMARY KEY( deployment )\n );\n ", - "describe": { - "columns": [], - "parameters": { - "Left": [] - }, - "nullable": [] - }, - "hash": "58c4216b3ce62314fa727e0f3c504bd64ca10a64a95e4eb8bd7bb822c1625ed6" -} diff --git a/.sqlx/query-dcc710ebd4fa7a86f95a45ebf13bc7c298006af22134c144048b0ac7183353c0.json b/.sqlx/query-dcc710ebd4fa7a86f95a45ebf13bc7c298006af22134c144048b0ac7183353c0.json new file mode 100644 index 000000000..65922247c --- /dev/null +++ b/.sqlx/query-dcc710ebd4fa7a86f95a45ebf13bc7c298006af22134c144048b0ac7183353c0.json @@ -0,0 +1,12 @@ +{ + "db_name": "PostgreSQL", + "query": "DELETE FROM \"CostModels\"", + "describe": { + "columns": [], + "parameters": { + "Left": [] + }, + "nullable": [] + }, + "hash": "dcc710ebd4fa7a86f95a45ebf13bc7c298006af22134c144048b0ac7183353c0" +} diff --git a/Cargo.lock b/Cargo.lock index 82d5cb970..9472072cc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -353,7 +353,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26154390b1d205a4a7ac7352aa2eb4f81f391399d4e2f546fb81a2f8bb383f62" dependencies = [ "alloy-rlp-derive", - "arrayvec", + "arrayvec 0.7.6", "bytes", ] @@ -701,6 +701,12 @@ version = "1.0.90" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37bf3594c4c988a53154954629820791dde498571819ae4ca50ca811e060cc95" +[[package]] +name = "anymap3" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9316150bf847270bd54e404c3fb1c5d3c6af118b3f0c3b42097da5ceafbe7c27" + [[package]] name = "ark-ff" version = "0.3.0" @@ -712,7 +718,7 @@ dependencies = [ "ark-serialize 0.3.0", "ark-std 0.3.0", "derivative", - "num-bigint", + "num-bigint 0.4.6", "num-traits", "paste", "rustc_version 0.3.3", @@ -732,7 +738,7 @@ dependencies = [ "derivative", "digest 0.10.7", "itertools 0.10.5", - "num-bigint", + "num-bigint 0.4.6", "num-traits", "paste", "rustc_version 0.4.1", @@ -765,7 +771,7 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "db2fd794a08ccb318058009eefdf15bcaaaaf6f8161eb3345f907222bac38b20" dependencies = [ - "num-bigint", + "num-bigint 0.4.6", "num-traits", "quote", "syn 1.0.109", @@ -777,7 +783,7 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7abe79b0e4288889c4574159ab790824d0033b9fdcb2a112a3182fac2e514565" dependencies = [ - "num-bigint", + "num-bigint 0.4.6", "num-traits", "proc-macro2", "quote", @@ -802,7 +808,7 @@ checksum = "adb7b85a02b83d2f22f89bd5cac66c9c89474240cb6207cb1efc16d098e822a5" dependencies = [ "ark-std 0.4.0", "digest 0.10.7", - "num-bigint", + "num-bigint 0.4.6", ] [[package]] @@ -825,6 +831,12 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "arrayvec" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b" + [[package]] name = "arrayvec" version = "0.7.6" @@ -1263,7 +1275,7 @@ checksum = "51d712318a27c7150326677b321a5fa91b55f6d9034ffd67f20319e147d40cee" dependencies = [ "autocfg", "libm", - "num-bigint", + "num-bigint 0.4.6", "num-integer", "num-traits", "serde", @@ -1482,7 +1494,7 @@ dependencies = [ "bincode", "build-info-common", "chrono", - "num-bigint", + "num-bigint 0.4.6", "num-traits", "proc-macro-error2", "proc-macro2", @@ -1818,6 +1830,22 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "cost-model" +version = "0.1.0" +source = "git+https://github.com/graphprotocol/agora?rev=3ed34ca#3ed34cae6dadded9c8c6ca34356309c98f0f0578" +dependencies = [ + "firestorm", + "fraction", + "graphql 0.2.0", + "itertools 0.12.1", + "lazy_static", + "nom 5.1.3", + "num-bigint 0.2.6", + "num-traits", + "serde_json", +] + [[package]] name = "cpufeatures" version = "0.2.14" @@ -2243,7 +2271,7 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "139834ddba373bbdd213dffe02c8d110508dcf1726c2be27e8d1f7d7e1856418" dependencies = [ - "arrayvec", + "arrayvec 0.7.6", "auto_impl", "bytes", ] @@ -2347,6 +2375,16 @@ dependencies = [ "thiserror", ] +[[package]] +name = "fraction" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a27f0e7512f6915c9bc38594725e33d7673da9308fea0abf4cc258c281cdbb2a" +dependencies = [ + "lazy_static", + "num", +] + [[package]] name = "fs_extra" version = "1.3.0" @@ -2566,6 +2604,16 @@ dependencies = [ "spinning_top", ] +[[package]] +name = "graphql" +version = "0.2.0" +source = "git+https://github.com/edgeandnode/toolshed?tag=graphql-v0.2.0#2df5ee975656027d4a7fbf591baf6a29dfbe0ee6" +dependencies = [ + "firestorm", + "graphql-parser", + "serde", +] + [[package]] name = "graphql" version = "0.3.0" @@ -3095,6 +3143,7 @@ dependencies = [ "bigdecimal", "bip39", "build-info", + "cost-model", "env_logger", "eventuals", "graphql_client", @@ -3106,7 +3155,7 @@ dependencies = [ "serde", "serde_json", "sqlx", - "tap_core 1.0.0 (git+https://github.com/semiotic-ai/timeline-aggregation-protocol?rev=ff856d9)", + "tap_core 1.0.0 (git+https://github.com/semiotic-ai/timeline-aggregation-protocol?rev=3fe6bc2)", "test-log", "thegraph-core", "thegraph-graphql-http", @@ -3153,7 +3202,7 @@ dependencies = [ "build-info", "build-info-build", "clap", - "graphql", + "graphql 0.3.0", "hex-literal", "indexer-common", "indexer-config", @@ -3197,7 +3246,7 @@ dependencies = [ "serde_json", "sqlx", "tap_aggregator", - "tap_core 1.0.0 (git+https://github.com/semiotic-ai/timeline-aggregation-protocol?rev=ff856d9)", + "tap_core 1.0.0 (git+https://github.com/semiotic-ai/timeline-aggregation-protocol?rev=3fe6bc2)", "tempfile", "thegraph-core", "thiserror", @@ -3292,6 +3341,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.13.0" @@ -3555,6 +3613,19 @@ dependencies = [ "spin", ] +[[package]] +name = "lexical-core" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6607c62aa161d23d17a9072cc5da0be67cdfc89d3afb1e8d9c842bebc2525ffe" +dependencies = [ + "arrayvec 0.5.2", + "bitflags 1.3.2", + "cfg-if", + "ryu", + "static_assertions", +] + [[package]] name = "libc" version = "0.2.159" @@ -3799,6 +3870,17 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c" +[[package]] +name = "nom" +version = "5.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08959a387a676302eebf4ddbcbc611da04285579f76f88ee0506c63b1a61dd4b" +dependencies = [ + "lexical-core", + "memchr", + "version_check", +] + [[package]] name = "nom" version = "7.1.3" @@ -3831,6 +3913,31 @@ dependencies = [ "winapi", ] +[[package]] +name = "num" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36" +dependencies = [ + "num-bigint 0.2.6", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "090c7f9998ee0ff65aa5b723e4009f7b217707f1fb5ea551329cc4d6231fb304" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + [[package]] name = "num-bigint" version = "0.4.6" @@ -3858,6 +3965,16 @@ dependencies = [ "zeroize", ] +[[package]] +name = "num-complex" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95" +dependencies = [ + "autocfg", + "num-traits", +] + [[package]] name = "num-conv" version = "0.1.0" @@ -3884,6 +4001,18 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-rational" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c000134b5dbf44adc5cb772486d335293351644b801551abe8f75c84cfa4aef" +dependencies = [ + "autocfg", + "num-bigint 0.2.6", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -4056,7 +4185,7 @@ version = "3.6.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "306800abfa29c7f16596b5970a588435e3d5b3149683d00c12b699cc19f895ee" dependencies = [ - "arrayvec", + "arrayvec 0.7.6", "bitvec", "byte-slice-cast", "impl-trait-for-tuples", @@ -4915,7 +5044,7 @@ dependencies = [ "ark-ff 0.4.2", "bytes", "fastrlp", - "num-bigint", + "num-bigint 0.4.6", "num-traits", "parity-scale-codec", "primitive-types", @@ -4940,7 +5069,7 @@ version = "1.36.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b082d80e3e3cc52b2ed634388d436fe1f4de6af5786cc2de9ba9737527bdf555" dependencies = [ - "arrayvec", + "arrayvec 0.7.6", "borsh", "bytes", "num-traits", @@ -5233,7 +5362,7 @@ dependencies = [ "core-foundation", "core-foundation-sys", "libc", - "num-bigint", + "num-bigint 0.4.6", "security-framework-sys", ] @@ -5514,7 +5643,7 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adc4e5204eb1910f40f9cfa375f6f05b68c3abac4b6fd879c8ff5e7ae8a0a085" dependencies = [ - "num-bigint", + "num-bigint 0.4.6", "num-traits", "thiserror", "time", @@ -5615,7 +5744,7 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7bba3a93db0cc4f7bdece8bb09e77e2e785c20bfebf79eb8340ed80708048790" dependencies = [ - "nom", + "nom 7.1.3", "unicode_categories", ] @@ -5785,7 +5914,7 @@ dependencies = [ "log", "md-5", "memchr", - "num-bigint", + "num-bigint 0.4.6", "once_cell", "rand 0.8.5", "rust_decimal", @@ -6038,17 +6167,14 @@ dependencies = [ [[package]] name = "tap_core" version = "1.0.0" -source = "git+https://github.com/semiotic-ai/timeline-aggregation-protocol?rev=eb8447e#eb8447ed4566ced6846c03b510b25b915f985186" +source = "git+https://github.com/semiotic-ai/timeline-aggregation-protocol?rev=3fe6bc2#3fe6bc27c10161e5a1c011c088281d3da227df53" dependencies = [ "alloy", "anyhow", + "anymap3", "async-trait", "rand 0.8.5", - "rand_core 0.6.4", - "rstest", "serde", - "strum 0.24.1", - "strum_macros 0.24.3", "thiserror", "tokio", ] @@ -6056,13 +6182,17 @@ dependencies = [ [[package]] name = "tap_core" version = "1.0.0" -source = "git+https://github.com/semiotic-ai/timeline-aggregation-protocol?rev=ff856d9#ff856d966112af4c4d554a81154797fae4b335d9" +source = "git+https://github.com/semiotic-ai/timeline-aggregation-protocol?rev=eb8447e#eb8447ed4566ced6846c03b510b25b915f985186" dependencies = [ "alloy", "anyhow", "async-trait", "rand 0.8.5", + "rand_core 0.6.4", + "rstest", "serde", + "strum 0.24.1", + "strum_macros 0.24.3", "thiserror", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 6bf9e0b65..ed66b25d2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,7 +38,7 @@ sqlx = { version = "0.8.2", features = [ tracing = { version = "0.1.40", default-features = false } bigdecimal = "0.4.3" build-info = "0.0.39" -tap_core = { git = "https://github.com/semiotic-ai/timeline-aggregation-protocol", rev = "ff856d9", default-features = false } +tap_core = { git = "https://github.com/semiotic-ai/timeline-aggregation-protocol", rev = "3fe6bc2", default-features = false } tracing-subscriber = { version = "0.3", features = [ "json", "env-filter", diff --git a/common/Cargo.toml b/common/Cargo.toml index 4cb1b25ba..911fcc048 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -26,6 +26,7 @@ graphql_client.workspace = true serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } tokio = { workspace = true, features = ["fs", "tokio-macros"] } +cost-model = { git = "https://github.com/graphprotocol/agora", rev = "3ed34ca" } regex = "1.7.1" axum-extra = { version = "0.9.3", features = [ "typed-header", diff --git a/common/src/cost_model.rs b/common/src/cost_model.rs new file mode 100644 index 000000000..a2d50570b --- /dev/null +++ b/common/src/cost_model.rs @@ -0,0 +1,465 @@ +// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs. +// SPDX-License-Identifier: Apache-2.0 + +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use sqlx::PgPool; +use std::{collections::HashSet, str::FromStr}; +use thegraph_core::{DeploymentId, ParseDeploymentIdError}; + +/// Internal cost model representation as stored in the database. +/// +/// These can have "global" as the deployment ID. +#[derive(Debug, Clone)] +pub(crate) struct DbCostModel { + pub deployment: String, + pub model: Option, + pub variables: Option, +} + +/// External representation of cost models. +/// +/// Here, any notion of "global" is gone and deployment IDs are valid deployment IDs. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CostModel { + pub deployment: DeploymentId, + pub model: Option, + pub variables: Option, +} + +impl TryFrom for CostModel { + type Error = ParseDeploymentIdError; + + fn try_from(db_model: DbCostModel) -> Result { + Ok(Self { + deployment: DeploymentId::from_str(&db_model.deployment)?, + model: db_model.model, + variables: db_model.variables, + }) + } +} + +impl From for DbCostModel { + fn from(model: CostModel) -> Self { + let deployment = model.deployment; + DbCostModel { + deployment: format!("{deployment:#x}"), + model: model.model, + variables: model.variables, + } + } +} + +/// Query cost models from the database, merging the global cost model in +/// whenever there is no cost model defined for a deployment. +pub async fn cost_models( + pool: &PgPool, + deployments: &[DeploymentId], +) -> Result, anyhow::Error> { + let hex_ids = deployments + .iter() + .map(|d| format!("{d:#x}")) + .collect::>(); + + let mut models = if deployments.is_empty() { + sqlx::query_as!( + DbCostModel, + r#" + SELECT deployment, model, variables + FROM "CostModels" + WHERE deployment != 'global' + ORDER BY deployment ASC + "# + ) + .fetch_all(pool) + .await? + } else { + sqlx::query_as!( + DbCostModel, + r#" + SELECT deployment, model, variables + FROM "CostModels" + WHERE deployment = ANY($1) + AND deployment != 'global' + ORDER BY deployment ASC + "#, + &hex_ids + ) + .fetch_all(pool) + .await? + } + .into_iter() + .map(CostModel::try_from) + .collect::, _>>()?; + + let deployments_with_models = models + .iter() + .map(|model| &model.deployment) + .collect::>(); + + let deployments_without_models = deployments + .iter() + .filter(|deployment| !deployments_with_models.contains(deployment)) + .collect::>(); + + // Query the global cost model + if let Some(global_model) = global_cost_model(pool).await? { + // For all deployments that have a cost model, merge it with the global one + models = models + .into_iter() + .map(|model| merge_global(model, &global_model)) + // Inject a cost model for all deployments that don't have one + .chain( + deployments_without_models + .into_iter() + .map(|deployment| CostModel { + deployment: deployment.to_owned(), + model: global_model.model.clone(), + variables: global_model.variables.clone(), + }), + ) + .collect(); + } + + Ok(models) +} + +/// Make database query for a cost model indexed by deployment id +pub async fn cost_model( + pool: &PgPool, + deployment: &DeploymentId, +) -> Result, anyhow::Error> { + let model = sqlx::query_as!( + DbCostModel, + r#" + SELECT deployment, model, variables + FROM "CostModels" + WHERE deployment = $1 + AND deployment != 'global' + "#, + format!("{:#x}", deployment), + ) + .fetch_optional(pool) + .await? + .map(CostModel::try_from) + .transpose()?; + + let global_model = global_cost_model(pool).await?; + + Ok(match (model, global_model) { + // If we have no global model, return whatever we can find for the deployment + (None, None) => None, + (Some(model), None) => Some(model), + + // If we have a cost model and a global cost model, merge them + (Some(model), Some(global_model)) => Some(merge_global(model, &global_model)), + + // If we have only a global model, use that + (None, Some(global_model)) => Some(CostModel { + deployment: deployment.to_owned(), + model: global_model.model, + variables: global_model.variables, + }), + }) +} + +/// Query global cost model +pub(crate) async fn global_cost_model(pool: &PgPool) -> Result, anyhow::Error> { + sqlx::query_as!( + DbCostModel, + r#" + SELECT deployment, model, variables + FROM "CostModels" + WHERE deployment = $1 + "#, + "global" + ) + .fetch_optional(pool) + .await + .map_err(Into::into) +} + +fn merge_global(model: CostModel, global_model: &DbCostModel) -> CostModel { + CostModel { + deployment: model.deployment, + model: model.model.clone().or(global_model.model.clone()), + variables: model.variables.clone().or(global_model.variables.clone()), + } +} + +#[cfg(test)] +pub(crate) mod test { + + use std::str::FromStr; + + use sqlx::PgPool; + + use super::*; + + pub async fn add_cost_models(pool: &PgPool, models: Vec) { + for model in models { + sqlx::query!( + r#" + INSERT INTO "CostModels" (deployment, model) + VALUES ($1, $2); + "#, + model.deployment, + model.model, + ) + .execute(pool) + .await + .expect("Create test instance in db"); + } + } + + pub fn to_db_models(models: Vec) -> Vec { + models.into_iter().map(DbCostModel::from).collect() + } + + pub fn global_cost_model() -> DbCostModel { + DbCostModel { + deployment: "global".to_string(), + model: Some("default => 0.00001;".to_string()), + variables: None, + } + } + + pub fn test_data() -> Vec { + vec![ + CostModel { + deployment: "0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + .parse() + .unwrap(), + model: None, + variables: None, + }, + CostModel { + deployment: "0xbd499f7673ca32ef4a642207a8bebdd0fb03888cf2678b298438e3a1ae5206ea" + .parse() + .unwrap(), + model: Some("default => 0.00025;".to_string()), + variables: None, + }, + CostModel { + deployment: "0xcccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc" + .parse() + .unwrap(), + model: Some("default => 0.00012;".to_string()), + variables: None, + }, + ] + } + + #[sqlx::test(migrations = "../migrations")] + async fn success_cost_models(pool: PgPool) { + let test_models = test_data(); + let test_deployments = test_models + .iter() + .map(|model| model.deployment) + .collect::>(); + + add_cost_models(&pool, to_db_models(test_models.clone())).await; + + // First test: query without deployment filter + let models = cost_models(&pool, &[]) + .await + .expect("cost models query without deployment filter"); + + // We expect as many models as we have in the test data + assert_eq!(models.len(), test_models.len()); + + // We expect models for all test deployments to be present and + // identical to the test data + for test_deployment in test_deployments.iter() { + let test_model = test_models + .iter() + .find(|model| &model.deployment == test_deployment) + .expect("finding cost model for test deployment in test data"); + + let model = models + .iter() + .find(|model| &model.deployment == test_deployment) + .expect("finding cost model for test deployment in query result"); + + assert_eq!(test_model.model, model.model); + } + + // Second test: query with a deployment filter + let sample_deployments = vec![ + test_models.first().unwrap().deployment, + test_models.get(1).unwrap().deployment, + ]; + let models = cost_models(&pool, &sample_deployments) + .await + .expect("cost models query with deployment filter"); + + // Expect two cost mdoels to be returned + assert_eq!(models.len(), sample_deployments.len()); + + // Expect both returned deployments to be identical to the test data + for test_deployment in sample_deployments.iter() { + let test_model = test_models + .iter() + .find(|model| &model.deployment == test_deployment) + .expect("finding cost model for test deployment in test data"); + + let model = models + .iter() + .find(|model| &model.deployment == test_deployment) + .expect("finding cost model for test deployment in query result"); + + assert_eq!(test_model.model, model.model); + } + } + + #[sqlx::test(migrations = "../migrations")] + async fn global_fallback_cost_models(pool: PgPool) { + let test_models = test_data(); + let test_deployments = test_models + .iter() + .map(|model| model.deployment) + .collect::>(); + let global_model = global_cost_model(); + + add_cost_models(&pool, to_db_models(test_models.clone())).await; + add_cost_models(&pool, vec![global_model.clone()]).await; + + // First test: fetch cost models without filtering by deployment + let models = cost_models(&pool, &[]) + .await + .expect("cost models query without deployments filter"); + + // Since we've defined 3 cost models and we did not provide a filter, we + // expect all of them to be returned except for the global cost model + assert_eq!(models.len(), test_models.len()); + + // Expect all test deployments to be present in the query result + for test_deployment in test_deployments.iter() { + let test_model = test_models + .iter() + .find(|model| &model.deployment == test_deployment) + .expect("finding cost model for deployment in test data"); + + let model = models + .iter() + .find(|model| &model.deployment == test_deployment) + .expect("finding cost model for deployment in query result"); + + if test_model.model.is_some() { + // If the test model has a model definition, we expect that to be returned + assert_eq!(model.model, test_model.model); + } else { + // If the test model has no model definition, we expect the global + // model definition to be returned + assert_eq!(model.model, global_model.model); + } + } + + // Second test: fetch cost models, filtering by the first two deployment IDs + let sample_deployments = vec![ + test_models.first().unwrap().deployment, + test_models.get(1).unwrap().deployment, + ]; + let models = dbg!(cost_models(&pool, &sample_deployments).await) + .expect("cost models query with deployments filter"); + + // We've filtered by two deployment IDs and are expecting two cost models to be returned + assert_eq!(models.len(), sample_deployments.len()); + + for test_deployment in sample_deployments { + let test_model = test_models + .iter() + .find(|model| model.deployment == test_deployment) + .expect("finding cost model for deployment in test data"); + + let model = models + .iter() + .find(|model| model.deployment == test_deployment) + .expect("finding cost model for deployment in query result"); + + if test_model.model.is_some() { + // If the test model has a model definition, we expect that to be returned + assert_eq!(model.model, test_model.model); + } else { + // If the test model has no model definition, we expect the global + // model definition to be returned + assert_eq!(model.model, global_model.model); + } + } + + // Third test: query for missing cost model + let missing_deployment = + DeploymentId::from_str("Qmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa").unwrap(); + let models = cost_models(&pool, &[missing_deployment]) + .await + .expect("cost models query for missing deployment"); + + // The deployment may be missing in the database but we have a global model + // and expect that to be returned, with the missing deployment ID + let missing_model = models + .iter() + .find(|m| m.deployment == missing_deployment) + .expect("finding missing deployment"); + assert_eq!(missing_model.model, global_model.model); + } + + #[sqlx::test(migrations = "../migrations")] + async fn success_cost_model(pool: PgPool) { + add_cost_models(&pool, to_db_models(test_data())).await; + + let deployment_id_from_bytes = DeploymentId::from_str( + "0xbd499f7673ca32ef4a642207a8bebdd0fb03888cf2678b298438e3a1ae5206ea", + ) + .unwrap(); + let deployment_id_from_hash = + DeploymentId::from_str("Qmb5Ysp5oCUXhLA8NmxmYKDAX2nCMnh7Vvb5uffb9n5vss").unwrap(); + + assert_eq!(deployment_id_from_bytes, deployment_id_from_hash); + + let model = cost_model(&pool, &deployment_id_from_bytes) + .await + .expect("cost model query") + .expect("cost model for deployment"); + + assert_eq!(model.deployment, deployment_id_from_hash); + assert_eq!(model.model, Some("default => 0.00025;".to_string())); + } + + #[sqlx::test(migrations = "../migrations")] + async fn global_fallback_cost_model(pool: PgPool) { + let test_models = test_data(); + let global_model = global_cost_model(); + + add_cost_models(&pool, to_db_models(test_models.clone())).await; + add_cost_models(&pool, vec![global_model.clone()]).await; + + // Test that the behavior is correct for existing deployments + for test_model in test_models { + let model = cost_model(&pool, &test_model.deployment) + .await + .expect("cost model query") + .expect("global cost model fallback"); + + assert_eq!(model.deployment, test_model.deployment); + + if test_model.model.is_some() { + // If the test model has a model definition, we expect that to be returned + assert_eq!(model.model, test_model.model); + } else { + // If the test model has no model definition, we expect the global + // model definition to be returned + assert_eq!(model.model, global_model.model); + } + } + + // Test that querying a non-existing deployment returns the default cost model + let missing_deployment = + DeploymentId::from_str("Qmnononononononononononononononononononononono").unwrap(); + let model = cost_model(&pool, &missing_deployment) + .await + .expect("cost model query") + .expect("global cost model fallback"); + assert_eq!(model.deployment, missing_deployment); + assert_eq!(model.model, global_model.model); + } +} diff --git a/common/src/indexer_service/http/indexer_service.rs b/common/src/indexer_service/http/indexer_service.rs index e3aca9640..887db7f82 100644 --- a/common/src/indexer_service/http/indexer_service.rs +++ b/common/src/indexer_service/http/indexer_service.rs @@ -66,15 +66,14 @@ pub enum AttestationOutput { #[async_trait] pub trait IndexerServiceImpl { type Error: std::error::Error; - type Request: DeserializeOwned + Send + Debug + Serialize; type Response: IndexerServiceResponse + Sized; type State: Send + Sync; - async fn process_request( + async fn process_request( &self, manifest_id: DeploymentId, - request: Self::Request, - ) -> Result<(Self::Request, Self::Response), Self::Error>; + request: Request, + ) -> Result<(Request, Self::Response), Self::Error>; } #[derive(Debug, Error)] diff --git a/common/src/indexer_service/http/request_handler.rs b/common/src/indexer_service/http/request_handler.rs index b7f23d681..c3a7a6486 100644 --- a/common/src/indexer_service/http/request_handler.rs +++ b/common/src/indexer_service/http/request_handler.rs @@ -13,10 +13,13 @@ use axum_extra::TypedHeader; use lazy_static::lazy_static; use prometheus::{register_counter_vec, register_histogram_vec, CounterVec, HistogramVec}; use reqwest::StatusCode; +use tap_core::receipt::Context; use thegraph_core::DeploymentId; use tracing::trace; -use crate::indexer_service::http::IndexerServiceResponse; +use serde_json::value::RawValue; + +use crate::{indexer_service::http::IndexerServiceResponse, tap::AgoraQuery}; use super::{ indexer_service::{AttestationOutput, IndexerServiceError, IndexerServiceState}, @@ -77,7 +80,13 @@ where { trace!("Handling request for deployment `{manifest_id}`"); - let request = + #[derive(Debug, serde::Deserialize, serde::Serialize)] + pub struct QueryBody { + pub query: String, + pub variables: Option>, + } + + let request: QueryBody = serde_json::from_slice(&body).map_err(|e| IndexerServiceError::InvalidRequest(e.into()))?; let Some(receipt) = receipt.into_signed_receipt() else { @@ -110,6 +119,18 @@ where let allocation_id = receipt.message.allocation_id; + let variables = request + .variables + .as_ref() + .map(ToString::to_string) + .unwrap_or_default(); + let mut ctx = Context::new(); + ctx.insert(AgoraQuery { + deployment_id: manifest_id, + query: request.query.clone(), + variables, + }); + // recover the signer address // get escrow accounts from eventual // return sender from signer @@ -141,7 +162,7 @@ where // Verify the receipt and store it in the database state .tap_manager - .verify_and_store_receipt(receipt) + .verify_and_store_receipt(&ctx, receipt) .await .inspect_err(|_| { FAILED_RECEIPT diff --git a/common/src/lib.rs b/common/src/lib.rs index 63470d754..7b20b85c2 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -4,6 +4,7 @@ pub mod address; pub mod allocations; pub mod attestations; +pub mod cost_model; pub mod escrow_accounts; pub mod graphql; pub mod indexer_service; diff --git a/common/src/tap.rs b/common/src/tap.rs index 2d23ca922..151281b43 100644 --- a/common/src/tap.rs +++ b/common/src/tap.rs @@ -6,6 +6,7 @@ use crate::tap::checks::deny_list_check::DenyListCheck; use crate::tap::checks::receipt_max_val_check::ReceiptMaxValueCheck; use crate::tap::checks::sender_balance_check::SenderBalanceCheck; use crate::tap::checks::timestamp_check::TimestampCheck; +use crate::tap::checks::value_check::MinimumValue; use crate::{escrow_accounts::EscrowAccounts, prelude::Allocation}; use alloy::dyn_abi::Eip712Domain; use alloy::primitives::Address; @@ -24,6 +25,9 @@ use tracing::error; mod checks; mod receipt_store; +pub use checks::value_check::AgoraQuery; + +#[derive(Clone)] pub struct IndexerTapContext { domain_separator: Arc, receipt_producer: Sender, @@ -52,8 +56,9 @@ impl IndexerTapContext { domain_separator.clone(), )), Arc::new(TimestampCheck::new(timestamp_error_tolerance)), - Arc::new(DenyListCheck::new(pgpool, escrow_accounts, domain_separator).await), + Arc::new(DenyListCheck::new(pgpool.clone(), escrow_accounts, domain_separator).await), Arc::new(ReceiptMaxValueCheck::new(receipt_max_value)), + Arc::new(MinimumValue::new(pgpool).await), ] } diff --git a/common/src/tap/checks.rs b/common/src/tap/checks.rs index 4da34f594..d4e23e9b2 100644 --- a/common/src/tap/checks.rs +++ b/common/src/tap/checks.rs @@ -6,3 +6,4 @@ pub mod deny_list_check; pub mod receipt_max_val_check; pub mod sender_balance_check; pub mod timestamp_check; +pub mod value_check; diff --git a/common/src/tap/checks/allocation_eligible.rs b/common/src/tap/checks/allocation_eligible.rs index ee4e752c7..a2a106c64 100644 --- a/common/src/tap/checks/allocation_eligible.rs +++ b/common/src/tap/checks/allocation_eligible.rs @@ -27,7 +27,11 @@ impl AllocationEligible { } #[async_trait::async_trait] impl Check for AllocationEligible { - async fn check(&self, receipt: &ReceiptWithState) -> CheckResult { + async fn check( + &self, + _: &tap_core::receipt::Context, + receipt: &ReceiptWithState, + ) -> CheckResult { let allocation_id = receipt.signed_receipt().message.allocation_id; if !self .indexer_allocations diff --git a/common/src/tap/checks/deny_list_check.rs b/common/src/tap/checks/deny_list_check.rs index 469e6a848..cf751d29a 100644 --- a/common/src/tap/checks/deny_list_check.rs +++ b/common/src/tap/checks/deny_list_check.rs @@ -150,7 +150,11 @@ impl DenyListCheck { #[async_trait::async_trait] impl Check for DenyListCheck { - async fn check(&self, receipt: &ReceiptWithState) -> CheckResult { + async fn check( + &self, + _: &tap_core::receipt::Context, + receipt: &ReceiptWithState, + ) -> CheckResult { let receipt_signer = receipt .signed_receipt() .recover_signer(&self.domain_separator) @@ -195,7 +199,7 @@ mod tests { use std::str::FromStr; use alloy::hex::ToHexExt; - use tap_core::receipt::ReceiptWithState; + use tap_core::receipt::{Context, ReceiptWithState}; use crate::test_vectors::{self, create_signed_receipt, TAP_SENDER}; @@ -241,7 +245,10 @@ mod tests { let checking_receipt = ReceiptWithState::new(signed_receipt); // Check that the receipt is rejected - assert!(deny_list_check.check(&checking_receipt).await.is_err()); + assert!(deny_list_check + .check(&Context::new(), &checking_receipt) + .await + .is_err()); } #[sqlx::test(migrations = "../migrations")] @@ -255,7 +262,10 @@ mod tests { // Check that the receipt is valid let checking_receipt = ReceiptWithState::new(signed_receipt); - deny_list_check.check(&checking_receipt).await.unwrap(); + deny_list_check + .check(&Context::new(), &checking_receipt) + .await + .unwrap(); // Add the sender to the denylist sqlx::query!( @@ -271,7 +281,10 @@ mod tests { // Check that the receipt is rejected tokio::time::sleep(std::time::Duration::from_millis(100)).await; - assert!(deny_list_check.check(&checking_receipt).await.is_err()); + assert!(deny_list_check + .check(&Context::new(), &checking_receipt) + .await + .is_err()); // Remove the sender from the denylist sqlx::query!( @@ -287,6 +300,9 @@ mod tests { // Check that the receipt is valid again tokio::time::sleep(std::time::Duration::from_millis(100)).await; - deny_list_check.check(&checking_receipt).await.unwrap(); + deny_list_check + .check(&Context::new(), &checking_receipt) + .await + .unwrap(); } } diff --git a/common/src/tap/checks/receipt_max_val_check.rs b/common/src/tap/checks/receipt_max_val_check.rs index 1dc7c26b6..8d480fff3 100644 --- a/common/src/tap/checks/receipt_max_val_check.rs +++ b/common/src/tap/checks/receipt_max_val_check.rs @@ -20,7 +20,11 @@ impl ReceiptMaxValueCheck { #[async_trait::async_trait] impl Check for ReceiptMaxValueCheck { - async fn check(&self, receipt: &ReceiptWithState) -> CheckResult { + async fn check( + &self, + _: &tap_core::receipt::Context, + receipt: &ReceiptWithState, + ) -> CheckResult { let receipt_value = receipt.signed_receipt().message.value; if receipt_value < self.receipt_max_value { @@ -42,6 +46,7 @@ mod tests { use std::str::FromStr; use std::time::Duration; use std::time::SystemTime; + use tap_core::receipt::Context; use super::*; use crate::tap::Eip712Domain; @@ -92,20 +97,29 @@ mod tests { async fn test_receipt_lower_than_limit() { let signed_receipt = create_signed_receipt_with_custom_value(RECEIPT_LIMIT - 1); let timestamp_check = ReceiptMaxValueCheck::new(RECEIPT_LIMIT); - assert!(timestamp_check.check(&signed_receipt).await.is_ok()); + assert!(timestamp_check + .check(&Context::new(), &signed_receipt) + .await + .is_ok()); } #[tokio::test] async fn test_receipt_higher_than_limit() { let signed_receipt = create_signed_receipt_with_custom_value(RECEIPT_LIMIT + 1); let timestamp_check = ReceiptMaxValueCheck::new(RECEIPT_LIMIT); - assert!(timestamp_check.check(&signed_receipt).await.is_err()); + assert!(timestamp_check + .check(&Context::new(), &signed_receipt) + .await + .is_err()); } #[tokio::test] async fn test_receipt_same_as_limit() { let signed_receipt = create_signed_receipt_with_custom_value(RECEIPT_LIMIT); let timestamp_check = ReceiptMaxValueCheck::new(RECEIPT_LIMIT); - assert!(timestamp_check.check(&signed_receipt).await.is_err()); + assert!(timestamp_check + .check(&Context::new(), &signed_receipt) + .await + .is_err()); } } diff --git a/common/src/tap/checks/sender_balance_check.rs b/common/src/tap/checks/sender_balance_check.rs index b0269e71a..08a822759 100644 --- a/common/src/tap/checks/sender_balance_check.rs +++ b/common/src/tap/checks/sender_balance_check.rs @@ -30,7 +30,11 @@ impl SenderBalanceCheck { #[async_trait::async_trait] impl Check for SenderBalanceCheck { - async fn check(&self, receipt: &ReceiptWithState) -> CheckResult { + async fn check( + &self, + _: &tap_core::receipt::Context, + receipt: &ReceiptWithState, + ) -> CheckResult { let escrow_accounts_snapshot = self.escrow_accounts.value_immediate().unwrap_or_default(); let receipt_signer = receipt diff --git a/common/src/tap/checks/timestamp_check.rs b/common/src/tap/checks/timestamp_check.rs index 7504433e9..219742928 100644 --- a/common/src/tap/checks/timestamp_check.rs +++ b/common/src/tap/checks/timestamp_check.rs @@ -23,7 +23,11 @@ impl TimestampCheck { #[async_trait::async_trait] impl Check for TimestampCheck { - async fn check(&self, receipt: &ReceiptWithState) -> CheckResult { + async fn check( + &self, + _: &tap_core::receipt::Context, + receipt: &ReceiptWithState, + ) -> CheckResult { let timestamp_now = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .map_err(|e| CheckError::Failed(e.into()))?; @@ -54,7 +58,7 @@ mod tests { use super::*; use crate::tap::Eip712Domain; use tap_core::{ - receipt::{checks::Check, state::Checking, Receipt, ReceiptWithState}, + receipt::{checks::Check, state::Checking, Context, Receipt, ReceiptWithState}, signed_message::EIP712SignedMessage, tap_eip712_domain, }; @@ -98,7 +102,10 @@ mod tests { let timestamp_ns = timestamp as u64; let signed_receipt = create_signed_receipt_with_custom_timestamp(timestamp_ns); let timestamp_check = TimestampCheck::new(Duration::from_secs(30)); - assert!(timestamp_check.check(&signed_receipt).await.is_ok()); + assert!(timestamp_check + .check(&Context::new(), &signed_receipt) + .await + .is_ok()); } #[tokio::test] @@ -111,7 +118,10 @@ mod tests { let timestamp_ns = timestamp as u64; let signed_receipt = create_signed_receipt_with_custom_timestamp(timestamp_ns); let timestamp_check = TimestampCheck::new(Duration::from_secs(30)); - assert!(timestamp_check.check(&signed_receipt).await.is_err()); + assert!(timestamp_check + .check(&Context::new(), &signed_receipt) + .await + .is_err()); } #[tokio::test] @@ -124,6 +134,9 @@ mod tests { let timestamp_ns = timestamp as u64; let signed_receipt = create_signed_receipt_with_custom_timestamp(timestamp_ns); let timestamp_check = TimestampCheck::new(Duration::from_secs(30)); - assert!(timestamp_check.check(&signed_receipt).await.is_err()); + assert!(timestamp_check + .check(&Context::new(), &signed_receipt) + .await + .is_err()); } } diff --git a/common/src/tap/checks/value_check.rs b/common/src/tap/checks/value_check.rs new file mode 100644 index 000000000..daf9237c3 --- /dev/null +++ b/common/src/tap/checks/value_check.rs @@ -0,0 +1,549 @@ +// Copyright 2023-, GraphOps and Semiotic Labs. +// SPDX-License-Identifier: Apache-2.0 + +use ::cost_model::CostModel; +use anyhow::anyhow; +use bigdecimal::ToPrimitive; +use sqlx::{ + postgres::{PgListener, PgNotification}, + PgPool, +}; +use std::{ + collections::HashMap, + str::FromStr, + sync::{Arc, RwLock}, +}; +use thegraph_core::DeploymentId; +use tracing::error; + +use tap_core::receipt::{ + checks::{Check, CheckError, CheckResult}, + state::Checking, + Context, ReceiptWithState, +}; + +use crate::cost_model; + +// we only accept receipts with minimal 1 wei grt +const MINIMAL_VALUE: u128 = 1; + +/// Represents a query that can be checked against an agora model +/// +/// It contains the deployment_id to check which agora model evaluate +/// and also the query and variables to perform the evaluation +pub struct AgoraQuery { + pub deployment_id: DeploymentId, + pub query: String, + pub variables: String, +} + +type CostModelMap = Arc>>; +type GlobalModel = Arc>>; + +/// Represents the check for minimum for a receipt +/// +/// It contains all information needed in memory to +/// make it as fast as possible +pub struct MinimumValue { + cost_model_map: CostModelMap, + global_model: GlobalModel, + watcher_cancel_token: tokio_util::sync::CancellationToken, +} + +struct CostModelWatcher { + pgpool: PgPool, + + cost_models: CostModelMap, + global_model: GlobalModel, +} + +impl CostModelWatcher { + async fn cost_models_watcher( + pgpool: PgPool, + mut pglistener: PgListener, + cost_models: CostModelMap, + global_model: GlobalModel, + cancel_token: tokio_util::sync::CancellationToken, + ) { + let cost_model_watcher = CostModelWatcher { + pgpool, + global_model, + cost_models, + }; + + loop { + tokio::select! { + _ = cancel_token.cancelled() => { + break; + } + Ok(pg_notification) = pglistener.recv() => { + cost_model_watcher.new_notification( + pg_notification, + ).await; + } + } + } + } + + async fn new_notification(&self, pg_notification: PgNotification) { + let payload = pg_notification.payload(); + let cost_model_notification: Result = + serde_json::from_str(payload); + + match cost_model_notification { + Ok(CostModelNotification::Insert { + deployment, + model, + variables, + }) => self.handle_insert(deployment, model, variables), + Ok(CostModelNotification::Delete { deployment }) => self.handle_delete(deployment), + // UPDATE and TRUNCATE are not expected to happen. Reload the entire cost + // model cache. + Err(_) => self.handle_unexpected_notification(payload).await, + } + } + + fn handle_insert(&self, deployment: String, model: String, variables: String) { + let model = compile_cost_model(model, variables).unwrap(); + + match deployment.as_str() { + "global" => { + *self.global_model.write().unwrap() = Some(model); + } + deployment_id => match DeploymentId::from_str(deployment_id) { + Ok(deployment_id) => { + let mut cost_model_write = self.cost_models.write().unwrap(); + cost_model_write.insert(deployment_id, model); + } + Err(_) => { + error!( + "Received insert request for an invalid deployment_id: {}", + deployment_id + ) + } + }, + }; + } + + fn handle_delete(&self, deployment: String) { + match deployment.as_str() { + "global" => { + *self.global_model.write().unwrap() = None; + } + deployment_id => match DeploymentId::from_str(deployment_id) { + Ok(deployment_id) => { + self.cost_models.write().unwrap().remove(&deployment_id); + } + Err(_) => { + error!( + "Received delete request for an invalid deployment_id: {}", + deployment_id + ) + } + }, + }; + } + + async fn handle_unexpected_notification(&self, payload: &str) { + error!( + "Received an unexpected cost model table notification: {}. Reloading entire \ + cost model.", + payload + ); + + MinimumValue::value_check_reload( + &self.pgpool, + self.cost_models.clone(), + self.global_model.clone(), + ) + .await + .expect("should be able to reload cost models") + } +} + +impl Drop for MinimumValue { + fn drop(&mut self) { + // Clean shutdown for the minimum value check + // Though since it's not a critical task, we don't wait for it to finish (join). + self.watcher_cancel_token.cancel(); + } +} + +impl MinimumValue { + pub async fn new(pgpool: PgPool) -> Self { + let cost_model_map: CostModelMap = Default::default(); + let global_model: GlobalModel = Default::default(); + Self::value_check_reload(&pgpool, cost_model_map.clone(), global_model.clone()) + .await + .expect("should be able to reload cost models"); + + let mut pglistener = PgListener::connect_with(&pgpool.clone()).await.unwrap(); + pglistener + .listen("cost_models_update_notification") + .await + .expect( + "should be able to subscribe to Postgres Notify events on the channel \ + 'cost_models_update_notification'", + ); + + let watcher_cancel_token = tokio_util::sync::CancellationToken::new(); + tokio::spawn(CostModelWatcher::cost_models_watcher( + pgpool.clone(), + pglistener, + cost_model_map.clone(), + global_model.clone(), + watcher_cancel_token.clone(), + )); + + Self { + global_model, + cost_model_map, + watcher_cancel_token, + } + } + + fn expected_value(&self, agora_query: &AgoraQuery) -> anyhow::Result { + // get agora model for the deployment_id + let model = self.cost_model_map.read().unwrap(); + let subgraph_model = model.get(&agora_query.deployment_id); + let global_model = self.global_model.read().unwrap(); + + let expected_value = match (subgraph_model, global_model.as_ref()) { + (Some(model), _) | (_, Some(model)) => model + .cost(&agora_query.query, &agora_query.variables) + .map(|fee| fee.to_u128()) + .ok() + .flatten(), + _ => None, + }; + + Ok(expected_value.unwrap_or(MINIMAL_VALUE)) + } + + async fn value_check_reload( + pgpool: &PgPool, + cost_model_map: CostModelMap, + global_model: GlobalModel, + ) -> anyhow::Result<()> { + let models = sqlx::query!( + r#" + SELECT deployment, model, variables + FROM "CostModels" + WHERE deployment != 'global' + ORDER BY deployment ASC + "# + ) + .fetch_all(pgpool) + .await?; + let models = models + .into_iter() + .flat_map(|record| { + let deployment_id = DeploymentId::from_str(&record.deployment).ok()?; + let model = compile_cost_model( + record.model?, + record.variables.map(|v| v.to_string()).unwrap_or_default(), + ) + .ok()?; + Some((deployment_id, model)) + }) + .collect::>(); + + *cost_model_map.write().unwrap() = models; + + *global_model.write().unwrap() = + cost_model::global_cost_model(pgpool) + .await? + .and_then(|model| { + compile_cost_model( + model.model.unwrap_or_default(), + model.variables.map(|v| v.to_string()).unwrap_or_default(), + ) + .ok() + }); + + Ok(()) + } +} + +#[async_trait::async_trait] +impl Check for MinimumValue { + async fn check(&self, ctx: &Context, receipt: &ReceiptWithState) -> CheckResult { + let agora_query = ctx + .get() + .ok_or(CheckError::Failed(anyhow!("Could not find agora query")))?; + + let expected_value = self + .expected_value(agora_query) + .map_err(CheckError::Failed)?; + + // get value + let value = receipt.signed_receipt().message.value; + + let should_accept = value >= expected_value; + + tracing::trace!( + value, + expected_value, + should_accept, + "Evaluating mininum query fee." + ); + + if should_accept { + Ok(()) + } else { + return Err(CheckError::Failed(anyhow!( + "Query receipt does not have the minimum value. Expected value: {}. Received value: {}.", + expected_value, value, + ))); + } + } +} + +fn compile_cost_model(model: String, variables: String) -> anyhow::Result { + if model.len() > (1 << 16) { + return Err(anyhow!("CostModelTooLarge")); + } + let model = CostModel::compile(&model, &variables)?; + Ok(model) +} + +#[derive(serde::Deserialize)] +#[serde(tag = "tg_op")] +enum CostModelNotification { + #[serde(rename = "INSERT")] + Insert { + deployment: String, + model: String, + variables: String, + }, + #[serde(rename = "DELETE")] + Delete { deployment: String }, +} + +#[cfg(test)] +mod tests { + use alloy::primitives::{address, Address}; + use std::time::Duration; + + use sqlx::PgPool; + use tap_core::receipt::{checks::Check, Context, ReceiptWithState}; + use tokio::time::sleep; + + use crate::{ + cost_model::test::{add_cost_models, global_cost_model, to_db_models}, + tap::AgoraQuery, + test_vectors::create_signed_receipt, + }; + + use super::MinimumValue; + + #[sqlx::test(migrations = "../migrations")] + async fn initialize_check(pgpool: PgPool) { + let check = MinimumValue::new(pgpool).await; + assert_eq!(check.cost_model_map.read().unwrap().len(), 0); + } + + #[sqlx::test(migrations = "../migrations")] + async fn should_initialize_check_with_models(pgpool: PgPool) { + // insert 2 cost models for different deployment_id + let test_models = crate::cost_model::test::test_data(); + + add_cost_models(&pgpool, to_db_models(test_models.clone())).await; + + let check = MinimumValue::new(pgpool).await; + assert_eq!(check.cost_model_map.read().unwrap().len(), 2); + + // no global model + assert!(check.global_model.read().unwrap().is_none()); + } + + #[sqlx::test(migrations = "../migrations")] + async fn should_watch_model_insert(pgpool: PgPool) { + let check = MinimumValue::new(pgpool.clone()).await; + assert_eq!(check.cost_model_map.read().unwrap().len(), 0); + + // insert 2 cost models for different deployment_id + let test_models = crate::cost_model::test::test_data(); + add_cost_models(&pgpool, to_db_models(test_models.clone())).await; + sleep(Duration::from_millis(200)).await; + + assert_eq!( + check.cost_model_map.read().unwrap().len(), + test_models.len() + ); + } + + #[sqlx::test(migrations = "../migrations")] + async fn should_watch_model_remove(pgpool: PgPool) { + // insert 2 cost models for different deployment_id + let test_models = crate::cost_model::test::test_data(); + add_cost_models(&pgpool, to_db_models(test_models.clone())).await; + + let check = MinimumValue::new(pgpool.clone()).await; + assert_eq!(check.cost_model_map.read().unwrap().len(), 2); + + // remove + sqlx::query!(r#"DELETE FROM "CostModels""#) + .execute(&pgpool) + .await + .unwrap(); + + sleep(Duration::from_millis(200)).await; + + assert_eq!(check.cost_model_map.read().unwrap().len(), 0); + } + + #[sqlx::test(migrations = "../migrations")] + async fn should_start_global_model(pgpool: PgPool) { + let global_model = global_cost_model(); + add_cost_models(&pgpool, vec![global_model.clone()]).await; + + let check = MinimumValue::new(pgpool.clone()).await; + assert!(check.global_model.read().unwrap().is_some()); + } + + #[sqlx::test(migrations = "../migrations")] + async fn should_watch_global_model(pgpool: PgPool) { + let check = MinimumValue::new(pgpool.clone()).await; + + let global_model = global_cost_model(); + add_cost_models(&pgpool, vec![global_model.clone()]).await; + sleep(Duration::from_millis(10)).await; + + assert!(check.global_model.read().unwrap().is_some()); + } + + #[sqlx::test(migrations = "../migrations")] + async fn should_remove_global_model(pgpool: PgPool) { + let global_model = global_cost_model(); + add_cost_models(&pgpool, vec![global_model.clone()]).await; + + let check = MinimumValue::new(pgpool.clone()).await; + assert!(check.global_model.read().unwrap().is_some()); + + sqlx::query!(r#"DELETE FROM "CostModels""#) + .execute(&pgpool) + .await + .unwrap(); + + sleep(Duration::from_millis(10)).await; + + assert_eq!(check.cost_model_map.read().unwrap().len(), 0); + } + + const ALLOCATION_ID: Address = address!("deadbeefcafebabedeadbeefcafebabedeadbeef"); + + #[sqlx::test(migrations = "../migrations")] + async fn should_check_minimal_value(pgpool: PgPool) { + // insert cost models for different deployment_id + let test_models = crate::cost_model::test::test_data(); + + add_cost_models(&pgpool, to_db_models(test_models.clone())).await; + + let check = MinimumValue::new(pgpool).await; + + let deployment_id = test_models[0].deployment; + let mut ctx = Context::new(); + ctx.insert(AgoraQuery { + deployment_id, + query: "query { a(skip: 10), b(bob: 5) }".into(), + variables: "".into(), + }); + + let signed_receipt = create_signed_receipt(ALLOCATION_ID, u64::MAX, u64::MAX, 0).await; + let receipt = ReceiptWithState::new(signed_receipt); + + assert!( + check.check(&ctx, &receipt).await.is_err(), + "Should deny if value is 0 for any query" + ); + + let signed_receipt = create_signed_receipt(ALLOCATION_ID, u64::MAX, u64::MAX, 1).await; + let receipt = ReceiptWithState::new(signed_receipt); + assert!( + check.check(&ctx, &receipt).await.is_ok(), + "Should accept if value is more than 0 for any query" + ); + + let deployment_id = test_models[1].deployment; + let mut ctx = Context::new(); + ctx.insert(AgoraQuery { + deployment_id, + query: "query { a(skip: 10), b(bob: 5) }".into(), + variables: "".into(), + }); + let minimal_value = 500000000000000; + + let signed_receipt = + create_signed_receipt(ALLOCATION_ID, u64::MAX, u64::MAX, minimal_value - 1).await; + let receipt = ReceiptWithState::new(signed_receipt); + assert!( + check.check(&ctx, &receipt).await.is_err(), + "Should require minimal value" + ); + + let signed_receipt = + create_signed_receipt(ALLOCATION_ID, u64::MAX, u64::MAX, 500000000000000).await; + let receipt = ReceiptWithState::new(signed_receipt); + check + .check(&ctx, &receipt) + .await + .expect("should accept equals minimal"); + + let signed_receipt = + create_signed_receipt(ALLOCATION_ID, u64::MAX, u64::MAX, minimal_value + 1).await; + let receipt = ReceiptWithState::new(signed_receipt); + check + .check(&ctx, &receipt) + .await + .expect("should accept more than minimal"); + } + + #[sqlx::test(migrations = "../migrations")] + async fn should_check_using_global(pgpool: PgPool) { + // insert cost models for different deployment_id + let test_models = crate::cost_model::test::test_data(); + let global_model = global_cost_model(); + + add_cost_models(&pgpool, vec![global_model.clone()]).await; + add_cost_models(&pgpool, to_db_models(test_models.clone())).await; + + let check = MinimumValue::new(pgpool).await; + + let deployment_id = test_models[0].deployment; + let mut ctx = Context::new(); + ctx.insert(AgoraQuery { + deployment_id, + query: "query { a(skip: 10), b(bob: 5) }".into(), + variables: "".into(), + }); + + let minimal_global_value = 20000000000000; + let signed_receipt = + create_signed_receipt(ALLOCATION_ID, u64::MAX, u64::MAX, minimal_global_value - 1) + .await; + let receipt = ReceiptWithState::new(signed_receipt); + + assert!( + check.check(&ctx, &receipt).await.is_err(), + "Should deny less than global" + ); + + let signed_receipt = + create_signed_receipt(ALLOCATION_ID, u64::MAX, u64::MAX, minimal_global_value).await; + let receipt = ReceiptWithState::new(signed_receipt); + check + .check(&ctx, &receipt) + .await + .expect("should accept equals global"); + + let signed_receipt = + create_signed_receipt(ALLOCATION_ID, u64::MAX, u64::MAX, minimal_global_value + 1) + .await; + let receipt = ReceiptWithState::new(signed_receipt); + check + .check(&ctx, &receipt) + .await + .expect("should accept more than global"); + } +} diff --git a/migrations/20241024191258_add_cost_model_notification.down.sql b/migrations/20241024191258_add_cost_model_notification.down.sql new file mode 100644 index 000000000..16d32dfcc --- /dev/null +++ b/migrations/20241024191258_add_cost_model_notification.down.sql @@ -0,0 +1,4 @@ +-- Add down migration script here +DROP TRIGGER IF EXISTS cost_models_update ON "CostModels" CASCADE; + +DROP FUNCTION IF EXISTS cost_models_update_notify() CASCADE; diff --git a/migrations/20241024191258_add_cost_model_notification.up.sql b/migrations/20241024191258_add_cost_model_notification.up.sql new file mode 100644 index 000000000..c2f3fb37c --- /dev/null +++ b/migrations/20241024191258_add_cost_model_notification.up.sql @@ -0,0 +1,21 @@ +-- Add up migration script here +CREATE FUNCTION cost_models_update_notify() +RETURNS trigger AS +$$ +BEGIN + IF TG_OP = 'DELETE' THEN + PERFORM pg_notify('cost_models_update_notification', format('{"tg_op": "DELETE", "deployment": "%s"}', OLD.deployment)); + RETURN OLD; + ELSIF TG_OP = 'INSERT' THEN + PERFORM pg_notify('cost_models_update_notification', format('{"tg_op": "INSERT", "deployment": "%s", "model": "%s", "variables": "%s"}', NEW.deployment, NEW.model, NEW.variables)); + RETURN NEW; + ELSE + PERFORM pg_notify('cost_models_update_notification', format('{"tg_op": "%s", "deployment": "%s", "model": "%s", "variables": "%s" }', NEW.deployment, NEW.model, NEW.variables)); + RETURN NEW; + END IF; +END; +$$ LANGUAGE 'plpgsql'; + +CREATE TRIGGER cost_models_update AFTER INSERT OR UPDATE OR DELETE + ON "CostModels" + FOR EACH ROW EXECUTE PROCEDURE cost_models_update_notify(); diff --git a/service/src/database.rs b/service/src/database.rs index ab0cd80e0..7054f0358 100644 --- a/service/src/database.rs +++ b/service/src/database.rs @@ -1,13 +1,8 @@ // Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs. // SPDX-License-Identifier: Apache-2.0 -use std::time::Duration; -use std::{collections::HashSet, str::FromStr}; - -use serde::{Deserialize, Serialize}; -use serde_json::Value; use sqlx::{postgres::PgPoolOptions, PgPool}; -use thegraph_core::{DeploymentId, ParseDeploymentIdError}; +use std::time::Duration; use tracing::debug; pub async fn connect(url: &str) -> PgPool { @@ -20,481 +15,3 @@ pub async fn connect(url: &str) -> PgPool { .await .expect("Should be able to connect to the database") } - -/// Internal cost model representation as stored in the database. -/// -/// These can have "global" as the deployment ID. -#[derive(Debug, Clone)] -struct DbCostModel { - pub deployment: String, - pub model: Option, - pub variables: Option, -} - -/// External representation of cost models. -/// -/// Here, any notion of "global" is gone and deployment IDs are valid deployment IDs. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CostModel { - pub deployment: DeploymentId, - pub model: Option, - pub variables: Option, -} - -impl TryFrom for CostModel { - type Error = ParseDeploymentIdError; - - fn try_from(db_model: DbCostModel) -> Result { - Ok(Self { - deployment: DeploymentId::from_str(&db_model.deployment)?, - model: db_model.model, - variables: db_model.variables, - }) - } -} - -impl From for DbCostModel { - fn from(model: CostModel) -> Self { - let deployment = model.deployment; - DbCostModel { - deployment: format!("{deployment:#x}"), - model: model.model, - variables: model.variables, - } - } -} - -/// Query cost models from the database, merging the global cost model in -/// whenever there is no cost model defined for a deployment. -pub async fn cost_models( - pool: &PgPool, - deployments: &[DeploymentId], -) -> Result, anyhow::Error> { - let hex_ids = deployments - .iter() - .map(|d| format!("{d:#x}")) - .collect::>(); - - let mut models = if deployments.is_empty() { - sqlx::query_as!( - DbCostModel, - r#" - SELECT deployment, model, variables - FROM "CostModels" - WHERE deployment != 'global' - ORDER BY deployment ASC - "# - ) - .fetch_all(pool) - .await? - } else { - sqlx::query_as!( - DbCostModel, - r#" - SELECT deployment, model, variables - FROM "CostModels" - WHERE deployment = ANY($1) - AND deployment != 'global' - ORDER BY deployment ASC - "#, - &hex_ids - ) - .fetch_all(pool) - .await? - } - .into_iter() - .map(CostModel::try_from) - .collect::, _>>()?; - - let deployments_with_models = models - .iter() - .map(|model| &model.deployment) - .collect::>(); - - let deployments_without_models = deployments - .iter() - .filter(|deployment| !deployments_with_models.contains(deployment)) - .collect::>(); - - // Query the global cost model - if let Some(global_model) = global_cost_model(pool).await? { - // For all deployments that have a cost model, merge it with the global one - models = models - .into_iter() - .map(|model| merge_global(model, &global_model)) - // Inject a cost model for all deployments that don't have one - .chain( - deployments_without_models - .into_iter() - .map(|deployment| CostModel { - deployment: deployment.to_owned(), - model: global_model.model.clone(), - variables: global_model.variables.clone(), - }), - ) - .collect(); - } - - Ok(models) -} - -/// Make database query for a cost model indexed by deployment id -pub async fn cost_model( - pool: &PgPool, - deployment: &DeploymentId, -) -> Result, anyhow::Error> { - let model = sqlx::query_as!( - DbCostModel, - r#" - SELECT deployment, model, variables - FROM "CostModels" - WHERE deployment = $1 - AND deployment != 'global' - "#, - format!("{:#x}", deployment), - ) - .fetch_optional(pool) - .await? - .map(CostModel::try_from) - .transpose()?; - - let global_model = global_cost_model(pool).await?; - - Ok(match (model, global_model) { - // If we have no global model, return whatever we can find for the deployment - (None, None) => None, - (Some(model), None) => Some(model), - - // If we have a cost model and a global cost model, merge them - (Some(model), Some(global_model)) => Some(merge_global(model, &global_model)), - - // If we have only a global model, use that - (None, Some(global_model)) => Some(CostModel { - deployment: deployment.to_owned(), - model: global_model.model, - variables: global_model.variables, - }), - }) -} - -/// Query global cost model -async fn global_cost_model(pool: &PgPool) -> Result, anyhow::Error> { - sqlx::query_as!( - DbCostModel, - r#" - SELECT deployment, model, variables - FROM "CostModels" - WHERE deployment = $1 - "#, - "global" - ) - .fetch_optional(pool) - .await - .map_err(Into::into) -} - -fn merge_global(model: CostModel, global_model: &DbCostModel) -> CostModel { - CostModel { - deployment: model.deployment, - model: model.model.clone().or(global_model.model.clone()), - variables: model.variables.clone().or(global_model.variables.clone()), - } -} - -#[cfg(test)] -mod test { - - use std::str::FromStr; - - use sqlx::PgPool; - - use super::*; - - async fn setup_cost_models_table(pool: &PgPool) { - sqlx::query!( - r#" - CREATE TABLE "CostModels"( - id INT, - deployment VARCHAR NOT NULL, - model TEXT, - variables JSONB, - PRIMARY KEY( deployment ) - ); - "#, - ) - .execute(pool) - .await - .expect("Create test instance in db"); - } - - async fn add_cost_models(pool: &PgPool, models: Vec) { - for model in models { - sqlx::query!( - r#" - INSERT INTO "CostModels" (deployment, model) - VALUES ($1, $2); - "#, - model.deployment, - model.model, - ) - .execute(pool) - .await - .expect("Create test instance in db"); - } - } - - fn to_db_models(models: Vec) -> Vec { - models.into_iter().map(DbCostModel::from).collect() - } - - fn global_cost_model() -> DbCostModel { - DbCostModel { - deployment: "global".to_string(), - model: Some("default => 0.00001;".to_string()), - variables: None, - } - } - - fn test_data() -> Vec { - vec![ - CostModel { - deployment: "0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" - .parse() - .unwrap(), - model: None, - variables: None, - }, - CostModel { - deployment: "0xbd499f7673ca32ef4a642207a8bebdd0fb03888cf2678b298438e3a1ae5206ea" - .parse() - .unwrap(), - model: Some("default => 0.00025;".to_string()), - variables: None, - }, - CostModel { - deployment: "0xcccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc" - .parse() - .unwrap(), - model: Some("default => 0.00012;".to_string()), - variables: None, - }, - ] - } - - #[sqlx::test] - async fn success_cost_models(pool: PgPool) { - let test_models = test_data(); - let test_deployments = test_models - .iter() - .map(|model| model.deployment) - .collect::>(); - - setup_cost_models_table(&pool).await; - add_cost_models(&pool, to_db_models(test_models.clone())).await; - - // First test: query without deployment filter - let models = cost_models(&pool, &[]) - .await - .expect("cost models query without deployment filter"); - - // We expect as many models as we have in the test data - assert_eq!(models.len(), test_models.len()); - - // We expect models for all test deployments to be present and - // identical to the test data - for test_deployment in test_deployments.iter() { - let test_model = test_models - .iter() - .find(|model| &model.deployment == test_deployment) - .expect("finding cost model for test deployment in test data"); - - let model = models - .iter() - .find(|model| &model.deployment == test_deployment) - .expect("finding cost model for test deployment in query result"); - - assert_eq!(test_model.model, model.model); - } - - // Second test: query with a deployment filter - let sample_deployments = vec![ - test_models.first().unwrap().deployment, - test_models.get(1).unwrap().deployment, - ]; - let models = cost_models(&pool, &sample_deployments) - .await - .expect("cost models query with deployment filter"); - - // Expect two cost mdoels to be returned - assert_eq!(models.len(), sample_deployments.len()); - - // Expect both returned deployments to be identical to the test data - for test_deployment in sample_deployments.iter() { - let test_model = test_models - .iter() - .find(|model| &model.deployment == test_deployment) - .expect("finding cost model for test deployment in test data"); - - let model = models - .iter() - .find(|model| &model.deployment == test_deployment) - .expect("finding cost model for test deployment in query result"); - - assert_eq!(test_model.model, model.model); - } - } - - #[sqlx::test] - async fn global_fallback_cost_models(pool: PgPool) { - let test_models = test_data(); - let test_deployments = test_models - .iter() - .map(|model| model.deployment) - .collect::>(); - let global_model = global_cost_model(); - - setup_cost_models_table(&pool).await; - add_cost_models(&pool, to_db_models(test_models.clone())).await; - add_cost_models(&pool, vec![global_model.clone()]).await; - - // First test: fetch cost models without filtering by deployment - let models = cost_models(&pool, &[]) - .await - .expect("cost models query without deployments filter"); - - // Since we've defined 3 cost models and we did not provide a filter, we - // expect all of them to be returned except for the global cost model - assert_eq!(models.len(), test_models.len()); - - // Expect all test deployments to be present in the query result - for test_deployment in test_deployments.iter() { - let test_model = test_models - .iter() - .find(|model| &model.deployment == test_deployment) - .expect("finding cost model for deployment in test data"); - - let model = models - .iter() - .find(|model| &model.deployment == test_deployment) - .expect("finding cost model for deployment in query result"); - - if test_model.model.is_some() { - // If the test model has a model definition, we expect that to be returned - assert_eq!(model.model, test_model.model); - } else { - // If the test model has no model definition, we expect the global - // model definition to be returned - assert_eq!(model.model, global_model.model); - } - } - - // Second test: fetch cost models, filtering by the first two deployment IDs - let sample_deployments = vec![ - test_models.first().unwrap().deployment, - test_models.get(1).unwrap().deployment, - ]; - let models = dbg!(cost_models(&pool, &sample_deployments).await) - .expect("cost models query with deployments filter"); - - // We've filtered by two deployment IDs and are expecting two cost models to be returned - assert_eq!(models.len(), sample_deployments.len()); - - for test_deployment in sample_deployments { - let test_model = test_models - .iter() - .find(|model| model.deployment == test_deployment) - .expect("finding cost model for deployment in test data"); - - let model = models - .iter() - .find(|model| model.deployment == test_deployment) - .expect("finding cost model for deployment in query result"); - - if test_model.model.is_some() { - // If the test model has a model definition, we expect that to be returned - assert_eq!(model.model, test_model.model); - } else { - // If the test model has no model definition, we expect the global - // model definition to be returned - assert_eq!(model.model, global_model.model); - } - } - - // Third test: query for missing cost model - let missing_deployment = - DeploymentId::from_str("Qmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa").unwrap(); - let models = cost_models(&pool, &[missing_deployment]) - .await - .expect("cost models query for missing deployment"); - - // The deployment may be missing in the database but we have a global model - // and expect that to be returned, with the missing deployment ID - let missing_model = models - .iter() - .find(|m| m.deployment == missing_deployment) - .expect("finding missing deployment"); - assert_eq!(missing_model.model, global_model.model); - } - - #[sqlx::test] - async fn success_cost_model(pool: PgPool) { - setup_cost_models_table(&pool).await; - add_cost_models(&pool, to_db_models(test_data())).await; - - let deployment_id_from_bytes = DeploymentId::from_str( - "0xbd499f7673ca32ef4a642207a8bebdd0fb03888cf2678b298438e3a1ae5206ea", - ) - .unwrap(); - let deployment_id_from_hash = - DeploymentId::from_str("Qmb5Ysp5oCUXhLA8NmxmYKDAX2nCMnh7Vvb5uffb9n5vss").unwrap(); - - assert_eq!(deployment_id_from_bytes, deployment_id_from_hash); - - let model = cost_model(&pool, &deployment_id_from_bytes) - .await - .expect("cost model query") - .expect("cost model for deployment"); - - assert_eq!(model.deployment, deployment_id_from_hash); - assert_eq!(model.model, Some("default => 0.00025;".to_string())); - } - - #[sqlx::test] - async fn global_fallback_cost_model(pool: PgPool) { - let test_models = test_data(); - let global_model = global_cost_model(); - - setup_cost_models_table(&pool).await; - add_cost_models(&pool, to_db_models(test_models.clone())).await; - add_cost_models(&pool, vec![global_model.clone()]).await; - - // Test that the behavior is correct for existing deployments - for test_model in test_models { - let model = cost_model(&pool, &test_model.deployment) - .await - .expect("cost model query") - .expect("global cost model fallback"); - - assert_eq!(model.deployment, test_model.deployment); - - if test_model.model.is_some() { - // If the test model has a model definition, we expect that to be returned - assert_eq!(model.model, test_model.model); - } else { - // If the test model has no model definition, we expect the global - // model definition to be returned - assert_eq!(model.model, global_model.model); - } - } - - // Test that querying a non-existing deployment returns the default cost model - let missing_deployment = - DeploymentId::from_str("Qmnononononononononononononononononononononono").unwrap(); - let model = cost_model(&pool, &missing_deployment) - .await - .expect("cost model query") - .expect("global cost model fallback"); - assert_eq!(model.deployment, missing_deployment); - assert_eq!(model.model, global_model.model); - } -} diff --git a/service/src/routes/cost.rs b/service/src/routes/cost.rs index 08d4eb2cf..044a15f9f 100644 --- a/service/src/routes/cost.rs +++ b/service/src/routes/cost.rs @@ -7,6 +7,7 @@ use std::sync::Arc; use async_graphql::{Context, EmptyMutation, EmptySubscription, Object, Schema, SimpleObject}; use async_graphql_axum::{GraphQLRequest, GraphQLResponse}; use axum::extract::State; +use indexer_common::cost_model::{self, CostModel}; use lazy_static::lazy_static; use prometheus::{ register_counter, register_counter_vec, register_histogram, register_histogram_vec, Counter, @@ -16,7 +17,6 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use thegraph_core::DeploymentId; -use crate::database::{self, CostModel}; use crate::service::SubgraphServiceState; lazy_static! { @@ -128,7 +128,7 @@ impl Query { deployment_ids: Vec, ) -> Result, anyhow::Error> { let pool = &ctx.data_unchecked::>().database; - let cost_models = database::cost_models(pool, &deployment_ids).await?; + let cost_models = cost_model::cost_models(pool, &deployment_ids).await?; Ok(cost_models.into_iter().map(|m| m.into()).collect()) } @@ -138,7 +138,7 @@ impl Query { deployment_id: DeploymentId, ) -> Result, anyhow::Error> { let pool = &ctx.data_unchecked::>().database; - database::cost_model(pool, &deployment_id) + cost_model::cost_model(pool, &deployment_id) .await .map(|model_opt| model_opt.map(GraphQlCostModel::from)) } diff --git a/service/src/service.rs b/service/src/service.rs index e583bff75..12504677b 100644 --- a/service/src/service.rs +++ b/service/src/service.rs @@ -12,6 +12,7 @@ use indexer_common::indexer_service::http::{ }; use indexer_config::Config; use reqwest::Url; +use serde::{de::DeserializeOwned, Serialize}; use serde_json::{json, Value}; use sqlx::PgPool; use thegraph_core::DeploymentId; @@ -82,15 +83,14 @@ impl SubgraphService { #[async_trait] impl IndexerServiceImpl for SubgraphService { type Error = SubgraphServiceError; - type Request = serde_json::Value; type Response = SubgraphServiceResponse; type State = SubgraphServiceState; - async fn process_request( + async fn process_request( &self, deployment: DeploymentId, - request: Self::Request, - ) -> Result<(Self::Request, Self::Response), Self::Error> { + request: Request, + ) -> Result<(Request, Self::Response), Self::Error> { let deployment_url = self .state .graph_node_query_base_url diff --git a/tap-agent/src/agent/sender_allocation.rs b/tap-agent/src/agent/sender_allocation.rs index b44a896f3..572a25249 100644 --- a/tap-agent/src/agent/sender_allocation.rs +++ b/tap-agent/src/agent/sender_allocation.rs @@ -23,7 +23,7 @@ use tap_core::{ receipt::{ checks::{Check, CheckList}, state::Failed, - ReceiptWithState, + Context, ReceiptWithState, }, signed_message::EIP712SignedMessage, }; @@ -521,6 +521,7 @@ impl SenderAllocationState { } = self .tap_manager .create_rav_request( + &Context::new(), self.timestamp_buffer_ns, Some(self.rav_request_receipt_limit), ) @@ -884,7 +885,7 @@ pub mod tests { use tap_core::receipt::{ checks::{Check, CheckError, CheckList, CheckResult}, state::Checking, - ReceiptWithState, + Context, ReceiptWithState, }; use tokio::sync::mpsc; use wiremock::{ @@ -1495,7 +1496,11 @@ pub mod tests { #[async_trait::async_trait] impl Check for FailingCheck { - async fn check(&self, _receipt: &ReceiptWithState) -> CheckResult { + async fn check( + &self, + _: &tap_core::receipt::Context, + _receipt: &ReceiptWithState, + ) -> CheckResult { Err(CheckError::Failed(anyhow::anyhow!("Failing check"))) } } @@ -1517,7 +1522,7 @@ pub mod tests { .into_iter() .map(|receipt| async { receipt - .finalize_receipt_checks(&checks) + .finalize_receipt_checks(&Context::new(), &checks) .await .unwrap() .unwrap_err() diff --git a/tap-agent/src/tap/context/checks/allocation_id.rs b/tap-agent/src/tap/context/checks/allocation_id.rs index 87f05fbe2..978bd7c45 100644 --- a/tap-agent/src/tap/context/checks/allocation_id.rs +++ b/tap-agent/src/tap/context/checks/allocation_id.rs @@ -46,7 +46,11 @@ impl AllocationId { #[async_trait::async_trait] impl Check for AllocationId { - async fn check(&self, receipt: &ReceiptWithState) -> CheckResult { + async fn check( + &self, + _: &tap_core::receipt::Context, + receipt: &ReceiptWithState, + ) -> CheckResult { let allocation_id = receipt.signed_receipt().message.allocation_id; // TODO: Remove the if block below? Each TAP Monitor is specific to an allocation // ID. So the receipts that are received here should already have been filtered by diff --git a/tap-agent/src/tap/context/checks/signature.rs b/tap-agent/src/tap/context/checks/signature.rs index e4727ef6b..4b1bbe81a 100644 --- a/tap-agent/src/tap/context/checks/signature.rs +++ b/tap-agent/src/tap/context/checks/signature.rs @@ -29,7 +29,11 @@ impl Signature { #[async_trait::async_trait] impl Check for Signature { - async fn check(&self, receipt: &ReceiptWithState) -> CheckResult { + async fn check( + &self, + _: &tap_core::receipt::Context, + receipt: &ReceiptWithState, + ) -> CheckResult { let signer = receipt .signed_receipt() .recover_signer(&self.domain_separator) diff --git a/tap-agent/src/tap/context/checks/value.rs b/tap-agent/src/tap/context/checks/value.rs index 5265a370c..414095d1a 100644 --- a/tap-agent/src/tap/context/checks/value.rs +++ b/tap-agent/src/tap/context/checks/value.rs @@ -24,7 +24,11 @@ pub struct Value { #[async_trait::async_trait] impl Check for Value { - async fn check(&self, receipt: &ReceiptWithState) -> CheckResult { + async fn check( + &self, + _: &tap_core::receipt::Context, + receipt: &ReceiptWithState, + ) -> CheckResult { let value = receipt.signed_receipt().message.value; let query_id = receipt.signed_receipt().unique_hash();