diff --git a/Cargo.lock b/Cargo.lock index d39bb3daa56..a30d7d47660 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2243,7 +2243,9 @@ dependencies = [ "rand 0.9.2", "serde", "serde_json", + "sqlparser", "stable-hash 0.3.4", + "thiserror 2.0.16", ] [[package]] @@ -5157,6 +5159,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "11a81a8cad9befe4cf1b9d2d4b9c6841c76f0882a3fec00d95133953c13b3d3d" dependencies = [ "log", + "sqlparser_derive", +] + +[[package]] +name = "sqlparser_derive" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 2620237e28e..633c31a38e5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -81,11 +81,8 @@ serde_derive = "1.0.125" serde_json = { version = "1.0", features = ["arbitrary_precision"] } serde_regex = "1.1.0" serde_yaml = "0.9.21" -slog = { version = "2.7.0", features = [ - "release_max_level_trace", - "max_level_trace", -] } -sqlparser = "0.46.0" +slog = { version = "2.7.0", features = ["release_max_level_trace", "max_level_trace"] } +sqlparser = { version = "0.46.0", features = ["visitor"] } strum = { version = "0.26", features = ["derive"] } syn = { version = "2.0.106", features = ["full"] } test-store = { path = "./store/test-store" } diff --git a/docs/environment-variables.md b/docs/environment-variables.md index 747364dd0c4..a0a3cfd8cf5 100644 --- a/docs/environment-variables.md +++ b/docs/environment-variables.md @@ -284,3 +284,6 @@ those. graph-node bugs, but since it is hard to work around them, setting this variable to something like 10 makes it possible to work around such a bug while it is being fixed (default: 0) +- `GRAPH_ENABLE_SQL_QUERIES`: Enable the experimental [SQL query + interface](implementation/sql-interface.md). + (default: false) diff --git a/docs/implementation/sql-interface.md b/docs/implementation/sql-interface.md new file mode 100644 index 00000000000..6b90fe6da9c --- /dev/null +++ b/docs/implementation/sql-interface.md @@ -0,0 +1,89 @@ +# SQL Queries + +**This interface is extremely experimental. There is no guarantee that this +interface will ever be brought to production use. It's solely here to help +evaluate the utility of such an interface** + +**The interface is only available if the environment variable `GRAPH_ENABLE_SQL_QUERIES` is set to `true`** + +SQL queries can be issued by posting a JSON document to +`/subgraphs/sql`. The server will respond with a JSON response that +contains the records matching the query in JSON form. + +The body of the request must contain the following keys: + +* `deployment`: the hash of the deployment against which the query should + be run +* `query`: the SQL query +* `mode`: either `info` or `data`. When the mode is `info` only some + information of the response is reported, with a mode of `data` the query + result is sent in the response + +The SQL query can use all the tables of the given subgraph. Table and +attribute names for normal `@entity` types are snake-cased from their form +in the GraphQL schema, so that data for `SomeDailyStuff` is stored in a +table `some_daily_stuff`. For `@aggregation` types, the table can be +accessed as `()`, for example, `my_stats('hour')` for +`type MyStats @aggregation(..) { .. }` + +The query can use fairly arbitrary SQL, including aggregations and most +functions built into PostgreSQL. + +## Example + +For a subgraph whose schema defines an entity `Block`, the following query +```json +{ + "query": "select number, hash, parent_hash, timestamp from block order by number desc limit 2", + "deployment": "QmSoMeThInG", + "mode": "data" +} +``` + +might result in this response +```json +{ + "data": [ + { + "hash": "\\x5f91e535ee4d328725b869dd96f4c42059e3f2728dfc452c32e5597b28ce68d6", + "number": 5000, + "parent_hash": "\\x82e95c1ee3a98cd0646225b5ae6afc0b0229367b992df97aeb669c898657a4bb", + "timestamp": "2015-07-30T20:07:44+00:00" + }, + { + "hash": "\\x82e95c1ee3a98cd0646225b5ae6afc0b0229367b992df97aeb669c898657a4bb", + "number": 4999, + "parent_hash": "\\x875c9a0f8215258c3b17fd5af5127541121cca1f594515aae4fbe5a7fbef8389", + "timestamp": "2015-07-30T20:07:36+00:00" + } + ] +} +``` + +## Limitations/Ideas/Disclaimers + +Most of these are fairly easy to address: + +- bind variables/query parameters are not supported, only literal SQL + queries +* queries must finish within `GRAPH_SQL_STATEMENT_TIMEOUT` (unlimited by + default) +* queries are always executed at the subgraph head. It would be easy to add + a way to specify a block at which the query should be executed +* the interface right now pretty much exposes the raw SQL schema for a + subgraph, though system columns like `vid` or `block_range` are made + inaccessible. +* it is not possible to join across subgraphs, though it would be possible + to add that. Implenting that would require some additional plumbing that + hides the effects of sharding. +* JSON as the response format is pretty terrible, and we should change that + to something that isn't so inefficient +* the response contains data that's pretty raw; as the example shows, + binary data uses Postgres' notation for hex strings +* because of how broad the supported SQL is, it is pretty easy to issue + queries that take a very long time. It will therefore not be hard to take + down a `graph-node`, especially when no query timeout is set + +Most importantly: while quite a bit of effort has been put into making this +interface safe, in particular, making sure it's not possible to write +through this interface, there's no guarantee that this works without bugs. diff --git a/graph/src/components/graphql.rs b/graph/src/components/graphql.rs index b5fc4273860..8d42cecb9d8 100644 --- a/graph/src/components/graphql.rs +++ b/graph/src/components/graphql.rs @@ -1,6 +1,7 @@ -use crate::data::query::QueryResults; use crate::data::query::{Query, QueryTarget}; -use crate::prelude::DeploymentHash; +use crate::data::query::{QueryResults, SqlQueryReq}; +use crate::data::store::SqlQueryObject; +use crate::prelude::{DeploymentHash, QueryExecutionError}; use async_trait::async_trait; use std::sync::Arc; @@ -28,6 +29,11 @@ pub trait GraphQlRunner: Send + Sync + 'static { ) -> QueryResults; fn metrics(&self) -> Arc; + + async fn run_sql_query( + self: Arc, + req: SqlQueryReq, + ) -> Result, QueryExecutionError>; } pub trait GraphQLMetrics: Send + Sync + 'static { diff --git a/graph/src/components/store/traits.rs b/graph/src/components/store/traits.rs index eae3a1b0b4c..0964fb9a026 100644 --- a/graph/src/components/store/traits.rs +++ b/graph/src/components/store/traits.rs @@ -16,7 +16,7 @@ use crate::components::transaction_receipt; use crate::components::versions::ApiVersion; use crate::data::query::Trace; use crate::data::store::ethereum::call; -use crate::data::store::QueryObject; +use crate::data::store::{QueryObject, SqlQueryObject}; use crate::data::subgraph::{status, DeploymentFeatures}; use crate::data::{query::QueryTarget, subgraph::schema::*}; use crate::prelude::{DeploymentState, NodeId, QueryExecutionError, SubgraphName}; @@ -652,6 +652,8 @@ pub trait QueryStore: Send + Sync { query: EntityQuery, ) -> Result<(Vec, Trace), QueryExecutionError>; + fn execute_sql(&self, sql: &str) -> Result, QueryExecutionError>; + async fn is_deployment_synced(&self) -> Result; async fn block_ptr(&self) -> Result, StoreError>; diff --git a/graph/src/data/query/error.rs b/graph/src/data/query/error.rs index cfd900ac596..1a85f34af8c 100644 --- a/graph/src/data/query/error.rs +++ b/graph/src/data/query/error.rs @@ -73,6 +73,7 @@ pub enum QueryExecutionError { InvalidSubgraphManifest, ResultTooBig(usize, usize), DeploymentNotFound(String), + SqlError(String), IdMissing, IdNotString, InternalError(String), @@ -135,6 +136,7 @@ impl QueryExecutionError { | IdMissing | IdNotString | InternalError(_) => false, + SqlError(_) => false, } } } @@ -213,7 +215,7 @@ impl fmt::Display for QueryExecutionError { } InvalidFilterError => write!(f, "Filter must by an object"), InvalidOrFilterStructure(fields, example) => { - write!(f, "Cannot mix column filters with 'or' operator at the same level. Found column filter(s) {} alongside 'or' operator.\n\n{}", + write!(f, "Cannot mix column filters with 'or' operator at the same level. Found column filter(s) {} alongside 'or' operator.\n\n{}", fields.join(", "), example) } EntityFieldError(e, a) => { @@ -281,6 +283,7 @@ impl fmt::Display for QueryExecutionError { IdMissing => write!(f, "entity is missing an `id` attribute"), IdNotString => write!(f, "entity `id` attribute is not a string"), InternalError(msg) => write!(f, "internal error: {}", msg), + SqlError(e) => write!(f, "sql error: {}", e), } } } diff --git a/graph/src/data/query/mod.rs b/graph/src/data/query/mod.rs index 73a6f1fe220..407c2218525 100644 --- a/graph/src/data/query/mod.rs +++ b/graph/src/data/query/mod.rs @@ -6,6 +6,6 @@ mod trace; pub use self::cache_status::CacheStatus; pub use self::error::{QueryError, QueryExecutionError}; -pub use self::query::{Query, QueryTarget, QueryVariables}; +pub use self::query::{Query, QueryTarget, QueryVariables, SqlQueryMode, SqlQueryReq}; pub use self::result::{LatestBlockInfo, QueryResult, QueryResults}; pub use self::trace::Trace; diff --git a/graph/src/data/query/query.rs b/graph/src/data/query/query.rs index 2ca93f0cc43..5bb64a8a134 100644 --- a/graph/src/data/query/query.rs +++ b/graph/src/data/query/query.rs @@ -1,7 +1,8 @@ use serde::de::Deserializer; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, HashMap}; use std::convert::TryFrom; +use std::hash::{DefaultHasher, Hash as _, Hasher as _}; use std::ops::{Deref, DerefMut}; use std::sync::Arc; @@ -165,3 +166,26 @@ impl Query { } } } + +#[derive(Copy, Clone, Debug, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum SqlQueryMode { + Data, + Info, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct SqlQueryReq { + pub deployment: DeploymentHash, + pub query: String, + pub mode: SqlQueryMode, +} + +impl SqlQueryReq { + pub fn query_hash(&self) -> u64 { + let mut hasher = DefaultHasher::new(); + self.deployment.hash(&mut hasher); + self.query.hash(&mut hasher); + hasher.finish() + } +} diff --git a/graph/src/data/store/mod.rs b/graph/src/data/store/mod.rs index c8786e9b473..f52c70b7ce3 100644 --- a/graph/src/data/store/mod.rs +++ b/graph/src/data/store/mod.rs @@ -1102,6 +1102,10 @@ pub struct QueryObject { pub entity: r::Object, } +/// An object that is returned from a SQL query. It wraps an `r::Value` +#[derive(CacheWeight, Serialize)] +pub struct SqlQueryObject(pub r::Value); + impl CacheWeight for QueryObject { fn indirect_weight(&self) -> usize { self.parent.indirect_weight() + self.entity.indirect_weight() diff --git a/graph/src/env/mod.rs b/graph/src/env/mod.rs index 802b304db1f..3fce087986e 100644 --- a/graph/src/env/mod.rs +++ b/graph/src/env/mod.rs @@ -24,6 +24,7 @@ lazy_static! { #[cfg(debug_assertions)] lazy_static! { pub static ref TEST_WITH_NO_REORG: Mutex = Mutex::new(false); + pub static ref TEST_SQL_QUERIES_ENABLED: Mutex = Mutex::new(false); } /// Panics if: @@ -189,6 +190,10 @@ pub struct EnvVars { /// Set by the environment variable `ETHEREUM_REORG_THRESHOLD`. The default /// value is 250 blocks. reorg_threshold: BlockNumber, + /// Enable SQL query interface. SQL queries are disabled by default + /// because they are still experimental. Set by the environment variable + /// `GRAPH_ENABLE_SQL_QUERIES`. Off by default. + enable_sql_queries: bool, /// The time to wait between polls when using polling block ingestor. /// The value is set by `ETHERUM_POLLING_INTERVAL` in millis and the /// default is 1000. @@ -341,6 +346,7 @@ impl EnvVars { external_ws_base_url: inner.external_ws_base_url, static_filters_threshold: inner.static_filters_threshold, reorg_threshold: inner.reorg_threshold, + enable_sql_queries: inner.enable_sql_queries.0, ingestor_polling_interval: Duration::from_millis(inner.ingestor_polling_interval), subgraph_settings: inner.subgraph_settings, prefer_substreams_block_streams: inner.prefer_substreams_block_streams, @@ -414,6 +420,27 @@ impl EnvVars { pub fn reorg_threshold(&self) -> i32 { self.reorg_threshold } + + #[cfg(debug_assertions)] + pub fn sql_queries_enabled(&self) -> bool { + // SQL queries are disabled by default for security. + // For testing purposes, we allow tests to enable SQL queries via TEST_SQL_QUERIES_ENABLED. + if *TEST_SQL_QUERIES_ENABLED.lock().unwrap() { + true + } else { + self.enable_sql_queries + } + } + #[cfg(not(debug_assertions))] + pub fn sql_queries_enabled(&self) -> bool { + self.enable_sql_queries + } + + #[cfg(debug_assertions)] + pub fn enable_sql_queries_for_tests(&self, enable: bool) { + let mut lock = TEST_SQL_QUERIES_ENABLED.lock().unwrap(); + *lock = enable; + } } impl Default for EnvVars { @@ -514,6 +541,8 @@ struct Inner { // JSON-RPC specific. #[envconfig(from = "ETHEREUM_REORG_THRESHOLD", default = "250")] reorg_threshold: BlockNumber, + #[envconfig(from = "GRAPH_ENABLE_SQL_QUERIES", default = "false")] + enable_sql_queries: EnvVarBoolean, #[envconfig(from = "ETHEREUM_POLLING_INTERVAL", default = "1000")] ingestor_polling_interval: u64, #[envconfig(from = "GRAPH_EXPERIMENTAL_SUBGRAPH_SETTINGS")] diff --git a/graph/src/schema/input/mod.rs b/graph/src/schema/input/mod.rs index 15cdaa30478..25a80a75415 100644 --- a/graph/src/schema/input/mod.rs +++ b/graph/src/schema/input/mod.rs @@ -1383,6 +1383,14 @@ impl InputSchema { .any(|ti| matches!(ti, TypeInfo::Aggregation(_))) } + pub fn aggregation_names(&self) -> impl Iterator { + self.inner + .type_infos + .iter() + .filter_map(TypeInfo::aggregation) + .map(|agg_type| self.inner.pool.get(agg_type.name).unwrap()) + } + pub fn entity_fulltext_definitions( &self, entity: &str, diff --git a/graphql/src/runner.rs b/graphql/src/runner.rs index d2f0bc9c96c..210f070acd6 100644 --- a/graphql/src/runner.rs +++ b/graphql/src/runner.rs @@ -4,12 +4,14 @@ use std::time::Instant; use crate::metrics::GraphQLMetrics; use crate::prelude::{QueryExecutionOptions, StoreResolver}; use crate::query::execute_query; +use graph::data::query::{CacheStatus, SqlQueryReq}; +use graph::data::store::SqlQueryObject; use graph::futures03::future; -use graph::prelude::MetricsRegistry; use graph::prelude::{ async_trait, o, CheapClone, DeploymentState, GraphQLMetrics as GraphQLMetricsTrait, GraphQlRunner as GraphQlRunnerTrait, Logger, Query, QueryExecutionError, ENV_VARS, }; +use graph::prelude::{ApiVersion, MetricsRegistry}; use graph::{data::graphql::load_manager::LoadManager, prelude::QueryStoreManager}; use graph::{ data::query::{LatestBlockInfo, QueryResults, QueryTarget}, @@ -251,4 +253,51 @@ where fn metrics(&self) -> Arc { self.graphql_metrics.clone() } + + async fn run_sql_query( + self: Arc, + req: SqlQueryReq, + ) -> Result, QueryExecutionError> { + // Check if SQL queries are enabled + if !ENV_VARS.sql_queries_enabled() { + return Err(QueryExecutionError::SqlError( + "SQL queries are disabled. Set GRAPH_ENABLE_SQL_QUERIES=true to enable." + .to_string(), + )); + } + + let store = self + .store + .query_store(QueryTarget::Deployment( + req.deployment.clone(), + ApiVersion::default(), + )) + .await?; + + let query_hash = req.query_hash(); + self.load_manager + .decide( + &store.wait_stats(), + store.shard(), + store.deployment_id(), + query_hash, + &req.query, + ) + .to_result()?; + + let query_start = Instant::now(); + let result = store + .execute_sql(&req.query) + .map_err(|e| QueryExecutionError::from(e)); + + self.load_manager.record_work( + store.shard(), + store.deployment_id(), + query_hash, + query_start.elapsed(), + CacheStatus::Miss, + ); + + result + } } diff --git a/server/http/src/service.rs b/server/http/src/service.rs index 8e2237b86ff..c69e6428983 100644 --- a/server/http/src/service.rs +++ b/server/http/src/service.rs @@ -9,6 +9,8 @@ use graph::components::server::query::ServerResponse; use graph::components::server::query::ServerResult; use graph::components::versions::ApiVersion; use graph::data::query::QueryResult; +use graph::data::query::SqlQueryMode; +use graph::data::query::SqlQueryReq; use graph::data::subgraph::DeploymentHash; use graph::data::subgraph::SubgraphName; use graph::env::ENV_VARS; @@ -21,6 +23,8 @@ use graph::hyper::{body::Body, header::HeaderValue}; use graph::hyper::{Method, Request, Response, StatusCode}; use graph::prelude::serde_json; use graph::prelude::serde_json::json; +use graph::prelude::CacheWeight as _; +use graph::prelude::QueryError; use graph::semver::VersionReq; use graph::slog::error; use graph::slog::Logger; @@ -195,6 +199,51 @@ where Ok(result.as_http_response()) } + async fn handle_sql_query(&self, request: Request) -> ServerResult { + let body = request + .collect() + .await + .map_err(|_| ServerError::InternalError("Failed to read request body".into()))? + .to_bytes(); + let sql_req: SqlQueryReq = serde_json::from_slice(&body) + .map_err(|e| ServerError::ClientError(format!("{}", e)))?; + + let mode = sql_req.mode; + let result = self + .graphql_runner + .cheap_clone() + .run_sql_query(sql_req) + .await + .map_err(|e| ServerError::QueryError(QueryError::from(e))); + + use SqlQueryMode::*; + let response_obj = match (result, mode) { + (Ok(result), Info) => { + json!({ + "count": result.len(), + "bytes" : result.weight(), + }) + } + (Ok(result), Data) => { + json!({ + "data": result, + }) + } + (Err(e), _) => json!({ + "error": e.to_string(), + }), + }; + + let response_str = serde_json::to_string(&response_obj).unwrap(); + + Ok(Response::builder() + .status(200) + .header(ACCESS_CONTROL_ALLOW_ORIGIN, "*") + .header(CONTENT_TYPE, "application/json") + .body(Full::from(response_str)) + .unwrap()) + } + // Handles OPTIONS requests fn handle_graphql_options(&self, _request: Request) -> ServerResult { Ok(Response::builder() @@ -327,7 +376,9 @@ where let dest = format!("/{}/graphql", filtered_path); self.handle_temp_redirect(dest) } - + (Method::POST, &["subgraphs", "sql"] | &["subgraphs", "sql", ""]) => { + self.handle_sql_query(req).await + } (Method::POST, &["subgraphs", "id", subgraph_id]) => { self.handle_graphql_query_by_id(subgraph_id.to_owned(), req) .await @@ -395,6 +446,7 @@ where #[cfg(test)] mod tests { + use graph::data::store::SqlQueryObject; use graph::data::value::{Object, Word}; use graph::http_body_util::{BodyExt, Full}; use graph::hyper::body::Bytes; @@ -402,7 +454,7 @@ mod tests { use graph::hyper::{Method, Request, StatusCode}; use graph::prelude::serde_json::json; - use graph::data::query::{QueryResults, QueryTarget}; + use graph::data::query::{QueryResults, QueryTarget, SqlQueryReq}; use graph::prelude::*; use crate::test_utils; @@ -449,6 +501,13 @@ mod tests { fn metrics(&self) -> Arc { Arc::new(TestGraphQLMetrics) } + + async fn run_sql_query( + self: Arc, + _req: SqlQueryReq, + ) -> Result, QueryExecutionError> { + unimplemented!() + } } #[tokio::test] diff --git a/server/http/tests/server.rs b/server/http/tests/server.rs index 08d5a41f363..9c8037f6f09 100644 --- a/server/http/tests/server.rs +++ b/server/http/tests/server.rs @@ -1,4 +1,7 @@ -use graph::http::StatusCode; +use graph::{ + data::{query::SqlQueryReq, store::SqlQueryObject}, + http::StatusCode, +}; use std::time::Duration; use graph::data::{ @@ -66,6 +69,13 @@ impl GraphQlRunner for TestGraphQlRunner { fn metrics(&self) -> Arc { Arc::new(TestGraphQLMetrics) } + + async fn run_sql_query( + self: Arc, + _req: SqlQueryReq, + ) -> Result, QueryExecutionError> { + unimplemented!(); + } } #[cfg(test)] diff --git a/store/postgres/Cargo.toml b/store/postgres/Cargo.toml index 392a5a1c3de..0d1abb23a58 100644 --- a/store/postgres/Cargo.toml +++ b/store/postgres/Cargo.toml @@ -32,6 +32,8 @@ git-testament = "0.2.6" itertools = "0.14.0" hex = "0.4.3" pretty_assertions = "1.4.1" +sqlparser = { workspace = true } +thiserror = { workspace = true } [dev-dependencies] clap.workspace = true diff --git a/store/postgres/src/deployment_store.rs b/store/postgres/src/deployment_store.rs index 7aaedf12895..f9aa0dfde75 100644 --- a/store/postgres/src/deployment_store.rs +++ b/store/postgres/src/deployment_store.rs @@ -12,8 +12,9 @@ use graph::components::store::{ PruningStrategy, QueryPermit, StoredDynamicDataSource, VersionStats, }; use graph::components::versions::VERSIONS; +use graph::data::graphql::IntoValue; use graph::data::query::Trace; -use graph::data::store::IdList; +use graph::data::store::{IdList, SqlQueryObject}; use graph::data::subgraph::{status, SPEC_VERSION_0_0_6}; use graph::data_source::CausalityRegion; use graph::derive::CheapClone; @@ -53,8 +54,8 @@ use crate::detail::ErrorDetail; use crate::dynds::DataSourcesTable; use crate::primary::{DeploymentId, Primary}; use crate::relational::index::{CreateIndex, IndexList, Method}; -use crate::relational::{self, Layout, LayoutCache, SqlName, Table}; -use crate::relational_queries::FromEntityData; +use crate::relational::{self, Layout, LayoutCache, SqlName, Table, STATEMENT_TIMEOUT}; +use crate::relational_queries::{FromEntityData, JSONData}; use crate::{advisory_lock, catalog, retry}; use crate::{detail, ConnectionPool}; use crate::{dynds, primary::Site}; @@ -290,6 +291,34 @@ impl DeploymentStore { layout.query(&logger, conn, query) } + pub(crate) fn execute_sql( + &self, + conn: &mut PgConnection, + query: &str, + ) -> Result, QueryExecutionError> { + let query = format!( + "select to_jsonb(sub.*) as data from ({}) as sub limit {}", + query, ENV_VARS.graphql.max_first + ); + let query = diesel::sql_query(query); + + let results = conn + .transaction(|conn| { + if let Some(ref timeout_sql) = *STATEMENT_TIMEOUT { + conn.batch_execute(timeout_sql)?; + } + + // Execute the provided SQL query + query.load::(conn) + }) + .map_err(|e| QueryExecutionError::SqlError(e.to_string()))?; + + Ok(results + .into_iter() + .map(|e| SqlQueryObject(e.into_value())) + .collect::>()) + } + fn check_intf_uniqueness( &self, conn: &mut PgConnection, diff --git a/store/postgres/src/lib.rs b/store/postgres/src/lib.rs index 0bbb261c154..794d8b966dd 100644 --- a/store/postgres/src/lib.rs +++ b/store/postgres/src/lib.rs @@ -30,6 +30,7 @@ pub mod query_store; mod relational; mod relational_queries; mod retry; +mod sql; mod store; mod store_events; mod subgraph_store; diff --git a/store/postgres/src/query_store.rs b/store/postgres/src/query_store.rs index ab6c43e55fd..56bfde13bb2 100644 --- a/store/postgres/src/query_store.rs +++ b/store/postgres/src/query_store.rs @@ -2,9 +2,10 @@ use std::collections::HashMap; use std::time::Instant; use crate::deployment_store::{DeploymentStore, ReplicaId}; +use crate::sql::Parser; use graph::components::store::{DeploymentId, QueryPermit, QueryStore as QueryStoreTrait}; use graph::data::query::Trace; -use graph::data::store::QueryObject; +use graph::data::store::{QueryObject, SqlQueryObject}; use graph::prelude::*; use graph::schema::{ApiSchema, InputSchema}; @@ -16,6 +17,7 @@ pub(crate) struct QueryStore { store: Arc, chain_store: Arc, api_version: Arc, + sql_parser: Result, } impl QueryStore { @@ -26,12 +28,16 @@ impl QueryStore { replica_id: ReplicaId, api_version: Arc, ) -> Self { + let sql_parser = store + .find_layout(site.clone()) + .map(|layout| Parser::new(layout, BLOCK_NUMBER_MAX)); QueryStore { site, replica_id, store, chain_store, api_version, + sql_parser, } } } @@ -57,6 +63,33 @@ impl QueryStoreTrait for QueryStore { }) } + fn execute_sql( + &self, + sql: &str, + ) -> Result, graph::prelude::QueryExecutionError> { + // Check if SQL queries are enabled + if !ENV_VARS.sql_queries_enabled() { + return Err(QueryExecutionError::SqlError( + "SQL queries are disabled. Set GRAPH_ENABLE_SQL_QUERIES=true to enable." + .to_string(), + )); + } + + let mut conn = self + .store + .get_replica_conn(self.replica_id) + .map_err(|e| QueryExecutionError::SqlError(format!("SQL error: {}", e)))?; + + let parser = self + .sql_parser + .as_ref() + .map_err(|e| QueryExecutionError::SqlError(format!("SQL error: {}", e)))?; + + let sql = parser.parse_and_validate(sql)?; + + self.store.execute_sql(&mut conn, &sql) + } + /// Return true if the deployment with the given id is fully synced, /// and return false otherwise. Errors from the store are passed back up async fn is_deployment_synced(&self) -> Result { diff --git a/store/postgres/src/relational.rs b/store/postgres/src/relational.rs index 35e35a35746..44bb73e6243 100644 --- a/store/postgres/src/relational.rs +++ b/store/postgres/src/relational.rs @@ -39,7 +39,8 @@ use graph::data_source::CausalityRegion; use graph::internal_error; use graph::prelude::{q, EntityQuery, StopwatchMetrics, ENV_VARS}; use graph::schema::{ - EntityKey, EntityType, Field, FulltextConfig, FulltextDefinition, InputSchema, + AggregationInterval, EntityKey, EntityType, Field, FulltextConfig, FulltextDefinition, + InputSchema, }; use graph::slog::warn; use index::IndexList; @@ -94,7 +95,7 @@ pub const STRING_PREFIX_SIZE: usize = 256; pub const BYTE_ARRAY_PREFIX_SIZE: usize = 64; lazy_static! { - static ref STATEMENT_TIMEOUT: Option = ENV_VARS + pub(crate) static ref STATEMENT_TIMEOUT: Option = ENV_VARS .graphql .sql_statement_timeout .map(|duration| format!("set local statement_timeout={}", duration.as_millis())); @@ -442,12 +443,13 @@ impl Layout { Ok(()) } - /// Find the table with the provided `name`. The name must exactly match - /// the name of an existing table. No conversions of the name are done - pub fn table(&self, name: &SqlName) -> Option<&Table> { + /// Find the table with the provided `sql_name`. The name must exactly + /// match the name of an existing table. No conversions of the name are + /// done + pub fn table(&self, sql_name: &str) -> Option<&Table> { self.tables .values() - .find(|table| &table.name == name) + .find(|table| &table.name == sql_name) .map(|rc| rc.as_ref()) } @@ -1155,6 +1157,27 @@ impl Layout { Ok(rollups) } + /// Given an aggregation name that is already snake-cased like `stats` + /// (for an an aggregation `type Stats @aggregation(..)`) and an + /// interval, return the table that holds the aggregated data, like + /// `stats_hour`. + pub fn aggregation_table( + &self, + aggregation: &str, + interval: AggregationInterval, + ) -> Option<&Table> { + let sql_name = format!("{}_{interval}", aggregation); + self.table(&sql_name) + } + + /// Return true if the layout has an aggregation with the given name + /// like `stats` (already snake_cased) + pub fn has_aggregation(&self, aggregation: &str) -> bool { + self.input_schema + .aggregation_names() + .any(|agg_name| SqlName::from(agg_name).as_str() == aggregation) + } + /// Roll up all timeseries for each entry in `block_times`. The overall /// effect is that all buckets that end after `last_rollup` and before /// the last entry in `block_times` are filled. This will fill all diff --git a/store/postgres/src/relational/ddl_tests.rs b/store/postgres/src/relational/ddl_tests.rs index b15a40cecfb..6a9a2fdfaee 100644 --- a/store/postgres/src/relational/ddl_tests.rs +++ b/store/postgres/src/relational/ddl_tests.rs @@ -26,9 +26,7 @@ fn test_layout(gql: &str) -> Layout { #[test] fn table_is_sane() { let layout = test_layout(THING_GQL); - let table = layout - .table(&"thing".into()) - .expect("failed to get 'thing' table"); + let table = layout.table("thing").expect("failed to get 'thing' table"); assert_eq!(SqlName::from("thing"), table.name); assert_eq!("Thing", table.object.as_str()); diff --git a/store/postgres/src/relational_queries.rs b/store/postgres/src/relational_queries.rs index 89efd0415da..062a37526cc 100644 --- a/store/postgres/src/relational_queries.rs +++ b/store/postgres/src/relational_queries.rs @@ -14,6 +14,8 @@ use diesel::sql_types::{Array, BigInt, Binary, Bool, Int8, Integer, Jsonb, Text, use diesel::QuerySource as _; use graph::components::store::write::{EntityWrite, RowGroup, WriteChunk}; use graph::components::store::{Child as StoreChild, DerivedEntityQuery}; + +use graph::data::graphql::IntoValue; use graph::data::store::{Id, IdType, NULL}; use graph::data::store::{IdList, IdRef, QueryObject}; use graph::data::value::{Object, Word}; @@ -439,6 +441,47 @@ pub fn parse_id(id_type: IdType, json: serde_json::Value) -> Result r::Value { + JSONData::to_value(self.data) + } +} + +impl JSONData { + pub fn to_value(data: serde_json::Value) -> r::Value { + match data { + serde_json::Value::Null => r::Value::Null, + serde_json::Value::Bool(b) => r::Value::Boolean(b), + serde_json::Value::Number(n) => { + if let Some(i) = n.as_i64() { + r::Value::Int(i) + } else { + r::Value::Float(n.as_f64().unwrap()) + } + } + serde_json::Value::String(s) => r::Value::String(s), + serde_json::Value::Array(vals) => { + let vals: Vec<_> = vals.into_iter().map(JSONData::to_value).collect::>(); + r::Value::List(vals) + } + serde_json::Value::Object(map) => { + let mut m = std::collections::BTreeMap::new(); + for (k, v) in map { + let value = JSONData::to_value(v); + m.insert(Word::from(k), value); + } + r::Value::object(m) + } + } + } +} + /// Helper struct for retrieving entities from the database. With diesel, we /// can only run queries that return columns whose number and type are known /// at compile time. Because of that, we retrieve the actual data for an diff --git a/store/postgres/src/sql/constants.rs b/store/postgres/src/sql/constants.rs new file mode 100644 index 00000000000..b24f191f938 --- /dev/null +++ b/store/postgres/src/sql/constants.rs @@ -0,0 +1,435 @@ +use std::collections::HashSet; + +use lazy_static::lazy_static; +use sqlparser::dialect::PostgreSqlDialect; + +lazy_static! { + pub(super) static ref ALLOWED_FUNCTIONS: HashSet<&'static str> = { + vec![ + // Comparison Functions see https://www.postgresql.org/docs/14/functions-comparison.html#FUNCTIONS-COMPARISON-FUNC-TABLE + "num_nonnulls", // Number of non-null arguments + "num_nulls", // Number of null arguments + + // Mathematical Functions see https://www.postgresql.org/docs/14/functions-math.html#FUNCTIONS-MATH-FUNC-TABLE + "abs", // Asolute value + "cbrt", // Cube root + "ceil", // Nearest integer greater than or equal to argument + "ceiling", // Nearest integer greater than or equal to argument + "degrees", // Converts radians to degrees + "div", // Integer quotient of y/x (truncates towards zero) + "exp", // Exponential (e raised to the given power) + "factorial", // Factorial + "floor", // Nearest integer less than or equal to argument + "gcd", // Greatest common divisor (the largest positive number that divides both inputs with no remainder); returns 0 if both inputs are zero; available for integer, bigint, and numeric + "lcm", // Least common multiple (the smallest strictly positive number that is an integral multiple of both inputs); returns 0 if either input is zero; available for integer, bigint, and numeric + "ln", // Natural logarithm + "log", // Base 10 logarithm + "log10", // Base 10 logarithm (same as log) + "mod", // Remainder of y/x; available for smallint, integer, bigint, and numeric + "pi", // Approximate value of π + "power", // a raised to the power of b + "radians", // Converts degrees to radians + "round", // Rounds to nearest integer. For numeric, ties are broken by rounding away from zero. For double precision, the tie-breaking behavior is platform dependent, but “round to nearest even” is the most common rule. + "scale", // Scale of the argument (the number of decimal digits in the fractional part) + "sign", // Sign of the argument (-1, 0, or +1) + "sqrt", // Square root + "trim_scale", // Reduces the value's scale (number of fractional decimal digits) by removing trailing zeroes + "trunc", // Truncates to integer (towards zero) + "width_bucket", // Returns the number of the bucket in which operand falls in a histogram having count equal-width buckets spanning the range low to high. Returns 0 or count+1 for an input outside that range. + + // Random Functions see https://www.postgresql.org/docs/14/functions-math.html#FUNCTIONS-MATH-RANDOM-TABLE + "random", // Returns a random value in the range 0.0 <= x < 1.0 + "setseed", // Sets the seed for subsequent random() calls; argument must be between -1.0 and 1.0, inclusive + + // Trigonometric Functions see https://www.postgresql.org/docs/14/functions-math.html#FUNCTIONS-MATH-TRIG-TABLE + "acos", // Arc cosine, result in radians + "acosd", // Arc cosine, result in degrees + "asin", // Arc sine, result in radians + "asind", // Arc sine, result in degrees + "atan", // Arc tangent, result in radians + "atand", // Arc tangent, result in degrees + "atan2", // Arc tangent of y/x, result in radians + "atan2d", // Arc tangent of y/x, result in degrees + "cos", // Cosine, argument in radians + "cosd", // Cosine, argument in degrees + "cot", // Cotangent, argument in radians + "cotd", // Cotangent, argument in degrees + "sin", // Sine, argument in radians + "sind", // Sine, argument in degrees + "tan", // Tangent, argument in radians + "tand", // Tangent, argument in degrees + + // Hyperbolic Functions see https://www.postgresql.org/docs/14/functions-math.html#FUNCTIONS-MATH-HYPERBOLIC-TABLE + "sinh", // Hyperbolic sine + "cosh", // Hyperbolic cosine + "tanh", // Hyperbolic tangent + "asinh", // Inverse hyperbolic sine + "acosh", // Inverse hyperbolic cosine + "atanh", // Inverse hyperbolic tangent + + // String Functions see https://www.postgresql.org/docs/14/functions-string.html#FUNCTIONS-STRING-SQL + "bit_length", // Number of bits in string + "char_length", // Number of characters in string + "character_length", // Synonym for char_length + "lower", // Convert string to lower case + "normalize", // Convert string to specified Unicode normalization form + "octet_length", // Number of bytes in string + "overlay", // Replace substring + "position", // Location of specified substring + "substring", // Extract substring + "trim", // Remove leading and trailing characters + "upper", // Convert string to upper case + + //Additional string functions see https://www.postgresql.org/docs/14/functions-string.html#FUNCTIONS-STRING-OTHER + "ascii", // Convert first character to its numeric code + "btrim", // Remove the longest string containing only characters from characters (a space by default) from the start and end of string + "chr", // Convert integer to character + "concat", // Concatenate strings + "concat_ws", // Concatenate with separator + "format", // Format arguments according to a format string + "initcap", // Convert first letter of each word to upper case and the rest to lower case + "left", // Extract substring + "length", // Number of characters in string + "lpad", // Pad string to length length by prepending the characters fill (a space by default) + "ltrim", // Remove the longest string containing only characters from characters (a space by default) from the start of string + "md5", // Compute MD5 hash + "parse_ident", // Split qualified_identifier into an array of identifiers, removing any quoting of individual identifiers + "quote_ident", // Returns the given string suitably quoted to be used as an identifier in an SQL statement string + "quote_literal", // Returns the given string suitably quoted to be used as a string literal in an SQL statement string + "quote_nullable", // Returns the given string suitably quoted to be used as a string literal in an SQL statement string; or, if the argument is null, returns NULL + "regexp_match", // Returns captured substrings resulting from the first match of a POSIX regular expression to the string + "regexp_matches", // Returns captured substrings resulting from the first match of a POSIX regular expression to the string, or multiple matches if the g flag is used + "regexp_replace", // Replaces substrings resulting from the first match of a POSIX regular expression, or multiple substring matches if the g flag is used + "regexp_split_to_array", // Splits string using a POSIX regular expression as the delimiter, producing an array of results + "regexp_split_to_table", // Splits string using a POSIX regular expression as the delimiter, producing a set of results + "repeat", // Repeats string the specified number of times + "replace", // Replaces all occurrences in string of substring from with substring to + "reverse", // Reverses the order of the characters in the string + "right", // Extract substring + "rpad", // Pad string to length length by appending the characters fill (a space by default) + "rtrim", // Remove the longest string containing only characters from characters (a space by default) from the end of string + "split_part", // Splits string at occurrences of delimiter and returns the n'th field (counting from one), or when n is negative, returns the |n|'th-from-last field + "strpos", // Returns first starting index of the specified substring within string, or zero if it's not present + "substr", // Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified + "starts_with", // Returns true if string starts with prefix + "string_to_array", // Splits the string at occurrences of delimiter and forms the resulting fields into a text array + "string_to_table", // Splits the string at occurrences of delimiter and returns the resulting fields as a set of text rows + "to_ascii", // Converts string to ASCII from another encoding, which may be identified by name or number + "to_hex", // Converts the number to its equivalent hexadecimal representation + "translate", // Replaces each character in string that matches a character in the from set with the corresponding character in the to set + "unistr", // Evaluate escaped Unicode characters in the argument + + // Binary String Functions see https://www.postgresql.org/docs/14/functions-binarystring.html#FUNCTIONS-BINARYSTRING-OTHER + "bit_count", // Number of bits set in the argument + "get_bit", // Extracts the n'th bit from string + "get_byte", // Extracts the n'th byte from string + "set_bit", // Sets the n'th bit in string to newvalue + "set_byte", // Sets the n'th byte in string to newvalue + "sha224", // Compute SHA-224 hash + "sha256", // Compute SHA-256 hash + "sha384", // Compute SHA-384 hash + "sha512", // Compute SHA-512 hash + + // String conversion functions see https://www.postgresql.org/docs/14/functions-binarystring.html#FUNCTIONS-BINARYSTRING-CONVERSIONS + "convert", // Converts a binary string representing text in encoding src_encoding to a binary string in encoding dest_encoding + "convert_from", // Converts a binary string representing text in encoding src_encoding to text in the database encoding + "convert_to", // Converts a text string (in the database encoding) to a binary string encoded in encoding dest_encoding + "encode", // Encodes binary data into a textual representation + "decode", // Decodes binary data from a textual representation + + // Formatting Functions see https://www.postgresql.org/docs/14/functions-formatting.html#FUNCTIONS-FORMATTING-TABLE + "to_char", // Converts number to a string according to the given format + "to_date", // Converts string to date + "to_number", // Converts string to number + "to_timestamp", // Converts string to timestamp with time zone + + // Date/Time Functions see https://www.postgresql.org/docs/14/functions-datetime.html + "age", // Subtract arguments, producing a “symbolic” result that uses years and months, rather than just days + "clock_timestamp", // Current date and time (changes during statement execution) + "current_date", // Current date + "current_time", // Current time of day + "current_timestamp", // Current date and time (start of current transaction) + "date_bin", // Bin input into specified interval aligned with specified origin + "date_part", // Get subfield (equivalent to extract) + "date_trunc", // Truncate to specified precision + "extract", // Get subfield + "isfinite", // Test for finite date (not +/-infinity) + "justify_days", // Adjust interval so 30-day time periods are represented as months + "justify_hours", // Adjust interval so 24-hour time periods are represented as days + "justify_interval", // Adjust interval using justify_days and justify_hours, with additional sign adjustments + "localtime", // Current time of day + "localtimestamp", // Current date and time (start of current transaction) + "make_date", // Create date from year, month and day fields (negative years signify BC) + "make_interval", // Create interval from years, months, weeks, days, hours, minutes and seconds fields, each of which can default to zero + "make_time", // Create time from hour, minute and seconds fields + "make_timestamp", // Create timestamp from year, month, day, hour, minute and seconds fields (negative years signify BC) + "make_timestamptz", // Create timestamp with time zone from year, month, day, hour, minute and seconds fields (negative years signify BC). + "now", // Current date and time (start of current transaction) + "statement_timestamp", // Current date and time (start of current statement) + "timeofday", // Current date and time (like clock_timestamp, but as a text string) + "transaction_timestamp", // Current date and time (start of current transaction) + + // Enum support functions see https://www.postgresql.org/docs/14/functions-enum.html#FUNCTIONS-ENUM-SUPPORT + "enum_first", // Returns the first value of an enum type + "enum_last", // Returns the last value of an enum type + "enum_range", // Returns a range of values of an enum type + + // Geometric Functions see https://www.postgresql.org/docs/14/functions-geometry.html + "area", // Computes area + "center", // Computes center point + "diagonal", // Extracts box's diagonal as a line segment (same as lseg(box)) + "diameter", // Computes diameter of circle + "height", // Computes vertical size of box + "isclosed", // Is path closed? + "isopen", // Is path open? + "length", // Computes the total length + "npoints", // Returns the number of points + "pclose", // Converts path to closed form + "popen", // Converts path to open form + "radius", // Computes radius of circle + "slope", // Computes slope of a line drawn through the two points + "width", // Computes horizontal size of box + + // Geometric Type Conversion Functions see https://www.postgresql.org/docs/14/functions-geometry.html + "box", // Convert to a box + "circle", // Convert to a circle + "line", // Convert to a line + "lseg", // Convert to a line segment + "path", // Convert to a path + "point", // Convert to a point + "polygon", // Convert to a polygon + + // IP Address Functions see https://www.postgresql.org/docs/14/functions-net.html + "abbrev", // Creates an abbreviated display format as text + "broadcast", // Computes the broadcast address for the address's network + "family", // Returns the address's family: 4 for IPv4, 6 for IPv6 + "host", // Returns the IP address as text, ignoring the netmask + "hostmask", // Computes the host mask for the address's network + "inet_merge", // Computes the smallest network that includes both of the given networks + "inet_same_family", // Tests whether the addresses belong to the same IP family + "masklen", // Returns the netmask length in bits + "netmask", // Computes the network mask for the address's network + "network", // Returns the network part of the address, zeroing out whatever is to the right of the netmask + "set_masklen", // Sets the netmask length for an inet value. The address part does not change + "text", // Returns the unabbreviated IP address and netmask length as text + + // MAC Address Functions see https://www.postgresql.org/docs/14/functions-net.html#MACADDR-FUNCTIONS-TABLE + "macaddr8_set7bit", //Sets the 7th bit of the address to one, creating what is known as modified EUI-64, for inclusion in an IPv6 address. + + // Text Search Functions see https://www.postgresql.org/docs/14/functions-textsearch.html + "array_to_tsvector", // Converts an array of lexemes to a tsvector + "get_current_ts_config", // Returns the OID of the current default text search configuration (as set by default_text_search_config) + "numnode", // Returns the number of lexemes plus operators in the tsquery + "plainto_tsquery", // Converts text to a tsquery, normalizing words according to the specified or default configuration. + "phraseto_tsquery", // Converts text to a tsquery, normalizing words according to the specified or default configuration. + "websearch_to_tsquery", // Converts text to a tsquery, normalizing words according to the specified or default configuration. + "querytree", // Produces a representation of the indexable portion of a tsquery. A result that is empty or just T indicates a non-indexable query. + "setweight", // Assigns the specified weight to each element of the vector. + "strip", // Removes positions and weights from the tsvector. + "to_tsquery", // Converts text to a tsquery, normalizing words according to the specified or default configuration. + "to_tsvector", // Converts text to a tsvector, normalizing words according to the specified or default configuration. + "json_to_tsvector", // Selects each item in the JSON document that is requested by the filter and converts each one to a tsvector, normalizing words according to the specified or default configuration. + "jsonb_to_tsvector",// Selects each item in the JSON document that is requested by the filter and converts each one to a tsvector, normalizing words according to the specified or default configuration. + "ts_delete", // Removes any occurrence of the given lexeme from the vector. + "ts_filter", // Selects only elements with the given weights from the vector. + "ts_headline", // Displays, in an abbreviated form, the match(es) for the query in the document, which must be raw text not a tsvector. + "ts_rank", // Computes a score showing how well the vector matches the query. See Section 12.3.3 for details. + "ts_rank_cd", // Computes a score showing how well the vector matches the query, using a cover density algorithm. See Section 12.3.3 for details. + "ts_rewrite", // Replaces occurrences of target with substitute within the query. See Section + "tsquery_phrase", // Constructs a phrase query that searches for matches of query1 and query2 at successive lexemes (same as <-> operator). + "tsvector_to_array", // Converts a tsvector to an array of lexemes. + + // Text search debugging functions see https://www.postgresql.org/docs/14/functions-textsearch.html#TEXTSEARCH-FUNCTIONS-DEBUG-TABLE + "ts_debug", // Extracts and normalizes tokens from the document according to the specified or default text search configuration, and returns information about how each token was processed. See Section 12.8.1 for details. + "ts_lexize", // Returns an array of replacement lexemes if the input token is known to the dictionary, or an empty array if the token is known to the dictionary but it is a stop word, or NULL if it is not a known word. See Section 12.8.3 for details. + "ts_parse", // Extracts tokens from the document using the named parser. See Section 12.8.2 for details. + "ts_token_type", // Returns a table that describes each type of token the named parser can recognize. See Section 12.8.2 for details. + + // UUID Functions see https://www.postgresql.org/docs/14/functions-uuid.html + "gen_random_uuid", // Generate a version 4 (random) UUID + + // XML Functions see https://www.postgresql.org/docs/14/functions-xml.html + "xmlcomment", // Creates an XML comment + "xmlconcat", // Concatenates XML values + "xmlelement", // Creates an XML element + "xmlforest", // Creates an XML forest (sequence) of elements + "xmlpi", // Creates an XML processing instruction + "xmlagg", // Concatenates the input values to the aggregate function call, much like xmlconcat does, except that concatenation occurs across rows rather than across expressions in a single row. + "xmlexists", // Evaluates an XPath 1.0 expression (the first argument), with the passed XML value as its context item. + "xml_is_well_formed", // Checks whether the argument is a well-formed XML document or fragment. + "xml_is_well_formed_content", // Checks whether the argument is a well-formed XML document or fragment, and that it contains no document type declaration. + "xml_is_well_formed_document", // Checks whether the argument is a well-formed XML document. + "xpath", // Evaluates the XPath 1.0 expression xpath (given as text) against the XML value xml. + "xpath_exists", // Evaluates the XPath 1.0 expression xpath (given as text) against the XML value xml, and returns true if the expression selects at least one node, otherwise false. + "xmltable", // Expands an XML value into a table whose columns match the rowtype defined by the function's parameter list. + "table_to_xml", // Converts a table to XML. + "cursor_to_xml", // Converts a cursor to XML. + + // JSON and JSONB creation functions see https://www.postgresql.org/docs/14/functions-json.html#FUNCTIONS-JSON-CREATION-TABLE + "to_json", // Converts any SQL value to JSON. + "to_jsonb", // Converts any SQL value to JSONB. + "array_to_json", // Converts an SQL array to a JSON array. + "row_to_json", // Converts an SQL composite value to a JSON object. + "json_build_array", // Builds a possibly-heterogeneously-typed JSON array out of a variadic argument list. + "jsonb_build_array", // Builds a possibly-heterogeneously-typed JSON array out of a variadic argument list. + "json_build_object", // Builds a JSON object out of a variadic argument list. + "json_object", // Builds a JSON object out of a text array. + "jsonb_object", // Builds a JSONB object out of a text array. + + // JSON and JSONB processing functions see https://www.postgresql.org/docs/14/functions-json.html#FUNCTIONS-JSON-PROCESSING-TABLE + "json_array_elements", // Expands the top-level JSON array into a set of JSON values. + "jsonb_array_elements", // Expands the top-level JSON array into a set of JSONB values. + "json_array_elements_text", // Expands the top-level JSON array into a set of text values. + "jsonb_array_elements_text", // Expands the top-level JSONB array into a set of text values. + "json_array_length", // Returns the number of elements in the top-level JSON array. + "jsonb_array_length", // Returns the number of elements in the top-level JSONB array. + "json_each", // Expands the top-level JSON object into a set of key/value pairs. + "jsonb_each", // Expands the top-level JSONB object into a set of key/value pairs. + "json_each_text", // Expands the top-level JSON object into a set of key/value pairs. The returned values will be of type text. + "jsonb_each_text", // Expands the top-level JSONB object into a set of key/value pairs. The returned values will be of type text. + "json_extract_path", // Extracts JSON sub-object at the specified path. + "jsonb_extract_path", // Extracts JSONB sub-object at the specified path. + "json_extract_path_text", // Extracts JSON sub-object at the specified path as text. + "jsonb_extract_path_text", // Extracts JSONB sub-object at the specified path as text. + "json_object_keys", // Returns the set of keys in the top-level JSON object. + "jsonb_object_keys", // Returns the set of keys in the top-level JSONB object. + "json_populate_record", // Expands the top-level JSON object to a row having the composite type of the base argument. + "jsonb_populate_record", // Expands the top-level JSON object to a row having the composite type of the base argument. + "json_populate_recordset", // Expands the top-level JSON array of objects to a set of rows having the composite type of the base argument. + "jsonb_populate_recordset", // Expands the top-level JSONB array of objects to a set of rows having the composite type of the base argument. + "json_to_record", // Expands the top-level JSON object to a row having the composite type defined by an AS clause. + "jsonb_to_record", // Expands the top-level JSONB object to a row having the composite type defined by an AS clause. + "json_to_recordset", // Expands the top-level JSON array of objects to a set of rows having the composite type defined by an AS clause. + "jsonb_to_recordset", // Expands the top-level JSONB array of objects to a set of rows having the composite type defined by an AS clause. + "json_strip_nulls", // Deletes all object fields that have null values from the given JSON value, recursively. + "jsonb_strip_nulls", // Deletes all object fields that have null values from the given JSONB value, recursively. + "jsonb_set", // Returns target with the item designated by path replaced by new_value, or with new_value added if create_if_missing is true (which is the default) and the item designated by path does not exist. + "jsonb_set_lax", // If new_value is not NULL, behaves identically to jsonb_set. Otherwise behaves according to the value of null_value_treatment which must be one of 'raise_exception', 'use_json_null', 'delete_key', or 'return_target'. The default is 'use_json_null'. + "jsonb_insert", //Returns target with new_value inserted. + "jsonb_path_exists", // Checks whether the JSON path returns any item for the specified JSON value. + "jsonb_path_match", // Returns the result of a JSON path predicate check for the specified JSON value. + "jsonb_path_query", // Returns all JSON items returned by the JSON path for the specified JSON value. + "jsonb_path_query_array", // Returns all JSON items returned by the JSON path for the specified JSON value, as a JSON array. + "jsonb_path_query_first", // Returns the first JSON item returned by the JSON path for the specified JSON value. Returns NULL if there are no results. + "jsonb_path_exists_tz", // Support comparisons of date/time values that require timezone-aware conversions. + "jsonb_path_match_tz", // Support comparisons of date/time values that require timezone-aware conversions. + "jsonb_path_query_tz", // Support comparisons of date/time values that require timezone-aware conversions. + "jsonb_path_query_array_tz", // Support comparisons of date/time values that require timezone-aware conversions. + "jsonb_path_query_first_tz", // Support comparisons of date/time values that require timezone-aware conversions. + "jsonb_pretty", // Converts the given JSON value to pretty-printed, indented text. + "json_typeof", // Returns the type of the top-level JSON value as a text string. + "jsonb_typeof", // Returns the type of the top-level JSONB value as a text string. + + // Conditional Expressions hhttps://www.postgresql.org/docs/14/functions-conditional.html + "coalesce", // Return first non-null argument. + "nullif", // Return null if two arguments are equal, otherwise return the first argument. + "greatest", // Return greatest of a list of values. + "least", // Return smallest of a list of values. + + // Array Functions https://www.postgresql.org/docs/14/functions-array.html#ARRAY-FUNCTIONS-TABLE + "array_append", // Appends an element to the end of an array (same as the || operator). + "array_cat", // Concatenates two arrays (same as the || operator). + "array_dims", // Returns a text representation of the array's dimensions. + "array_fill", // Returns an array filled with copies of the given value, having dimensions of the lengths specified by the second argument. The optional third argument supplies lower-bound values for each dimension (which default to all 1). + "array_length", // Returns the length of the requested array dimension. (Produces NULL instead of 0 for empty or missing array dimensions.) + "array_lower", // Returns the lower bound of the requested array dimension. + "array_ndims", // Returns the number of dimensions of the array. + "array_position", // Returns the subscript of the first occurrence of the second argument in the array, or NULL if it's not present. + "array_prepend", // Prepends an element to the beginning of an array (same as the || operator). + "array_remove", // Removes all elements equal to the given value from the array. The array must be one-dimensional. Comparisons are done using IS NOT DISTINCT FROM semantics, so it is possible to remove NULLs. + "array_replace", // Replaces each array element equal to the second argument with the third argument. + "array_to_string", // Converts each array element to its text representation, and concatenates those separated by the delimiter string. If null_string is given and is not NULL, then NULL array entries are represented by that string; otherwise, they are omitted. + "array_upper", // Returns the upper bound of the requested array dimension. + "cardinality", // Returns the total number of elements in the array, or 0 if the array is empty. + "trim_array", // Trims an array by removing the last n elements. If the array is multidimensional, only the first dimension is trimmed. + "unnest", // Expands an array into a set of rows. The array's elements are read out in storage order. + + // Range Functions https://www.postgresql.org/docs/14/functions-range.html#RANGE-FUNCTIONS-TABLE + "lower", // Extracts the lower bound of the range (NULL if the range is empty or the lower bound is infinite). + "upper", // Extracts the upper bound of the range (NULL if the range is empty or the upper bound is infinite). + "isempty", // Is the range empty? + "lower_inc", // Is the range's lower bound inclusive? + "upper_inc", // Is the range's upper bound inclusive? + "lower_inf", // Is the range's lower bound infinite? + "upper_inf", // Is the range's upper bound infinite? + "range_merge", // Computes the smallest range that includes both of the given ranges. + + // Multi-range Functions https://www.postgresql.org/docs/14/functions-range.html#MULTIRANGE-FUNCTIONS-TABLE + "multirange", // Returns a multirange containing just the given range. + + // General purpose aggregate functions https://www.postgresql.org/docs/14/functions-aggregate.html#FUNCTIONS-AGGREGATE-TABLE + "array_agg", // Collects all the input values, including nulls, into an array. + "avg", // Computes the average (arithmetic mean) of all the non-null input values. + "bit_and", // Computes the bitwise AND of all non-null input values. + "bit_or", // Computes the bitwise OR of all non-null input values. + "bit_xor", // Computes the bitwise exclusive OR of all non-null input values. Can be useful as a checksum for an unordered set of values. + "bool_and", // Returns true if all non-null input values are true, otherwise false. + "bool_or", // Returns true if any non-null input value is true, otherwise false. + "count", // Computes the number of input rows. + "every", // This is the SQL standard's equivalent to bool_and. + "json_agg", // Collects all the input values, including nulls, into a JSON array. Values are converted to JSON as per to_json or to_jsonb. + "json_object_agg", // Collects all the key/value pairs into a JSON object. Key arguments are coerced to text; value arguments are converted as per to_json or to_jsonb. Values can be null, but not keys. + "max", // Computes the maximum of the non-null input values. Available for any numeric, string, date/time, or enum type, as well as inet, interval, money, oid, pg_lsn, tid, and arrays of any of these types. + "min", // Computes the minimum of the non-null input values. Available for any numeric, string, date/time, or enum type, as well as inet, interval, money, oid, pg_lsn, tid, and arrays of any of these types. + "range_agg", // Computes the union of the non-null input values. + "range_intersect_agg", // Computes the intersection of the non-null input values. + "string_agg", // Concatenates the non-null input values into a string. Each value after the first is preceded by the corresponding delimiter (if it's not null). + "sum", // Computes the sum of the non-null input values. + "xmlagg", // Concatenates the non-null XML input values. + + // Statistical aggregate functions https://www.postgresql.org/docs/14/functions-aggregate.html#FUNCTIONS-AGGREGATE-STATISTICS-TABLE + "corr", // Computes the correlation coefficient. + "covar_pop", // Computes the population covariance. + "covar_samp", // Computes the sample covariance. + "regr_avgx", // Computes the average of the independent variable, sum(X)/N. + "regr_avgy", // Computes the average of the dependent variable, sum(Y)/N. + "regr_count", // Computes the number of rows in which both inputs are non-null. + "regr_intercept", // Computes the y-intercept of the least-squares-fit linear equation determined by the (X, Y) pairs. + "regr_r2", // Computes the square of the correlation coefficient. + "regr_slope", // Computes the slope of the least-squares-fit linear equation determined by the (X, Y) pairs. + "regr_sxx", // Computes the “sum of squares” of the independent variable, sum(X^2) - sum(X)^2/N. + "regr_sxy", // Computes the “sum of products” of independent times dependent variables, sum(X*Y) - sum(X) * sum(Y)/N. + "regr_syy", // Computes the “sum of squares” of the dependent variable, sum(Y^2) - sum(Y)^2/N. + "stddev", // This is a historical alias for stddev_samp. + "stddev_pop", // Computes the population standard deviation of the input values. + "stddev_samp", // Computes the sample standard deviation of the input values. + "variance", // This is a historical alias for var_samp. + "var_pop", // Computes the population variance of the input values (square of the population standard deviation). + "var_samp", // Computes the sample variance of the input values (square of the sample standard deviation). + + // Ordered-set aggregate functions https://www.postgresql.org/docs/14/functions-aggregate.html#FUNCTIONS-AGGREGATE-ORDEREDSET-TABLE + "mode", // Computes the mode (most frequent value) of the input values. + "percentile_cont", // Computes the continuous percentile of the input values. + "percentile_disc", // Computes the discrete percentile of the input values. + + // Hypothetical-set aggregate functions https://www.postgresql.org/docs/14/functions-aggregate.html#FUNCTIONS-AGGREGATE-HYPOTHETICAL-TABLE + "rank", // Computes the rank of the current row with gaps; same as row_number of its first peer. + "dense_rank", // Computes the rank of the current row without gaps; this function counts peer groups. + "percent_rank", // Computes the relative rank (percentile) of the current row: (rank - 1) / (total partition rows - 1). + "cume_dist", // Computes the relative rank of the current row: (number of partition rows preceding or peer with current row) / (total partition rows). + + // Grouping set aggregate functions https://www.postgresql.org/docs/14/functions-aggregate.html#FUNCTIONS-AGGREGATE-GROUPINGSET-TABLE + "grouping", // Returns a bit mask indicating which GROUP BY expressions are not included in the current grouping set. + + // Window functions https://www.postgresql.org/docs/14/functions-window.html#FUNCTIONS-WINDOW-TABLE + "row_number", // Number of the current row within its partition, counting from 1. + "ntile", // Integer ranging from 1 to the argument value, dividing the partition as equally as possible. + "lag", // Returns value evaluated at the row that is offset rows before the current row within the partition; if there is no such row, instead returns default (which must be of a type compatible with value). + "lead", // Returns value evaluated at the row that is offset rows after the current row within the partition; if there is no such row, instead returns default (which must be of a type compatible with value). + "first_value", // Returns value evaluated at the row that is the first row of the window frame. + "last_value", // Returns value evaluated at the row that is the last row of the window frame. + "nth_value", // Returns value evaluated at the row that is the n'th row of the window frame (counting from 1); returns NULL if there is no such row. + + // Set returning functions https://www.postgresql.org/docs/14/functions-srf.html + "generate_series", // Expands range arguments into a set of rows. + "generate_subscripts", // Expands array arguments into a set of rows. + + // Abbreivated syntax for common functions + "pow", // see power function + "date", // see to_date + + ].into_iter().collect() + }; +} + +pub(super) static SQL_DIALECT: PostgreSqlDialect = PostgreSqlDialect {}; diff --git a/store/postgres/src/sql/mod.rs b/store/postgres/src/sql/mod.rs new file mode 100644 index 00000000000..55917f854c4 --- /dev/null +++ b/store/postgres/src/sql/mod.rs @@ -0,0 +1,28 @@ +mod constants; +mod parser; +mod validation; + +pub use parser::Parser; + +#[cfg(test)] +mod test { + use std::{collections::BTreeSet, sync::Arc}; + + use graph::{prelude::DeploymentHash, schema::InputSchema}; + + use crate::{ + catalog::Catalog, + primary::{make_dummy_site, Namespace}, + relational::Layout, + }; + + pub(crate) fn make_layout(gql: &str) -> Layout { + let subgraph = DeploymentHash::new("Qmasubgraph").unwrap(); + let schema = InputSchema::parse_latest(gql, subgraph.clone()).unwrap(); + let namespace = Namespace::new("sgd0815".to_string()).unwrap(); + let site = Arc::new(make_dummy_site(subgraph, namespace, "anet".to_string())); + let catalog = Catalog::for_tests(site.clone(), BTreeSet::new()).unwrap(); + let layout = Layout::new(site, &schema, catalog).unwrap(); + layout + } +} diff --git a/store/postgres/src/sql/parser.rs b/store/postgres/src/sql/parser.rs new file mode 100644 index 00000000000..9f1b1483741 --- /dev/null +++ b/store/postgres/src/sql/parser.rs @@ -0,0 +1,174 @@ +use super::{constants::SQL_DIALECT, validation::Validator}; +use crate::relational::Layout; +use anyhow::{anyhow, Ok, Result}; +use graph::{env::ENV_VARS, prelude::BlockNumber}; +use std::sync::Arc; + +pub struct Parser { + layout: Arc, + block: BlockNumber, +} + +impl Parser { + pub fn new(layout: Arc, block: BlockNumber) -> Self { + Self { layout, block } + } + + pub fn parse_and_validate(&self, sql: &str) -> Result { + let mut statements = sqlparser::parser::Parser::parse_sql(&SQL_DIALECT, sql)?; + + let max_offset = ENV_VARS.graphql.max_skip; + let max_limit = ENV_VARS.graphql.max_first; + + let mut validator = Validator::new(&self.layout, self.block, max_limit, max_offset); + validator.validate_statements(&mut statements)?; + + let statement = statements + .get(0) + .ok_or_else(|| anyhow!("No SQL statements found"))?; + + Ok(statement.to_string()) + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use crate::sql::{parser::SQL_DIALECT, test::make_layout}; + use graph::prelude::{lazy_static, serde_yaml, BLOCK_NUMBER_MAX}; + use serde::{Deserialize, Serialize}; + + use pretty_assertions::assert_eq; + + use super::Parser; + + const TEST_GQL: &str = r#" + type Swap @entity(immutable: true) { + id: Bytes! + timestamp: BigInt! + pool: Bytes! + token0: Bytes! + token1: Bytes! + sender: Bytes! + recipient: Bytes! + origin: Bytes! # the EOA that initiated the txn + amount0: BigDecimal! + amount1: BigDecimal! + amountUSD: BigDecimal! + sqrtPriceX96: BigInt! + tick: BigInt! + logIndex: BigInt + } + + type Token @entity { + id: ID! + address: Bytes! # address + symbol: String! + name: String! + decimals: Int! + } + + type Data @entity(timeseries: true) { + id: Int8! + timestamp: Timestamp! + price: Int! + } + + type Stats @aggregation(intervals: ["hour", "day"], source: "Data") { + id: Int8! + timestamp: Timestamp! + sum: BigDecimal! @aggregate(fn: "sum", arg: "price") + } + "#; + + fn parse_and_validate(sql: &str) -> Result { + let parser = Parser::new(Arc::new(make_layout(TEST_GQL)), BLOCK_NUMBER_MAX); + + parser.parse_and_validate(sql) + } + + #[derive(Debug, Serialize, Deserialize)] + struct TestCase { + name: Option, + sql: String, + ok: Option, + err: Option, + } + + impl TestCase { + fn fail( + &self, + name: &str, + msg: &str, + exp: impl std::fmt::Display, + actual: impl std::fmt::Display, + ) { + panic!( + "case {name} failed: {}\n expected: {}\n actual: {}", + msg, exp, actual + ); + } + + fn run(&self, num: usize) { + fn normalize(query: &str) -> String { + sqlparser::parser::Parser::parse_sql(&SQL_DIALECT, query) + .unwrap() + .pop() + .unwrap() + .to_string() + } + + let name = self + .name + .as_ref() + .map(|name| format!("{num} ({name})")) + .unwrap_or_else(|| num.to_string()); + let result = parse_and_validate(&self.sql); + + match (&self.ok, &self.err, result) { + (Some(expected), None, Ok(actual)) => { + let actual = normalize(&actual); + let expected = normalize(expected); + assert_eq!(actual, expected, "case {} failed", name); + } + (None, Some(expected), Err(actual)) => { + let actual = actual.to_string(); + if !actual.contains(expected) { + self.fail(&name, "expected error message not found", expected, actual); + } + } + (Some(_), Some(_), _) => { + panic!("case {} has both ok and err", name); + } + (None, None, _) => { + panic!("case {} has neither ok nor err", name) + } + (None, Some(exp), Ok(actual)) => { + self.fail(&name, "expected an error", exp, actual); + } + (Some(exp), None, Err(actual)) => self.fail(&name, "expected success", exp, actual), + } + } + } + + lazy_static! { + static ref TESTS: Vec = { + let file = std::path::PathBuf::from_iter([ + env!("CARGO_MANIFEST_DIR"), + "src", + "sql", + "parser_tests.yaml", + ]); + let tests = std::fs::read_to_string(file).unwrap(); + serde_yaml::from_str(&tests).unwrap() + }; + } + + #[test] + fn parse_sql() { + for (num, case) in TESTS.iter().enumerate() { + case.run(num); + } + } +} diff --git a/store/postgres/src/sql/parser_tests.yaml b/store/postgres/src/sql/parser_tests.yaml new file mode 100644 index 00000000000..f839ffdf761 --- /dev/null +++ b/store/postgres/src/sql/parser_tests.yaml @@ -0,0 +1,127 @@ +# Test cases for the SQL parser. Each test case has the following fields: +# name : an optional name for error messages +# sql : the SQL query to parse +# ok : the expected rewritten query +# err : a part of the error message if parsing should fail +# Of course, only one of ok and err can be specified + +- sql: select symbol, address from token where decimals > 10 + ok: > + select symbol, address from ( + select "id", "address", "symbol", "name", "decimals" from "sgd0815"."token" where block_range @> 2147483647) as token + where decimals > 10 +- sql: > + with tokens as ( + select * from (values + ('0x0000000000000000000000000000000000000000','eth','ethereum',18), + ('0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48','usdc','usd coin',6) + ) as t(address,symbol,name,decimals)) + select date, t.symbol, sum(amount)/pow(10,t.decimals) as amount + from (select + date(to_timestamp(block_timestamp) at time zone 'utc') as date, + token, amount + from swap as sm, + unnest(sm.amounts_in,sm.tokens_in) as smi(amount,token) + union all + select + date(to_timestamp(block_timestamp) at time zone 'utc') as date, + token, amount + from swap as sm, + unnest(sm.amounts_out,sm.tokens_out) as smo(amount,token)) as tp + inner join + tokens as t on t.address = tp.token + group by tp.date, t.symbol, t.decimals + order by tp.date desc, amount desc + ok: > + with tokens as ( + select * from ( + values ('0x0000000000000000000000000000000000000000', 'eth', 'ethereum', 18), + ('0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48', 'usdc', 'usd coin', 6)) + as t (address, symbol, name, decimals)) + select date, t.symbol, sum(amount) / pow(10, t.decimals) as amount + from (select date(to_timestamp(block_timestamp) at time zone 'utc') as date, token, amount + from (select "id", "timestamp", "pool", "token_0", "token_1", "sender", "recipient", "origin", "amount_0", "amount_1", "amount_usd", "sqrt_price_x96", "tick", "log_index" + from "sgd0815"."swap" where block$ <= 2147483647) as sm, + unnest(sm.amounts_in, sm.tokens_in) as smi (amount, token) + union all + select date(to_timestamp(block_timestamp) at time zone 'utc') as date, token, amount + from (select "id", "timestamp", "pool", "token_0", "token_1", "sender", "recipient", "origin", "amount_0", "amount_1", "amount_usd", "sqrt_price_x96", "tick", "log_index" + from "sgd0815"."swap" where block$ <= 2147483647) as sm, + unnest(sm.amounts_out, sm.tokens_out) as smo (amount, token)) as tp + join tokens as t on t.address = tp.token + group by tp.date, t.symbol, t.decimals + order by tp.date desc, amount desc +- name: pg_sleep forbidden + sql: select pool from swap where '' = (select cast(pg_sleep(5) as text)) + err: Unknown or unsupported function pg_sleep +- name: table functions forbidden + sql: > + select vid, k.sname + from swap, + lateral(select current_schemas as sname from current_schemas(true)) as k + err: Unknown or unsupported function current_schemas +- name: function without parens forbidden + sql: select input_token from swap where '' = (select user) + err: Unknown or unsupported function user +- name: aggregation allowed + sql: > + select token0, sum(amount0) as total_amount + from swap + group by token0 + having sum(amount0) > 1000 + ok: > + SELECT token0, sum(amount0) AS total_amount + FROM (SELECT "id", "timestamp", "pool", "token_0", "token_1", "sender", "recipient", "origin", "amount_0", "amount_1", "amount_usd", "sqrt_price_x96", "tick", "log_index" + FROM "sgd0815"."swap" WHERE block$ <= 2147483647) AS swap + GROUP BY token0 + HAVING sum(amount0) > 1000 +- name: arbitrary function forbidden + sql: > + select token0 from swap + where '' = (select cast(do_strange_math(amount_in) as text)) + err: Unknown or unsupported function do_strange_math +- name: create table forbidden + sql: create table foo (id int primary key); + err: Only SELECT query is supported +- name: insert forbidden + sql: insert into foo values (1); + err: Only SELECT query is supported +- name: CTE allowed + sql: with foo as (select 1) select * from foo + ok: with foo as (select 1) select * from foo +- name: CTE with insert forbidden + sql: with foo as (insert into target values(1)) select * from bar + err: Only SELECT query is supported +- name: only single statement + sql: select 1; select 2; + err: Multi statement is not supported +- name: unknown tables forbidden + sql: select * from unknown_table + err: Unknown table unknown_table +- name: qualified tables are forbidden + sql: select * from pg_catalog.pg_class + err: "Qualified table names are not supported: pg_catalog.pg_class" +- name: aggregation tables are hidden + sql: select * from stats_hour + err: Unknown table stats_hour +- name: CTEs take precedence + sql: with stats_hour as (select 1) select * from stats_hour + ok: WITH stats_hour AS (SELECT 1) SELECT * FROM stats_hour +- name: aggregation tables use function syntax + sql: select * from stats('hour') + ok: SELECT * FROM (SELECT "id", "timestamp", "sum" FROM "sgd0815"."stats_hour" WHERE block$ <= 2147483647) AS stats_hour +- name: unknown aggregation interval + sql: select * from stats('fortnight') + err: Unknown aggregation interval `fortnight` for table stats +- name: aggregation tables with empty arg + sql: select * from stats('') + err: Unknown aggregation interval `` for table stats +- name: aggregation tables with no args + sql: select * from stats() + err: Invalid syntax for aggregation stats +- name: aggregation tables with multiple args + sql: select * from stats('hour', 'day') + err: Invalid syntax for aggregation stats +- name: aggregation tables with alias + sql: select * from stats('hour') as sh + ok: SELECT * FROM (SELECT "id", "timestamp", "sum" FROM "sgd0815"."stats_hour" WHERE block$ <= 2147483647) AS sh diff --git a/store/postgres/src/sql/validation.rs b/store/postgres/src/sql/validation.rs new file mode 100644 index 00000000000..dd9db5cdb5c --- /dev/null +++ b/store/postgres/src/sql/validation.rs @@ -0,0 +1,283 @@ +use graph::prelude::BlockNumber; +use graph::schema::AggregationInterval; +use sqlparser::ast::{ + Expr, FunctionArg, FunctionArgExpr, Ident, ObjectName, Offset, Query, SetExpr, Statement, + TableAlias, TableFactor, Value, VisitMut, VisitorMut, +}; +use sqlparser::parser::Parser; +use std::result::Result; +use std::{collections::HashSet, ops::ControlFlow}; + +use crate::block_range::{BLOCK_COLUMN, BLOCK_RANGE_COLUMN}; +use crate::relational::Layout; + +use super::constants::{ALLOWED_FUNCTIONS, SQL_DIALECT}; + +#[derive(thiserror::Error, Debug, PartialEq)] +pub enum Error { + #[error("Unknown or unsupported function {0}")] + UnknownFunction(String), + #[error("Multi statement is not supported.")] + MultiStatementUnSupported, + #[error("Only SELECT query is supported.")] + NotSelectQuery, + #[error("Unknown table {0}")] + UnknownTable(String), + #[error("Unknown aggregation interval `{1}` for table {0}")] + UnknownAggregationInterval(String, String), + #[error("Invalid syntax for aggregation {0}")] + InvalidAggregationSyntax(String), + #[error("Only constant numbers are supported for LIMIT and OFFSET.")] + UnsupportedLimitOffset, + #[error("The limit of {0} is greater than the maximum allowed limit of {1}.")] + UnsupportedLimit(u32, u32), + #[error("The offset of {0} is greater than the maximum allowed offset of {1}.")] + UnsupportedOffset(u32, u32), + #[error("Qualified table names are not supported: {0}")] + NoQualifiedTables(String), +} + +pub struct Validator<'a> { + layout: &'a Layout, + ctes: HashSet, + block: BlockNumber, + max_limit: u32, + max_offset: u32, +} + +impl<'a> Validator<'a> { + pub fn new(layout: &'a Layout, block: BlockNumber, max_limit: u32, max_offset: u32) -> Self { + Self { + layout, + ctes: Default::default(), + block, + max_limit, + max_offset, + } + } + + fn validate_function_name(&self, name: &ObjectName) -> ControlFlow { + let name = name.to_string().to_lowercase(); + if ALLOWED_FUNCTIONS.contains(name.as_str()) { + ControlFlow::Continue(()) + } else { + ControlFlow::Break(Error::UnknownFunction(name)) + } + } + + pub fn validate_statements(&mut self, statements: &mut Vec) -> Result<(), Error> { + self.ctes.clear(); + + if statements.len() > 1 { + return Err(Error::MultiStatementUnSupported); + } + + if let ControlFlow::Break(error) = statements.visit(self) { + return Err(error); + } + + Ok(()) + } + + pub fn validate_limit_offset(&mut self, query: &mut Query) -> ControlFlow { + let Query { limit, offset, .. } = query; + + if let Some(limit) = limit { + match limit { + Expr::Value(Value::Number(s, _)) => match s.parse::() { + Err(_) => return ControlFlow::Break(Error::UnsupportedLimitOffset), + Ok(limit) => { + if limit > self.max_limit { + return ControlFlow::Break(Error::UnsupportedLimit( + limit, + self.max_limit, + )); + } + } + }, + _ => return ControlFlow::Break(Error::UnsupportedLimitOffset), + } + } + + if let Some(Offset { value, .. }) = offset { + match value { + Expr::Value(Value::Number(s, _)) => match s.parse::() { + Err(_) => return ControlFlow::Break(Error::UnsupportedLimitOffset), + Ok(offset) => { + if offset > self.max_offset { + return ControlFlow::Break(Error::UnsupportedOffset( + offset, + self.max_offset, + )); + } + } + }, + _ => return ControlFlow::Break(Error::UnsupportedLimitOffset), + } + } + ControlFlow::Continue(()) + } +} + +impl VisitorMut for Validator<'_> { + type Break = Error; + + fn pre_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow { + match _statement { + Statement::Query(_) => ControlFlow::Continue(()), + _ => ControlFlow::Break(Error::NotSelectQuery), + } + } + + fn pre_visit_query(&mut self, query: &mut Query) -> ControlFlow { + // Add common table expressions to the set of known tables + if let Some(ref with) = query.with { + self.ctes.extend( + with.cte_tables + .iter() + .map(|cte| cte.alias.name.value.to_lowercase()), + ); + } + + match *query.body { + SetExpr::Select(_) | SetExpr::Query(_) => { /* permitted */ } + SetExpr::SetOperation { .. } => { /* permitted */ } + SetExpr::Table(_) => { /* permitted */ } + SetExpr::Values(_) => { /* permitted */ } + SetExpr::Insert(_) | SetExpr::Update(_) => { + return ControlFlow::Break(Error::NotSelectQuery) + } + } + + self.validate_limit_offset(query) + } + + /// Invoked for any table function in the AST. + /// See [TableFactor::Table.args](sqlparser::ast::TableFactor::Table::args) for more details identifying a table function + fn post_visit_table_factor( + &mut self, + table_factor: &mut TableFactor, + ) -> ControlFlow { + /// Check whether `args` is a single string argument and return that + /// string + fn extract_string_arg(args: &Vec) -> Option { + if args.len() != 1 { + return None; + } + match &args[0] { + FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( + Value::SingleQuotedString(s), + ))) => Some(s.clone()), + _ => None, + } + } + + if let TableFactor::Table { + name, args, alias, .. + } = table_factor + { + if name.0.len() != 1 { + // We do not support schema qualified table names + return ControlFlow::Break(Error::NoQualifiedTables(name.to_string())); + } + let table_name = &name.0[0].value; + + // CTES override subgraph tables + if self.ctes.contains(&table_name.to_lowercase()) && args.is_none() { + return ControlFlow::Continue(()); + } + + let table = match (self.layout.table(table_name), args) { + (None, None) => { + return ControlFlow::Break(Error::UnknownTable(table_name.clone())); + } + (Some(_), Some(_)) => { + // Table exists but has args, must be a function + return self.validate_function_name(&name); + } + (None, Some(args)) => { + // Table does not exist but has args, is either an + // aggregation table in the form () or + // must be a function + + if !self.layout.has_aggregation(table_name) { + // Not an aggregation, must be a function + return self.validate_function_name(&name); + } + + let Some(intv) = extract_string_arg(args) else { + // Looks like an aggregation, but argument is not a single string + return ControlFlow::Break(Error::InvalidAggregationSyntax( + table_name.clone(), + )); + }; + let Some(intv) = intv.parse::().ok() else { + return ControlFlow::Break(Error::UnknownAggregationInterval( + table_name.clone(), + intv, + )); + }; + + let Some(table) = self.layout.aggregation_table(table_name, intv) else { + return self.validate_function_name(&name); + }; + table + } + (Some(table), None) => { + if !table.object.is_object_type() { + // Interfaces and aggregations can not be queried + // with the table name directly + return ControlFlow::Break(Error::UnknownTable(table_name.clone())); + } + table + } + }; + + // Change 'from table [as alias]' to 'from (select {columns} from table) as alias' + let columns = table + .columns + .iter() + .map(|column| column.name.quoted()) + .collect::>() + .join(", "); + let query = if table.immutable { + format!( + "select {columns} from {} where {} <= {}", + table.qualified_name, BLOCK_COLUMN, self.block + ) + } else { + format!( + "select {columns} from {} where {} @> {}", + table.qualified_name, BLOCK_RANGE_COLUMN, self.block + ) + }; + let Statement::Query(subquery) = Parser::parse_sql(&SQL_DIALECT, &query) + .unwrap() + .pop() + .unwrap() + else { + unreachable!(); + }; + let alias = alias.as_ref().map(|alias| alias.clone()).or_else(|| { + Some(TableAlias { + name: Ident::new(table.name.as_str()), + columns: vec![], + }) + }); + *table_factor = TableFactor::Derived { + lateral: false, + subquery, + alias, + }; + } + ControlFlow::Continue(()) + } + + /// Invoked for any function expressions that appear in the AST + fn pre_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow { + if let Expr::Function(function) = _expr { + return self.validate_function_name(&function.name); + } + ControlFlow::Continue(()) + } +} diff --git a/store/test-store/tests/graphql.rs b/store/test-store/tests/graphql.rs index 3ae1fcd2b74..86ed181da39 100644 --- a/store/test-store/tests/graphql.rs +++ b/store/test-store/tests/graphql.rs @@ -1,4 +1,5 @@ pub mod graphql { pub mod introspection; pub mod query; + pub mod sql; } diff --git a/store/test-store/tests/graphql/query.rs b/store/test-store/tests/graphql/query.rs index 5acf9754772..9dc01ce51ff 100644 --- a/store/test-store/tests/graphql/query.rs +++ b/store/test-store/tests/graphql/query.rs @@ -97,7 +97,7 @@ impl std::fmt::Display for IdVal { } #[derive(Clone, Copy, Debug)] -enum IdType { +pub enum IdType { String, Bytes, Int8, @@ -157,7 +157,7 @@ impl IdType { } } - fn deployment_id(&self) -> &str { + pub fn deployment_id(&self) -> &str { match self { IdType::String => "graphqlTestsQuery", IdType::Bytes => "graphqlTestsQueryBytes", @@ -176,7 +176,7 @@ async fn setup_readonly(store: &Store) -> DeploymentLocator { /// data. If the `id` is the same as `id_type.deployment_id()`, the test /// must not modify the deployment in any way as these are reused for other /// tests that expect pristine data -async fn setup( +pub async fn setup( store: &Store, id: &str, features: BTreeSet, diff --git a/store/test-store/tests/graphql/sql.rs b/store/test-store/tests/graphql/sql.rs new file mode 100644 index 00000000000..ac0f3f8ea34 --- /dev/null +++ b/store/test-store/tests/graphql/sql.rs @@ -0,0 +1,289 @@ +// SQL Query Tests for Graph Node +// These tests parallel the GraphQL tests in query.rs but use SQL queries + +use graph::components::store::QueryStoreManager; +use graph::data::query::QueryTarget; +use graph::data::store::SqlQueryObject; +use graph::prelude::{r, QueryExecutionError}; +use std::collections::BTreeSet; +use test_store::{run_test_sequentially, STORE}; + +#[cfg(debug_assertions)] +use graph::env::ENV_VARS; + +// Import test setup from query.rs module +use super::query::{setup, IdType}; + +/// Synchronous wrapper for SQL query execution +fn run_sql_query(sql: &str, test: F) +where + F: Fn(Result, QueryExecutionError>, IdType) + Send + 'static, +{ + let sql = sql.to_string(); // Convert to owned String + run_test_sequentially(move |store| async move { + ENV_VARS.enable_sql_queries_for_tests(true); + + for id_type in [IdType::String, IdType::Bytes, IdType::Int8] { + let name = id_type.deployment_id(); + let deployment = setup(store.as_ref(), name, BTreeSet::new(), id_type).await; + + let query_store = STORE + .query_store(QueryTarget::Deployment( + deployment.hash.clone(), + Default::default(), + )) + .await + .unwrap(); + + let result = query_store.execute_sql(&sql); + test(result, id_type); + } + + ENV_VARS.enable_sql_queries_for_tests(false); + }); +} + +#[test] +fn sql_can_query_simple_select() { + const SQL: &str = "SELECT id, name FROM musician ORDER BY id"; + + run_sql_query(SQL, |result, _| { + let results = result.expect("SQL query should succeed"); + assert_eq!(results.len(), 5, "Should return 5 musicians"); + + // Check first musician + if let Some(first) = results.first() { + if let r::Value::Object(ref obj) = first.0 { + if let Some(r::Value::String(name)) = obj.get("name") { + assert_eq!(name, "John", "First musician should be John"); + } + } + } + }); +} + +#[test] +fn sql_can_query_with_where_clause() { + const SQL: &str = "SELECT id, name FROM musician WHERE name = 'John'"; + + run_sql_query(SQL, |result, _| { + let results = result.expect("SQL query should succeed"); + assert_eq!(results.len(), 1, "Should return 1 musician named John"); + + if let Some(first) = results.first() { + if let r::Value::Object(ref obj) = first.0 { + if let Some(r::Value::String(name)) = obj.get("name") { + assert_eq!(name, "John", "Should return John"); + } + } + } + }); +} + +#[test] +fn sql_can_query_with_aggregation() { + const SQL: &str = "SELECT COUNT(*) as total FROM musician"; + + run_sql_query(SQL, |result, _| { + let results = result.expect("SQL query should succeed"); + assert_eq!(results.len(), 1, "Should return 1 row with count"); + + if let Some(first) = results.first() { + if let r::Value::Object(ref obj) = first.0 { + if let Some(total) = obj.get("total") { + // The count should be a number (could be various forms) + match total { + r::Value::Int(n) => assert_eq!(*n, 5), + r::Value::String(s) => assert_eq!(s, "5"), + _ => panic!("Total should be a number: {:?}", total), + } + } + } + } + }); +} + +#[test] +fn sql_can_query_with_limit_offset() { + const SQL: &str = "SELECT id, name FROM musician ORDER BY id LIMIT 2 OFFSET 1"; + + run_sql_query(SQL, |result, _| { + let results = result.expect("SQL query should succeed"); + assert_eq!(results.len(), 2, "Should return 2 musicians with offset"); + + // Should skip first musician (order may vary by id type) + if let Some(first) = results.first() { + if let r::Value::Object(ref obj) = first.0 { + if let Some(r::Value::String(name)) = obj.get("name") { + // Just check we got a valid musician name + assert!(["John", "Lisa", "Tom", "Valerie", "Paul"].contains(&name.as_str())); + } + } + } + }); +} + +#[test] +fn sql_can_query_with_group_by() { + const SQL: &str = " + SELECT COUNT(*) as musician_count + FROM musician + GROUP BY name + ORDER BY musician_count DESC + "; + + run_sql_query(SQL, |result, _| { + let results = result.expect("SQL query should succeed"); + assert!(!results.is_empty(), "Should return grouped musician counts"); + }); +} + +// Validation Tests + +#[test] +fn sql_validates_table_names() { + const SQL: &str = "SELECT * FROM invalid_table"; + + run_sql_query(SQL, |result, _| { + assert!(result.is_err(), "Query with invalid table should fail"); + if let Err(e) = result { + let error_msg = e.to_string(); + assert!( + error_msg.contains("Unknown table") || error_msg.contains("invalid_table"), + "Error should mention unknown table: {}", + error_msg + ); + } + }); +} + +#[test] +fn sql_validates_functions() { + // Try to use a potentially dangerous function + const SQL: &str = "SELECT pg_sleep(1)"; + + run_sql_query(SQL, |result, _| { + assert!(result.is_err(), "Query with blocked function should fail"); + if let Err(e) = result { + let error_msg = e.to_string(); + assert!( + error_msg.contains("Unknown or unsupported function") + || error_msg.contains("pg_sleep"), + "Error should mention unsupported function: {}", + error_msg + ); + } + }); +} + +#[test] +fn sql_blocks_ddl_statements() { + const SQL: &str = "DROP TABLE musician"; + + run_sql_query(SQL, |result, _| { + assert!(result.is_err(), "DDL statements should be blocked"); + if let Err(e) = result { + let error_msg = e.to_string(); + assert!( + error_msg.contains("Only SELECT query is supported") || error_msg.contains("DROP"), + "Error should mention unsupported statement type: {}", + error_msg + ); + } + }); +} + +#[test] +fn sql_blocks_dml_statements() { + const SQL: &str = "DELETE FROM musician WHERE id = 'm1'"; + + run_sql_query(SQL, |result, _| { + assert!(result.is_err(), "DML statements should be blocked"); + if let Err(e) = result { + let error_msg = e.to_string(); + assert!( + error_msg.contains("Only SELECT query is supported") + || error_msg.contains("DELETE"), + "Error should mention unsupported statement type: {}", + error_msg + ); + } + }); +} + +#[test] +fn sql_blocks_multi_statement() { + const SQL: &str = "SELECT * FROM musician; SELECT * FROM band"; + + run_sql_query(SQL, |result, _| { + assert!(result.is_err(), "Multi-statement queries should be blocked"); + if let Err(e) = result { + let error_msg = e.to_string(); + assert!( + error_msg.contains("Multi statement is not supported") + || error_msg.contains("multiple statements"), + "Error should mention multi-statement restriction: {}", + error_msg + ); + } + }); +} + +#[test] +fn sql_can_query_with_case_expression() { + const SQL: &str = " + SELECT + id, + name, + CASE + WHEN favorite_count > 10 THEN 'popular' + WHEN favorite_count > 5 THEN 'liked' + ELSE 'normal' + END as popularity + FROM musician + ORDER BY id + LIMIT 5 + "; + + run_sql_query(SQL, |result, _| { + let results = result.expect("SQL query with CASE should succeed"); + assert!( + results.len() <= 5, + "Should return limited musicians with popularity" + ); + + // Check that popularity field exists in first result + if let Some(first) = results.first() { + if let r::Value::Object(ref obj) = first.0 { + assert!( + obj.get("popularity").is_some(), + "Should have popularity field" + ); + } + } + }); +} + +#[test] +fn sql_can_query_with_subquery() { + const SQL: &str = " + WITH active_musicians AS ( + SELECT id, name + FROM musician + WHERE name IS NOT NULL + ) + SELECT COUNT(*) as active_count FROM active_musicians + "; + + run_sql_query(SQL, |result, _| { + let results = result.expect("SQL query with CTE should succeed"); + assert_eq!(results.len(), 1, "Should return one count result"); + + if let Some(first) = results.first() { + if let r::Value::Object(ref obj) = first.0 { + let count = obj.get("active_count"); + assert!(count.is_some(), "Should have active_count field"); + } + } + }); +}