diff --git a/docs/pages/product/apis-integrations/rest-api/reference.mdx b/docs/pages/product/apis-integrations/rest-api/reference.mdx index ea935aa175c31..3da22d4020322 100644 --- a/docs/pages/product/apis-integrations/rest-api/reference.mdx +++ b/docs/pages/product/apis-integrations/rest-api/reference.mdx @@ -119,12 +119,6 @@ If `disable_post_processing` is set to `true`, Cube will try to generate the SQL as if the query is run without [post-processing][ref-query-wpp], i.e., if it's run as a query with [pushdown][ref-query-wpd]. - - -Currently, the `disable_post_processing` parameter is not yet supported. - - - The response will contain a JSON object with the following properties under the `sql` key: | Property, type | Description | diff --git a/packages/cubejs-api-gateway/src/gateway.ts b/packages/cubejs-api-gateway/src/gateway.ts index 10e8889ccfd8d..59d3bd9822854 100644 --- a/packages/cubejs-api-gateway/src/gateway.ts +++ b/packages/cubejs-api-gateway/src/gateway.ts @@ -330,6 +330,7 @@ class ApiGateway { if (req.query.format === 'sql') { await this.sql4sql({ query: req.query.query, + disablePostProcessing: req.query.disable_post_processing === 'true', context: req.context, res: this.resToResultFn(res) }); @@ -349,6 +350,7 @@ class ApiGateway { if (req.body.format === 'sql') { await this.sql4sql({ query: req.body.query, + disablePostProcessing: req.body.disable_post_processing, context: req.context, res: this.resToResultFn(res) }); @@ -1308,13 +1310,14 @@ class ApiGateway { protected async sql4sql({ query, + disablePostProcessing, context, res, - }: {query: string} & BaseRequest) { + }: {query: string, disablePostProcessing: boolean} & BaseRequest) { try { await this.assertApiScope('data', context.securityContext); - const result = await this.sqlServer.sql4sql(query, context.securityContext); + const result = await this.sqlServer.sql4sql(query, disablePostProcessing, context.securityContext); res({ sql: result }); } catch (e: any) { this.handleError({ diff --git a/packages/cubejs-api-gateway/src/sql-server.ts b/packages/cubejs-api-gateway/src/sql-server.ts index b40c83f3edf11..e7ccc28e64023 100644 --- a/packages/cubejs-api-gateway/src/sql-server.ts +++ b/packages/cubejs-api-gateway/src/sql-server.ts @@ -64,8 +64,8 @@ export class SQLServer { await execSql(this.sqlInterfaceInstance!, sqlQuery, stream, securityContext); } - public async sql4sql(sqlQuery: string, securityContext?: any): Promise { - return sql4sql(this.sqlInterfaceInstance!, sqlQuery, securityContext); + public async sql4sql(sqlQuery: string, disablePostProcessing: boolean, securityContext?: unknown): Promise { + return sql4sql(this.sqlInterfaceInstance!, sqlQuery, disablePostProcessing, securityContext); } protected buildCheckSqlAuth(options: SQLServerOptions): CheckSQLAuthFn { diff --git a/packages/cubejs-backend-native/js/index.ts b/packages/cubejs-backend-native/js/index.ts index f32c824669916..b804d4d2b2fc8 100644 --- a/packages/cubejs-backend-native/js/index.ts +++ b/packages/cubejs-backend-native/js/index.ts @@ -405,10 +405,10 @@ export const execSql = async (instance: SqlInterfaceInstance, sqlQuery: string, }; // TODO parse result from native code -export const sql4sql = async (instance: SqlInterfaceInstance, sqlQuery: string, securityContext?: any): Promise => { +export const sql4sql = async (instance: SqlInterfaceInstance, sqlQuery: string, disablePostProcessing: boolean, securityContext?: unknown): Promise => { const native = loadNative(); - return native.sql4sql(instance, sqlQuery, securityContext ? JSON.stringify(securityContext) : null); + return native.sql4sql(instance, sqlQuery, disablePostProcessing, securityContext ? JSON.stringify(securityContext) : null); }; export const buildSqlAndParams = (cubeEvaluator: any): String => { diff --git a/packages/cubejs-backend-native/src/sql4sql.rs b/packages/cubejs-backend-native/src/sql4sql.rs index fdc627dede12b..bfeef78b49153 100644 --- a/packages/cubejs-backend-native/src/sql4sql.rs +++ b/packages/cubejs-backend-native/src/sql4sql.rs @@ -2,11 +2,13 @@ use std::sync::Arc; use neon::prelude::*; -use cubesql::compile::convert_sql_to_cube_query; use cubesql::compile::datafusion::logical_plan::LogicalPlan; +use cubesql::compile::datafusion::scalar::ScalarValue; +use cubesql::compile::datafusion::variable::VarType; use cubesql::compile::engine::df::scan::CubeScanNode; use cubesql::compile::engine::df::wrapper::{CubeScanWrappedSqlNode, CubeScanWrapperNode}; -use cubesql::sql::Session; +use cubesql::compile::{convert_sql_to_cube_query, DatabaseVariable}; +use cubesql::sql::{Session, CUBESQL_PENALIZE_POST_PROCESSING_VAR}; use cubesql::transport::MetaContext; use cubesql::CubeError; @@ -157,8 +159,20 @@ async fn handle_sql4sql_query( services: Arc, native_auth_ctx: Arc, sql_query: &str, + disable_post_processing: bool, ) -> Result { with_session(&services, native_auth_ctx.clone(), |session| async move { + if disable_post_processing { + let v = DatabaseVariable { + name: CUBESQL_PENALIZE_POST_PROCESSING_VAR.to_string(), + value: ScalarValue::Boolean(Some(true)), + var_type: VarType::UserDefined, + readonly: false, + additional_params: None, + }; + session.state.set_variables(vec![v]); + } + let transport = session.server.transport.clone(); // todo: can we use compiler_cache? let meta_context = transport @@ -176,8 +190,9 @@ async fn handle_sql4sql_query( pub fn sql4sql(mut cx: FunctionContext) -> JsResult { let interface = cx.argument::>(0)?; let sql_query = cx.argument::(1)?.value(&mut cx); + let disable_post_processing = cx.argument::(2)?.value(&mut cx); - let security_context: Option = match cx.argument::(2) { + let security_context: Option = match cx.argument::(3) { Ok(string) => match string.downcast::(&mut cx) { Ok(v) => v.value(&mut cx).parse::().ok(), Err(_) => None, @@ -208,7 +223,13 @@ pub fn sql4sql(mut cx: FunctionContext) -> JsResult { // can do it relatively rare, and in a single loop for all JoinHandles // this is just a watchdog for a Very Bad case, so latency requirement can be quite relaxed runtime.spawn(async move { - let result = handle_sql4sql_query(services, native_auth_ctx, &sql_query).await; + let result = handle_sql4sql_query( + services, + native_auth_ctx, + &sql_query, + disable_post_processing, + ) + .await; if let Err(err) = deferred.try_settle_with(&channel, move |mut cx| { // `neon::result::ResultExt` is implemented only for Result, even though Ok variant is not touched diff --git a/packages/cubejs-testing/test/__snapshots__/smoke-cubesql.test.ts.snap b/packages/cubejs-testing/test/__snapshots__/smoke-cubesql.test.ts.snap index 6e8034d70056c..0c93be9f71419 100644 --- a/packages/cubejs-testing/test/__snapshots__/smoke-cubesql.test.ts.snap +++ b/packages/cubejs-testing/test/__snapshots__/smoke-cubesql.test.ts.snap @@ -36,6 +36,64 @@ Object { } `; +exports[`SQL API Cube SQL over HTTP sql4sql double aggregation post-processing with disabled post-processing 1`] = ` +Object { + "body": Object { + "sql": Object { + "query_type": "pushdown", + "sql": Array [ + "SELECT \\"t\\".\\"avg_t_total_\\" \\"avg_t_total_\\" +FROM ( + SELECT AVG(\\"t\\".\\"total\\") \\"avg_t_total_\\" + FROM ( + SELECT + \\"orders\\".status \\"status\\", sum(\\"orders\\".amount) \\"total\\" + FROM + ( + select 1 as id, 100 as amount, 'new' status, '2024-01-01'::timestamptz created_at + UNION ALL + select 2 as id, 200 as amount, 'new' status, '2024-01-02'::timestamptz created_at + UNION ALL + select 3 as id, 300 as amount, 'processed' status, '2024-01-03'::timestamptz created_at + UNION ALL + select 4 as id, 500 as amount, 'processed' status, '2024-01-04'::timestamptz created_at + UNION ALL + select 5 as id, 600 as amount, 'shipped' status, '2024-01-05'::timestamptz created_at + ) AS \\"orders\\" GROUP BY 1 + ) AS \\"t\\" +) AS \\"t\\"", + Array [], + ], + "status": "ok", + }, + }, + "headers": Headers { + Symbol(map): Object { + "access-control-allow-origin": Array [ + "*", + ], + "connection": Array [ + "keep-alive", + ], + "content-length": Array [ + "878", + ], + "content-type": Array [ + "application/json; charset=utf-8", + ], + "keep-alive": Array [ + "timeout=5", + ], + "x-powered-by": Array [ + "Express", + ], + }, + }, + "status": 200, + "statusText": "OK", +} +`; + exports[`SQL API Cube SQL over HTTP sql4sql regular query 1`] = ` Object { "body": Object { @@ -244,6 +302,42 @@ Object { } `; +exports[`SQL API Cube SQL over HTTP sql4sql strictly post-processing with disabled post-processing 1`] = ` +Object { + "body": Object { + "sql": Object { + "error": "Provided query can not be executed without post-processing.", + "query_type": "post_processing", + "status": "error", + }, + }, + "headers": Headers { + Symbol(map): Object { + "access-control-allow-origin": Array [ + "*", + ], + "connection": Array [ + "keep-alive", + ], + "content-length": Array [ + "127", + ], + "content-type": Array [ + "application/json; charset=utf-8", + ], + "keep-alive": Array [ + "timeout=5", + ], + "x-powered-by": Array [ + "Express", + ], + }, + }, + "status": 200, + "statusText": "OK", +} +`; + exports[`SQL API Cube SQL over HTTP sql4sql wrapper 1`] = ` Object { "body": Object { diff --git a/packages/cubejs-testing/test/smoke-cubesql.test.ts b/packages/cubejs-testing/test/smoke-cubesql.test.ts index 53052bdf3bc1c..e8d71956921e4 100644 --- a/packages/cubejs-testing/test/smoke-cubesql.test.ts +++ b/packages/cubejs-testing/test/smoke-cubesql.test.ts @@ -149,7 +149,7 @@ describe('SQL API', () => { }); describe('sql4sql', () => { - async function generateSql(query: string) { + async function generateSql(query: string, disablePostPprocessing: boolean = false) { const response = await fetch(`${birdbox.configuration.apiUrl}/sql`, { method: 'POST', headers: { @@ -159,6 +159,7 @@ describe('SQL API', () => { body: JSON.stringify({ query, format: 'sql', + disable_post_processing: disablePostPprocessing, }), }); const { status, statusText, headers } = response; @@ -193,6 +194,10 @@ describe('SQL API', () => { expect(await generateSql(`SELECT version();`)).toMatchSnapshot(); }); + it('strictly post-processing with disabled post-processing', async () => { + expect(await generateSql(`SELECT version();`, true)).toMatchSnapshot(); + }); + it('double aggregation post-processing', async () => { expect(await generateSql(` SELECT AVG(total) @@ -206,6 +211,19 @@ describe('SQL API', () => { `)).toMatchSnapshot(); }); + it('double aggregation post-processing with disabled post-processing', async () => { + expect(await generateSql(` + SELECT AVG(total) + FROM ( + SELECT + status, + SUM(totalAmount) AS total + FROM Orders + GROUP BY 1 + ) t + `, true)).toMatchSnapshot(); + }); + it('wrapper', async () => { expect(await generateSql(` SELECT diff --git a/rust/cubesql/cubesql/src/compile/rewrite/cost.rs b/rust/cubesql/cubesql/src/compile/rewrite/cost.rs index 2a1732849901f..40e5594c854b5 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/cost.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/cost.rs @@ -14,11 +14,15 @@ use indexmap::IndexSet; #[derive(Debug)] pub struct BestCubePlan { meta_context: Arc, + penalize_post_processing: bool, } impl BestCubePlan { - pub fn new(meta_context: Arc) -> Self { - Self { meta_context } + pub fn new(meta_context: Arc, penalize_post_processing: bool) -> Self { + Self { + meta_context, + penalize_post_processing, + } } pub fn initial_cost(&self, enode: &LogicalPlanLanguage, top_down: bool) -> CubePlanCost { @@ -208,6 +212,8 @@ impl BestCubePlan { CubePlanCost { replacers: this_replacers, + // Will be filled in finalize + penalized_ast_size_outside_wrapper: 0, table_scans, filters, filter_members, @@ -239,8 +245,15 @@ impl BestCubePlan { } } +#[derive(Clone, Copy)] +pub struct CubePlanCostOptions { + top_down: bool, + penalize_post_processing: bool, +} + /// This cost struct maintains following structural relationships: /// - `replacers` > other nodes - having replacers in structure means not finished processing +/// - `penalized_ast_size_outside_wrapper` > other nodes - this is used to force "no post processing" mode, only CubeScan and CubeScanWrapped are expected as result /// - `table_scans` > other nodes - having table scan means not detected cube scan /// - `empty_wrappers` > `non_detected_cube_scans` - we don't want empty wrapper to hide non detected cube scan errors /// - `non_detected_cube_scans` > other nodes - minimize cube scans without members @@ -256,6 +269,7 @@ impl BestCubePlan { #[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq)] pub struct CubePlanCost { replacers: i64, + penalized_ast_size_outside_wrapper: usize, table_scans: i64, empty_wrappers: i64, non_detected_cube_scans: i64, @@ -353,11 +367,11 @@ impl CubePlanCostAndState { } } - pub fn finalize(&self, enode: &LogicalPlanLanguage) -> Self { + pub fn finalize(&self, enode: &LogicalPlanLanguage, options: CubePlanCostOptions) -> Self { Self { cost: self .cost - .finalize(&self.state, &self.sort_state, enode, false), + .finalize(&self.state, &self.sort_state, enode, options), state: self.state.clone(), sort_state: self.sort_state.clone(), } @@ -368,6 +382,8 @@ impl CubePlanCost { pub fn add_child(&self, other: &Self) -> Self { Self { replacers: self.replacers + other.replacers, + // Will be filled in finalize + penalized_ast_size_outside_wrapper: 0, table_scans: self.table_scans + other.table_scans, filters: self.filters + other.filters, non_detected_cube_scans: (if other.cube_members == 0 { @@ -419,10 +435,22 @@ impl CubePlanCost { state: &CubePlanState, sort_state: &SortState, enode: &LogicalPlanLanguage, - top_down: bool, + options: CubePlanCostOptions, ) -> Self { + let ast_size_outside_wrapper = match state { + CubePlanState::Wrapped => 0, + CubePlanState::Unwrapped(size) => *size, + CubePlanState::Wrapper => 0, + } + self.ast_size_outside_wrapper; + let penalized_ast_size_outside_wrapper = if options.penalize_post_processing { + ast_size_outside_wrapper + } else { + 0 + }; + Self { replacers: self.replacers, + penalized_ast_size_outside_wrapper, table_scans: self.table_scans, filters: self.filters, non_detected_cube_scans: match state { @@ -440,7 +468,7 @@ impl CubePlanCost { }, non_pushed_down_limit_sort: match sort_state { SortState::DirectChild => self.non_pushed_down_limit_sort, - SortState::Current if top_down => self.non_pushed_down_limit_sort, + SortState::Current if options.top_down => self.non_pushed_down_limit_sort, _ => 0, }, // Don't track state here: we want representation that have fewer wrappers with zero members _in total_ @@ -449,11 +477,7 @@ impl CubePlanCost { errors: self.errors, structure_points: self.structure_points, joins: self.joins, - ast_size_outside_wrapper: match state { - CubePlanState::Wrapped => 0, - CubePlanState::Unwrapped(size) => *size, - CubePlanState::Wrapper => 0, - } + self.ast_size_outside_wrapper, + ast_size_outside_wrapper, empty_wrappers: match state { CubePlanState::Wrapped => 0, CubePlanState::Unwrapped(_) => 0, @@ -538,7 +562,13 @@ impl CostFunction for BestCubePlan { let child = costs(*id); cost.add_child(&child) }) - .finalize(enode); + .finalize( + enode, + CubePlanCostOptions { + top_down: false, + penalize_post_processing: self.penalize_post_processing, + }, + ); res } } @@ -880,6 +910,15 @@ impl TopDownCostFunction CubePlanCost { - CubePlanCost::finalize(&cost, &state.wrapped, &state.limit, node, true) + CubePlanCost::finalize( + &cost, + &state.wrapped, + &state.limit, + node, + CubePlanCostOptions { + top_down: true, + penalize_post_processing: self.penalize_post_processing, + }, + ) } } diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rewriter.rs b/rust/cubesql/cubesql/src/compile/rewrite/rewriter.rs index 488ca241e053d..a562e9c91637a 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rewriter.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rewriter.rs @@ -15,6 +15,7 @@ use crate::{ CubeContext, }, config::ConfigObj, + sql::database_variables::postgres::session_vars::CUBESQL_PENALIZE_POST_PROCESSING_VAR, sql::{compiler_cache::CompilerCacheEntry, AuthContextRef}, transport::{MetaContext, SpanId}, CubeError, @@ -344,15 +345,26 @@ impl Rewriter { .rewrite_rules(cache_entry, true) .await?; + let penalize_post_processing = self + .cube_context + .session_state + .get_variable(CUBESQL_PENALIZE_POST_PROCESSING_VAR) + .map(|v| v.value); + let penalize_post_processing = match penalize_post_processing { + Some(ScalarValue::Boolean(val)) => val.unwrap_or(false), + _ => false, + }; + let (plan, qtrace_egraph_iterations, qtrace_best_graph) = tokio::task::spawn_blocking(move || { let (runner, qtrace_egraph_iterations) = Self::run_rewrites(&cube_context, egraph, rules, "final")?; + // TODO maybe check replacers and penalized_ast_size_outside_wrapper right after extraction? let best = if top_down_extractor { let mut extractor = TopDownExtractor::new( &runner.egraph, - BestCubePlan::new(cube_context.meta.clone()), + BestCubePlan::new(cube_context.meta.clone(), penalize_post_processing), CubePlanTopDownState::new(), ); let Some((best_cost, best)) = extractor.find_best(root) else { @@ -363,7 +375,7 @@ impl Rewriter { } else { let extractor = Extractor::new( &runner.egraph, - BestCubePlan::new(cube_context.meta.clone()), + BestCubePlan::new(cube_context.meta.clone(), penalize_post_processing), ); let (best_cost, best) = extractor.find_best(root); log::debug!("Best cost: {:#?}", best_cost); @@ -376,6 +388,7 @@ impl Rewriter { }; let new_root = Id::from(best.as_ref().len() - 1); log::debug!("Best: {}", best.pretty(120)); + // TODO maybe pass penalize_post_processing here as well, to break with sane error let converter = LanguageToLogicalPlanConverter::new( best, cube_context.clone(), diff --git a/rust/cubesql/cubesql/src/sql/database_variables/mysql/global_vars.rs b/rust/cubesql/cubesql/src/sql/database_variables/mysql/global_vars.rs index b7b010a8ab25a..67cc213cc2ffd 100644 --- a/rust/cubesql/cubesql/src/sql/database_variables/mysql/global_vars.rs +++ b/rust/cubesql/cubesql/src/sql/database_variables/mysql/global_vars.rs @@ -1,211 +1,138 @@ -use std::collections::HashMap; - use crate::compile::{DatabaseVariable, DatabaseVariables}; use datafusion::scalar::ScalarValue; pub fn defaults() -> DatabaseVariables { - let mut variables: DatabaseVariables = HashMap::new(); - - variables.insert( - "max_allowed_packet".to_string(), + let variables = [ DatabaseVariable::system( "max_allowed_packet".to_string(), ScalarValue::UInt32(Some(67108864)), None, ), - ); - variables.insert( - "auto_increment_increment".to_string(), DatabaseVariable::system( "auto_increment_increment".to_string(), ScalarValue::UInt32(Some(1)), None, ), - ); - variables.insert( - "version_comment".to_string(), DatabaseVariable::system( "version_comment".to_string(), ScalarValue::Utf8(Some("mysql".to_string())), None, ), - ); - variables.insert( - "system_time_zone".to_string(), DatabaseVariable::system( "system_time_zone".to_string(), ScalarValue::Utf8(Some("UTC".to_string())), None, ), - ); - variables.insert( - "time_zone".to_string(), DatabaseVariable::system( "time_zone".to_string(), ScalarValue::Utf8(Some("SYSTEM".to_string())), None, ), - ); - - variables.insert( - "tx_isolation".to_string(), DatabaseVariable::system( "tx_isolation".to_string(), ScalarValue::Utf8(Some("REPEATABLE-READ".to_string())), None, ), - ); - variables.insert( - "tx_read_only".to_string(), DatabaseVariable::system( "tx_read_only".to_string(), ScalarValue::Boolean(Some(false)), None, ), - ); - variables.insert( - "transaction_isolation".to_string(), DatabaseVariable::system( "transaction_isolation".to_string(), ScalarValue::Utf8(Some("REPEATABLE-READ".to_string())), None, ), - ); - variables.insert( - "transaction_read_only".to_string(), DatabaseVariable::system( "transaction_read_only".to_string(), ScalarValue::Boolean(Some(false)), None, ), - ); - variables.insert( - "sessiontransaction_isolation".to_string(), DatabaseVariable::system( "sessiontransaction_isolation".to_string(), ScalarValue::Utf8(Some("REPEATABLE-READ".to_string())), None, ), - ); - variables.insert( - "sessionauto_increment_increment".to_string(), DatabaseVariable::system( "sessionauto_increment_increment".to_string(), ScalarValue::Int64(Some(1)), None, ), - ); - variables.insert( - "character_set_client".to_string(), DatabaseVariable::system( "character_set_client".to_string(), ScalarValue::Utf8(Some("utf8mb4".to_string())), None, ), - ); - variables.insert( - "character_set_connection".to_string(), DatabaseVariable::system( "character_set_connection".to_string(), ScalarValue::Utf8(Some("utf8mb4".to_string())), None, ), - ); - variables.insert( - "character_set_results".to_string(), DatabaseVariable::system( "character_set_results".to_string(), ScalarValue::Utf8(Some("utf8mb4".to_string())), None, ), - ); - variables.insert( - "character_set_server".to_string(), DatabaseVariable::system( "character_set_server".to_string(), ScalarValue::Utf8(Some("utf8mb4".to_string())), None, ), - ); - variables.insert( - "collation_connection".to_string(), DatabaseVariable::system( "collation_connection".to_string(), ScalarValue::Utf8(Some("utf8mb4_general_ci".to_string())), None, ), - ); - variables.insert( - "collation_server".to_string(), DatabaseVariable::system( "collation_server".to_string(), ScalarValue::Utf8(Some("utf8mb4_0900_ai_ci".to_string())), None, ), - ); - variables.insert( - "init_connect".to_string(), DatabaseVariable::system( "init_connect".to_string(), ScalarValue::Utf8(Some("".to_string())), None, ), - ); - variables.insert( - "interactive_timeout".to_string(), DatabaseVariable::system( "interactive_timeout".to_string(), ScalarValue::Int32(Some(28800)), None, ), - ); - variables.insert( - "license".to_string(), DatabaseVariable::system( "license".to_string(), ScalarValue::Utf8(Some("Apache 2".to_string())), None, ), - ); - variables.insert( - "lower_case_table_names".to_string(), DatabaseVariable::system( "lower_case_table_names".to_string(), ScalarValue::Int32(Some(0)), None, ), - ); - variables.insert( - "net_buffer_length".to_string(), DatabaseVariable::system( "net_buffer_length".to_string(), ScalarValue::Int32(Some(16384)), None, ), - ); - variables.insert( - "net_write_timeout".to_string(), DatabaseVariable::system( "net_write_timeout".to_string(), ScalarValue::Int32(Some(600)), None, ), - ); - variables.insert( - "wait_timeout".to_string(), DatabaseVariable::system( "wait_timeout".to_string(), ScalarValue::Int32(Some(28800)), None, ), - ); - variables.insert( - "sql_mode".to_string(), - DatabaseVariable::system( - "sql_mode".to_string(), - ScalarValue::Utf8(Some("ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION".to_string())), - None, - ), -); + DatabaseVariable::system( + "sql_mode".to_string(), + ScalarValue::Utf8(Some("ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION".to_string())), + None, + ), + ]; + + let variables = IntoIterator::into_iter(variables) + .map(|v| (v.name.clone(), v)) + .collect::(); + variables } diff --git a/rust/cubesql/cubesql/src/sql/database_variables/mysql/session_vars.rs b/rust/cubesql/cubesql/src/sql/database_variables/mysql/session_vars.rs index b7b010a8ab25a..67cc213cc2ffd 100644 --- a/rust/cubesql/cubesql/src/sql/database_variables/mysql/session_vars.rs +++ b/rust/cubesql/cubesql/src/sql/database_variables/mysql/session_vars.rs @@ -1,211 +1,138 @@ -use std::collections::HashMap; - use crate::compile::{DatabaseVariable, DatabaseVariables}; use datafusion::scalar::ScalarValue; pub fn defaults() -> DatabaseVariables { - let mut variables: DatabaseVariables = HashMap::new(); - - variables.insert( - "max_allowed_packet".to_string(), + let variables = [ DatabaseVariable::system( "max_allowed_packet".to_string(), ScalarValue::UInt32(Some(67108864)), None, ), - ); - variables.insert( - "auto_increment_increment".to_string(), DatabaseVariable::system( "auto_increment_increment".to_string(), ScalarValue::UInt32(Some(1)), None, ), - ); - variables.insert( - "version_comment".to_string(), DatabaseVariable::system( "version_comment".to_string(), ScalarValue::Utf8(Some("mysql".to_string())), None, ), - ); - variables.insert( - "system_time_zone".to_string(), DatabaseVariable::system( "system_time_zone".to_string(), ScalarValue::Utf8(Some("UTC".to_string())), None, ), - ); - variables.insert( - "time_zone".to_string(), DatabaseVariable::system( "time_zone".to_string(), ScalarValue::Utf8(Some("SYSTEM".to_string())), None, ), - ); - - variables.insert( - "tx_isolation".to_string(), DatabaseVariable::system( "tx_isolation".to_string(), ScalarValue::Utf8(Some("REPEATABLE-READ".to_string())), None, ), - ); - variables.insert( - "tx_read_only".to_string(), DatabaseVariable::system( "tx_read_only".to_string(), ScalarValue::Boolean(Some(false)), None, ), - ); - variables.insert( - "transaction_isolation".to_string(), DatabaseVariable::system( "transaction_isolation".to_string(), ScalarValue::Utf8(Some("REPEATABLE-READ".to_string())), None, ), - ); - variables.insert( - "transaction_read_only".to_string(), DatabaseVariable::system( "transaction_read_only".to_string(), ScalarValue::Boolean(Some(false)), None, ), - ); - variables.insert( - "sessiontransaction_isolation".to_string(), DatabaseVariable::system( "sessiontransaction_isolation".to_string(), ScalarValue::Utf8(Some("REPEATABLE-READ".to_string())), None, ), - ); - variables.insert( - "sessionauto_increment_increment".to_string(), DatabaseVariable::system( "sessionauto_increment_increment".to_string(), ScalarValue::Int64(Some(1)), None, ), - ); - variables.insert( - "character_set_client".to_string(), DatabaseVariable::system( "character_set_client".to_string(), ScalarValue::Utf8(Some("utf8mb4".to_string())), None, ), - ); - variables.insert( - "character_set_connection".to_string(), DatabaseVariable::system( "character_set_connection".to_string(), ScalarValue::Utf8(Some("utf8mb4".to_string())), None, ), - ); - variables.insert( - "character_set_results".to_string(), DatabaseVariable::system( "character_set_results".to_string(), ScalarValue::Utf8(Some("utf8mb4".to_string())), None, ), - ); - variables.insert( - "character_set_server".to_string(), DatabaseVariable::system( "character_set_server".to_string(), ScalarValue::Utf8(Some("utf8mb4".to_string())), None, ), - ); - variables.insert( - "collation_connection".to_string(), DatabaseVariable::system( "collation_connection".to_string(), ScalarValue::Utf8(Some("utf8mb4_general_ci".to_string())), None, ), - ); - variables.insert( - "collation_server".to_string(), DatabaseVariable::system( "collation_server".to_string(), ScalarValue::Utf8(Some("utf8mb4_0900_ai_ci".to_string())), None, ), - ); - variables.insert( - "init_connect".to_string(), DatabaseVariable::system( "init_connect".to_string(), ScalarValue::Utf8(Some("".to_string())), None, ), - ); - variables.insert( - "interactive_timeout".to_string(), DatabaseVariable::system( "interactive_timeout".to_string(), ScalarValue::Int32(Some(28800)), None, ), - ); - variables.insert( - "license".to_string(), DatabaseVariable::system( "license".to_string(), ScalarValue::Utf8(Some("Apache 2".to_string())), None, ), - ); - variables.insert( - "lower_case_table_names".to_string(), DatabaseVariable::system( "lower_case_table_names".to_string(), ScalarValue::Int32(Some(0)), None, ), - ); - variables.insert( - "net_buffer_length".to_string(), DatabaseVariable::system( "net_buffer_length".to_string(), ScalarValue::Int32(Some(16384)), None, ), - ); - variables.insert( - "net_write_timeout".to_string(), DatabaseVariable::system( "net_write_timeout".to_string(), ScalarValue::Int32(Some(600)), None, ), - ); - variables.insert( - "wait_timeout".to_string(), DatabaseVariable::system( "wait_timeout".to_string(), ScalarValue::Int32(Some(28800)), None, ), - ); - variables.insert( - "sql_mode".to_string(), - DatabaseVariable::system( - "sql_mode".to_string(), - ScalarValue::Utf8(Some("ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION".to_string())), - None, - ), -); + DatabaseVariable::system( + "sql_mode".to_string(), + ScalarValue::Utf8(Some("ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION".to_string())), + None, + ), + ]; + + let variables = IntoIterator::into_iter(variables) + .map(|v| (v.name.clone(), v)) + .collect::(); + variables } diff --git a/rust/cubesql/cubesql/src/sql/database_variables/postgres/global_vars.rs b/rust/cubesql/cubesql/src/sql/database_variables/postgres/global_vars.rs index 5068393135045..814e317f9c466 100644 --- a/rust/cubesql/cubesql/src/sql/database_variables/postgres/global_vars.rs +++ b/rust/cubesql/cubesql/src/sql/database_variables/postgres/global_vars.rs @@ -1,64 +1,44 @@ -use std::collections::HashMap; - use datafusion::scalar::ScalarValue; use crate::compile::{DatabaseVariable, DatabaseVariables}; pub fn defaults() -> DatabaseVariables { - let mut variables: DatabaseVariables = HashMap::new(); - - variables.insert( - "application_name".to_string(), + let variables = [ DatabaseVariable::system( "application_name".to_string(), ScalarValue::Utf8(None), None, ), - ); - - variables.insert( - "extra_float_digits".to_string(), DatabaseVariable::system( "extra_float_digits".to_string(), ScalarValue::UInt32(Some(1)), None, ), - ); - - variables.insert( - "transaction_isolation".to_string(), DatabaseVariable::system( "transaction_isolation".to_string(), ScalarValue::Utf8(Some("read committed".to_string())), None, ), - ); - - variables.insert( - "max_allowed_packet".to_string(), DatabaseVariable::system( "max_allowed_packet".to_string(), ScalarValue::UInt32(Some(67108864)), None, ), - ); - variables.insert( - "max_index_keys".to_string(), DatabaseVariable::system( "max_index_keys".to_string(), ScalarValue::UInt32(Some(32)), None, ), - ); - - variables.insert( - "lc_collate".to_string(), DatabaseVariable::system( "lc_collate".to_string(), ScalarValue::Utf8(Some("en_US.utf8".to_string())), None, ), - ); + ]; + + let variables = IntoIterator::into_iter(variables) + .map(|v| (v.name.clone(), v)) + .collect::(); variables } diff --git a/rust/cubesql/cubesql/src/sql/database_variables/postgres/session_vars.rs b/rust/cubesql/cubesql/src/sql/database_variables/postgres/session_vars.rs index 4470d2bde4005..a99625081fe29 100644 --- a/rust/cubesql/cubesql/src/sql/database_variables/postgres/session_vars.rs +++ b/rust/cubesql/cubesql/src/sql/database_variables/postgres/session_vars.rs @@ -1,109 +1,77 @@ use datafusion::scalar::ScalarValue; -use std::collections::HashMap; use crate::compile::{DatabaseVariable, DatabaseVariables}; -pub fn defaults() -> DatabaseVariables { - let mut variables: DatabaseVariables = HashMap::new(); +pub const CUBESQL_PENALIZE_POST_PROCESSING_VAR: &str = "cubesql_penalize_post_processing"; - variables.insert( - "client_min_messages".to_string(), +pub fn defaults() -> DatabaseVariables { + let variables = [ DatabaseVariable::system( "client_min_messages".to_string(), ScalarValue::Utf8(Some("NOTICE".to_string())), None, ), - ); - - variables.insert( - "timezone".to_string(), DatabaseVariable::system( "timezone".to_string(), ScalarValue::Utf8(Some("GMT".to_string())), None, ), - ); - - variables.insert( - "application_name".to_string(), DatabaseVariable::system( "application_name".to_string(), ScalarValue::Utf8(None), None, ), - ); - - variables.insert( - "extra_float_digits".to_string(), DatabaseVariable::system( "extra_float_digits".to_string(), ScalarValue::UInt32(Some(1)), None, ), - ); - - variables.insert( - "transaction_isolation".to_string(), DatabaseVariable::system( "transaction_isolation".to_string(), ScalarValue::Utf8(Some("read committed".to_string())), None, ), - ); - - variables.insert( - "max_allowed_packet".to_string(), DatabaseVariable::system( "max_allowed_packet".to_string(), ScalarValue::UInt32(Some(67108864)), None, ), - ); - - variables.insert( - "max_index_keys".to_string(), DatabaseVariable::system( "max_index_keys".to_string(), ScalarValue::UInt32(Some(32)), None, ), - ); - - variables.insert( - "lc_collate".to_string(), DatabaseVariable::system( "lc_collate".to_string(), ScalarValue::Utf8(Some("en_US.utf8".to_string())), None, ), - ); - - variables.insert( - "standard_conforming_strings".to_string(), DatabaseVariable::system( "standard_conforming_strings".to_string(), ScalarValue::Utf8(Some("on".to_string())), None, ), - ); - - variables.insert( - "max_identifier_length".to_string(), DatabaseVariable::system( "max_identifier_length".to_string(), ScalarValue::UInt32(Some(63)), None, ), - ); - - variables.insert( - "role".to_string(), DatabaseVariable::system( "role".to_string(), ScalarValue::Utf8(Some("none".to_string())), None, ), - ); + // Custom cubesql variables + DatabaseVariable::user_defined( + CUBESQL_PENALIZE_POST_PROCESSING_VAR.to_string(), + ScalarValue::Boolean(Some(false)), + None, + ), + ]; + + let variables = IntoIterator::into_iter(variables) + .map(|v| (v.name.clone(), v)) + .collect::(); variables } diff --git a/rust/cubesql/cubesql/src/sql/mod.rs b/rust/cubesql/cubesql/src/sql/mod.rs index 490e7037a700c..776b13db15e78 100644 --- a/rust/cubesql/cubesql/src/sql/mod.rs +++ b/rust/cubesql/cubesql/src/sql/mod.rs @@ -15,6 +15,7 @@ pub use auth_service::{ AuthContext, AuthContextRef, AuthenticateResponse, HttpAuthContext, SqlAuthDefaultImpl, SqlAuthService, }; +pub use database_variables::postgres::session_vars::CUBESQL_PENALIZE_POST_PROCESSING_VAR; pub use postgres::*; pub use server_manager::ServerManager; pub use session::{Session, SessionProcessList, SessionProperties, SessionState};