diff --git a/rust/cubesql/cubesql/src/compile/engine/context_postgresql.rs b/rust/cubesql/cubesql/src/compile/engine/context_postgresql.rs index 06770f57c3243..b7952f7d32b18 100644 --- a/rust/cubesql/cubesql/src/compile/engine/context_postgresql.rs +++ b/rust/cubesql/cubesql/src/compile/engine/context_postgresql.rs @@ -23,9 +23,9 @@ use super::information_schema::postgres::{ PgCatalogIndexProvider, PgCatalogInheritsProvider, PgCatalogMatviewsProvider, PgCatalogNamespaceProvider, PgCatalogPartitionedTableProvider, PgCatalogProcProvider, PgCatalogRangeProvider, PgCatalogRolesProvider, PgCatalogSequenceProvider, - PgCatalogSettingsProvider, PgCatalogStatActivityProvider, PgCatalogStatUserTablesProvider, - PgCatalogStatioUserTablesProvider, PgCatalogStatsProvider, PgCatalogTableProvider, - PgCatalogTypeProvider, PgCatalogUserProvider, PgCatalogViewsProvider, + PgCatalogSettingsProvider, PgCatalogShdescriptionProvider, PgCatalogStatActivityProvider, + PgCatalogStatUserTablesProvider, PgCatalogStatioUserTablesProvider, PgCatalogStatsProvider, + PgCatalogTableProvider, PgCatalogTypeProvider, PgCatalogUserProvider, PgCatalogViewsProvider, PgPreparedStatementsProvider, }; use crate::{ @@ -136,6 +136,8 @@ impl DatabaseProtocol { "pg_catalog.pg_views".to_string() } else if let Some(_) = any.downcast_ref::() { "pg_catalog.pg_stat_user_tables".to_string() + } else if let Some(_) = any.downcast_ref::() { + "pg_catalog.pg_shdescription".to_string() } else if let Some(_) = any.downcast_ref::() { "pg_catalog.pg_external_schema".to_string() } else if let Some(_) = any.downcast_ref::() { @@ -401,6 +403,7 @@ impl DatabaseProtocol { &context.meta.tables, ))) } + "pg_shdescription" => return Some(Arc::new(PgCatalogShdescriptionProvider::new())), "pg_external_schema" => { return Some(Arc::new(RedshiftPgExternalSchemaProvider::new())) } diff --git a/rust/cubesql/cubesql/src/compile/engine/information_schema/postgres/mod.rs b/rust/cubesql/cubesql/src/compile/engine/information_schema/postgres/mod.rs index 2ca1bf2b5d36a..9b5e52bd30a67 100644 --- a/rust/cubesql/cubesql/src/compile/engine/information_schema/postgres/mod.rs +++ b/rust/cubesql/cubesql/src/compile/engine/information_schema/postgres/mod.rs @@ -34,6 +34,7 @@ mod pg_range; mod pg_roles; mod pg_sequence; mod pg_settings; +mod pg_shdescription; mod pg_stat_activity; mod pg_stat_user_tables; mod pg_statio_user_tables; @@ -69,6 +70,7 @@ pub use pg_range::*; pub use pg_roles::*; pub use pg_sequence::*; pub use pg_settings::*; +pub use pg_shdescription::*; pub use pg_stat_activity::*; pub use pg_stat_user_tables::*; pub use pg_statio_user_tables::*; diff --git a/rust/cubesql/cubesql/src/compile/engine/information_schema/postgres/pg_shdescription.rs b/rust/cubesql/cubesql/src/compile/engine/information_schema/postgres/pg_shdescription.rs new file mode 100644 index 0000000000000..5b8d0997b99cd --- /dev/null +++ b/rust/cubesql/cubesql/src/compile/engine/information_schema/postgres/pg_shdescription.rs @@ -0,0 +1,97 @@ +use std::{any::Any, sync::Arc}; + +use async_trait::async_trait; + +use datafusion::{ + arrow::{ + array::{Array, ArrayRef, StringBuilder, UInt32Builder}, + datatypes::{DataType, Field, Schema, SchemaRef}, + record_batch::RecordBatch, + }, + datasource::{datasource::TableProviderFilterPushDown, TableProvider, TableType}, + error::DataFusionError, + logical_plan::Expr, + physical_plan::{memory::MemoryExec, ExecutionPlan}, +}; + +struct PgCatalogShdescriptionBuilder { + objoid: UInt32Builder, + classoid: UInt32Builder, + description: StringBuilder, +} + +impl PgCatalogShdescriptionBuilder { + fn new(capacity: usize) -> Self { + Self { + objoid: UInt32Builder::new(capacity), + classoid: UInt32Builder::new(capacity), + description: StringBuilder::new(capacity), + } + } + + fn finish(mut self) -> Vec> { + let columns: Vec> = vec![ + Arc::new(self.objoid.finish()), + Arc::new(self.classoid.finish()), + Arc::new(self.description.finish()), + ]; + + columns + } +} + +pub struct PgCatalogShdescriptionProvider { + data: Arc>, +} + +// https://www.postgresql.org/docs/14/catalog-pg-shdescription.html +impl PgCatalogShdescriptionProvider { + pub fn new() -> Self { + let builder = PgCatalogShdescriptionBuilder::new(0); + + Self { + data: Arc::new(builder.finish()), + } + } +} + +#[async_trait] +impl TableProvider for PgCatalogShdescriptionProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn table_type(&self) -> TableType { + TableType::View + } + + fn schema(&self) -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("objoid", DataType::UInt32, false), + Field::new("classoid", DataType::UInt32, false), + Field::new("description", DataType::Utf8, false), + ])) + } + + async fn scan( + &self, + projection: &Option>, + _filters: &[Expr], + _limit: Option, + ) -> Result, DataFusionError> { + let batch = RecordBatch::try_new(self.schema(), self.data.to_vec())?; + + Ok(Arc::new(MemoryExec::try_new( + &[vec![batch]], + self.schema(), + projection.clone(), + )?)) + } + + fn supports_filter_pushdown( + &self, + _filter: &Expr, + ) -> Result { + Ok(TableProviderFilterPushDown::Unsupported) + } +} diff --git a/rust/cubesql/cubesql/src/compile/engine/udf/common.rs b/rust/cubesql/cubesql/src/compile/engine/udf/common.rs index e75583b667099..6535634178916 100644 --- a/rust/cubesql/cubesql/src/compile/engine/udf/common.rs +++ b/rust/cubesql/cubesql/src/compile/engine/udf/common.rs @@ -3999,6 +3999,59 @@ pub fn create_inet_server_addr_udf() -> ScalarUDF { ) } +pub fn create_pg_get_partkeydef_udf() -> ScalarUDF { + let fun = make_scalar_function(move |args: &[ArrayRef]| { + let table_oids = downcast_primitive_arg!(args[0], "table_oid", OidType); + + let result = table_oids + .iter() + .map(|_| None::) + .collect::(); + + Ok(Arc::new(result)) + }); + + create_udf( + "pg_get_partkeydef", + vec![DataType::UInt32], + Arc::new(DataType::Utf8), + Volatility::Immutable, + fun, + ) +} + +pub fn create_pg_relation_size_udf() -> ScalarUDF { + let fun = make_scalar_function(move |args: &[ArrayRef]| { + assert!(args.len() == 1); + + let relids = downcast_primitive_arg!(args[0], "relid", OidType); + + // 8192 is the lowest size for a table that has at least one column + // TODO: check if the requested table actually exists + let result = relids + .iter() + .map(|relid| relid.map(|_| 8192)) + .collect::>(); + + Ok(Arc::new(result)) + }); + + let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(DataType::Int64))); + + ScalarUDF::new( + "pg_relation_size", + &Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::UInt32]), + TypeSignature::Exact(vec![DataType::UInt32, DataType::Utf8]), + ], + Volatility::Immutable, + ), + &return_type, + &fun, + ) +} + pub fn register_fun_stubs(mut ctx: SessionContext) -> SessionContext { macro_rules! register_fun_stub { ($FTYP:ident, $NAME:expr, argc=$ARGC:expr $(, rettyp=$RETTYP:ident)? $(, vol=$VOL:ident)?) => { @@ -4863,13 +4916,6 @@ pub fn register_fun_stubs(mut ctx: SessionContext) -> SessionContext { rettyp = Utf8, vol = Volatile ); - register_fun_stub!( - udf, - "pg_relation_size", - tsigs = [[Regclass], [Regclass, Utf8],], - rettyp = Int64, - vol = Volatile - ); register_fun_stub!( udf, "pg_reload_conf", diff --git a/rust/cubesql/cubesql/src/compile/mod.rs b/rust/cubesql/cubesql/src/compile/mod.rs index d8fd0bdbccab8..d375e9cb03de0 100644 --- a/rust/cubesql/cubesql/src/compile/mod.rs +++ b/rust/cubesql/cubesql/src/compile/mod.rs @@ -5722,6 +5722,20 @@ ORDER BY Ok(()) } + #[tokio::test] + async fn test_pgcatalog_pgshdescription_postgres() -> Result<(), CubeError> { + insta::assert_snapshot!( + "pgcatalog_pgshdescription_postgres", + execute_query( + "SELECT * FROM pg_catalog.pg_shdescription".to_string(), + DatabaseProtocol::PostgreSQL + ) + .await? + ); + + Ok(()) + } + #[tokio::test] async fn test_constraint_column_usage_postgres() -> Result<(), CubeError> { insta::assert_snapshot!( diff --git a/rust/cubesql/cubesql/src/compile/parser.rs b/rust/cubesql/cubesql/src/compile/parser.rs index 4177e7f66ccec..6b4e1ccc4c097 100644 --- a/rust/cubesql/cubesql/src/compile/parser.rs +++ b/rust/cubesql/cubesql/src/compile/parser.rs @@ -73,6 +73,12 @@ pub fn parse_sql_to_statements( "SELECT n.oid,n.*,d.description FROM", "SELECT n.oid as _oid,n.*,d.description FROM", ); + let query = query.replace("SELECT c.oid,c.*,", "SELECT c.oid as _oid,c.*,"); + let query = query.replace("SELECT a.oid,a.*,", "SELECT a.oid as _oid,a.*,"); + let query = query.replace( + "LEFT OUTER JOIN pg_depend dep on dep.refobjid = a.attrelid AND dep.deptype = 'i' and dep.refobjsubid = a.attnum and dep.classid = dep.refclassid", + "LEFT OUTER JOIN pg_depend dep on dep.refobjid = a.attrelid AND dep.deptype = 'i' and dep.refobjsubid = a.attnum", + ); // TODO Superset introspection: LEFT JOIN by ANY() is not supported let query = query.replace( diff --git a/rust/cubesql/cubesql/src/compile/query_engine.rs b/rust/cubesql/cubesql/src/compile/query_engine.rs index 2b5e442f358bb..a9ff013c3dce2 100644 --- a/rust/cubesql/cubesql/src/compile/query_engine.rs +++ b/rust/cubesql/cubesql/src/compile/query_engine.rs @@ -516,6 +516,8 @@ impl QueryEngine for SqlQueryEngine { ctx.register_udf(create_pg_get_indexdef_udf()); ctx.register_udf(create_inet_server_addr_udf()); ctx.register_udf(create_age_udf()); + ctx.register_udf(create_pg_get_partkeydef_udf()); + ctx.register_udf(create_pg_relation_size_udf()); // udaf ctx.register_udaf(create_measure_udaf()); diff --git a/rust/cubesql/cubesql/src/compile/snapshots/cubesql__compile__tests__pgcatalog_pgshdescription_postgres.snap b/rust/cubesql/cubesql/src/compile/snapshots/cubesql__compile__tests__pgcatalog_pgshdescription_postgres.snap new file mode 100644 index 0000000000000..ad8db5b2ea943 --- /dev/null +++ b/rust/cubesql/cubesql/src/compile/snapshots/cubesql__compile__tests__pgcatalog_pgshdescription_postgres.snap @@ -0,0 +1,8 @@ +--- +source: cubesql/src/compile/mod.rs +expression: "execute_query(\"SELECT * FROM pg_catalog.pg_shdescription\".to_string(),\nDatabaseProtocol::PostgreSQL).await?" +--- ++--------+----------+-------------+ +| objoid | classoid | description | ++--------+----------+-------------+ ++--------+----------+-------------+ diff --git a/rust/cubesql/cubesql/src/compile/test/snapshots/cubesql__compile__test__test_introspection__dbeaver_introspection_columns.snap b/rust/cubesql/cubesql/src/compile/test/snapshots/cubesql__compile__test__test_introspection__dbeaver_introspection_columns.snap new file mode 100644 index 0000000000000..4148f0d56245f Binary files /dev/null and b/rust/cubesql/cubesql/src/compile/test/snapshots/cubesql__compile__test__test_introspection__dbeaver_introspection_columns.snap differ diff --git a/rust/cubesql/cubesql/src/compile/test/snapshots/cubesql__compile__test__test_introspection__dbeaver_introspection_tables.snap b/rust/cubesql/cubesql/src/compile/test/snapshots/cubesql__compile__test__test_introspection__dbeaver_introspection_tables.snap new file mode 100644 index 0000000000000..6bffcd63e7a4c --- /dev/null +++ b/rust/cubesql/cubesql/src/compile/test/snapshots/cubesql__compile__test__test_introspection__dbeaver_introspection_tables.snap @@ -0,0 +1,13 @@ +--- +source: cubesql/src/compile/test/test_introspection.rs +expression: "execute_query(r#\"\n select c.oid,pg_catalog.pg_total_relation_size(c.oid) as total_rel_size,pg_catalog.pg_relation_size(c.oid) as rel_size\n FROM pg_class c\n WHERE c.relnamespace=2200\n ORDER BY c.oid\n \"#.to_string(),\nDatabaseProtocol::PostgreSQL).await?" +--- ++-------+----------------+----------+ +| oid | total_rel_size | rel_size | ++-------+----------------+----------+ +| 18000 | 8192 | 8192 | +| 18020 | 8192 | 8192 | +| 18030 | 8192 | 8192 | +| 18036 | 8192 | 8192 | +| 18246 | 8192 | 8192 | ++-------+----------------+----------+ diff --git a/rust/cubesql/cubesql/src/compile/test/snapshots/cubesql__compile__test__test_introspection__dbeaver_introspection_tables_with_descriptions.snap b/rust/cubesql/cubesql/src/compile/test/snapshots/cubesql__compile__test__test_introspection__dbeaver_introspection_tables_with_descriptions.snap new file mode 100644 index 0000000000000..6656934629127 --- /dev/null +++ b/rust/cubesql/cubesql/src/compile/test/snapshots/cubesql__compile__test__test_introspection__dbeaver_introspection_tables_with_descriptions.snap @@ -0,0 +1,13 @@ +--- +source: cubesql/src/compile/test/test_introspection.rs +expression: "execute_query(r#\"\n SELECT c.oid,c.*,d.description,pg_catalog.pg_get_expr(c.relpartbound, c.oid) as partition_expr, pg_catalog.pg_get_partkeydef(c.oid) as partition_key \n FROM pg_catalog.pg_class c\n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=c.oid AND d.objsubid=0 AND d.classoid='pg_class'::regclass\n WHERE c.relnamespace=2200 AND c.relkind not in ('i','I','c')\n ORDER BY c.oid\n \"#.to_string(),\nDatabaseProtocol::PostgreSQL).await?" +--- ++-------+-------+---------------------------+--------------+---------+-----------+----------+-------+-------------+---------------+----------+-----------+---------------+---------------+-------------+-------------+----------------+---------+----------+-----------+-------------+----------------+----------------+----------------+---------------------+----------------+--------------+----------------+------------+--------------+------------+--------+------------+--------------+------------+-------------------------------------------------------+----------------+---------------+ +| _oid | oid | relname | relnamespace | reltype | reloftype | relowner | relam | relfilenode | reltablespace | relpages | reltuples | relallvisible | reltoastrelid | relhasindex | relisshared | relpersistence | relkind | relnatts | relchecks | relhasrules | relhastriggers | relhassubclass | relrowsecurity | relforcerowsecurity | relispopulated | relreplident | relispartition | relrewrite | relfrozenxid | relminmxid | relacl | reloptions | relpartbound | relhasoids | description | partition_expr | partition_key | ++-------+-------+---------------------------+--------------+---------+-----------+----------+-------+-------------+---------------+----------+-----------+---------------+---------------+-------------+-------------+----------------+---------+----------+-----------+-------------+----------------+----------------+----------------+---------------------+----------------+--------------+----------------+------------+--------------+------------+--------+------------+--------------+------------+-------------------------------------------------------+----------------+---------------+ +| 18000 | 18000 | KibanaSampleDataEcommerce | 2200 | 18001 | 0 | 10 | 2 | 0 | 0 | 0 | -1 | 0 | 0 | false | false | p | r | 17 | 0 | false | false | false | false | false | true | p | false | 0 | 0 | 1 | NULL | NULL | NULL | false | Sample data for tracking eCommerce orders from Kibana | NULL | NULL | +| 18020 | 18020 | Logs | 2200 | 18021 | 0 | 10 | 2 | 0 | 0 | 0 | -1 | 0 | 0 | false | false | p | r | 7 | 0 | false | false | false | false | false | true | p | false | 0 | 0 | 1 | NULL | NULL | NULL | false | NULL | NULL | NULL | +| 18030 | 18030 | NumberCube | 2200 | 18031 | 0 | 10 | 2 | 0 | 0 | 0 | -1 | 0 | 0 | false | false | p | r | 3 | 0 | false | false | false | false | false | true | p | false | 0 | 0 | 1 | NULL | NULL | NULL | false | NULL | NULL | NULL | +| 18036 | 18036 | WideCube | 2200 | 18037 | 0 | 10 | 2 | 0 | 0 | 0 | -1 | 0 | 0 | false | false | p | r | 207 | 0 | false | false | false | false | false | true | p | false | 0 | 0 | 1 | NULL | NULL | NULL | false | NULL | NULL | NULL | +| 18246 | 18246 | MultiTypeCube | 2200 | 18247 | 0 | 10 | 2 | 0 | 0 | 0 | -1 | 0 | 0 | false | false | p | r | 67 | 0 | false | false | false | false | false | true | p | false | 0 | 0 | 1 | NULL | NULL | NULL | false | Test cube with a little bit of everything | NULL | NULL | ++-------+-------+---------------------------+--------------+---------+-----------+----------+-------+-------------+---------------+----------+-----------+---------------+---------------+-------------+-------------+----------------+---------+----------+-----------+-------------+----------------+----------------+----------------+---------------------+----------------+--------------+----------------+------------+--------------+------------+--------+------------+--------------+------------+-------------------------------------------------------+----------------+---------------+ diff --git a/rust/cubesql/cubesql/src/compile/test/test_introspection.rs b/rust/cubesql/cubesql/src/compile/test/test_introspection.rs index 0a4f3b9fc0f24..cafc287b5e2f5 100644 --- a/rust/cubesql/cubesql/src/compile/test/test_introspection.rs +++ b/rust/cubesql/cubesql/src/compile/test/test_introspection.rs @@ -899,6 +899,55 @@ async fn dbeaver_introspection() -> Result<(), CubeError> { .await? ); + insta::assert_snapshot!( + "dbeaver_introspection_tables_with_descriptions", + // NOTE: order by added manually to avoid random snapshot order + execute_query( + r#" + SELECT c.oid,c.*,d.description,pg_catalog.pg_get_expr(c.relpartbound, c.oid) as partition_expr, pg_catalog.pg_get_partkeydef(c.oid) as partition_key + FROM pg_catalog.pg_class c + LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=c.oid AND d.objsubid=0 AND d.classoid='pg_class'::regclass + WHERE c.relnamespace=2200 AND c.relkind not in ('i','I','c') + ORDER BY c.oid + "#.to_string(), + DatabaseProtocol::PostgreSQL + ) + .await? + ); + + insta::assert_snapshot!( + "dbeaver_introspection_tables", + // NOTE: order by added manually to avoid random snapshot order + execute_query( + r#" + select c.oid,pg_catalog.pg_total_relation_size(c.oid) as total_rel_size,pg_catalog.pg_relation_size(c.oid) as rel_size + FROM pg_class c + WHERE c.relnamespace=2200 + ORDER BY c.oid + "#.to_string(), + DatabaseProtocol::PostgreSQL + ) + .await? + ); + + insta::assert_snapshot!( + "dbeaver_introspection_columns", + execute_query( + r#" + SELECT c.relname,a.*,pg_catalog.pg_get_expr(ad.adbin, ad.adrelid, true) as def_value,dsc.description,dep.objid + FROM pg_catalog.pg_attribute a + INNER JOIN pg_catalog.pg_class c ON (a.attrelid=c.oid) + LEFT OUTER JOIN pg_catalog.pg_attrdef ad ON (a.attrelid=ad.adrelid AND a.attnum = ad.adnum) + LEFT OUTER JOIN pg_catalog.pg_description dsc ON (c.oid=dsc.objoid AND a.attnum = dsc.objsubid) + LEFT OUTER JOIN pg_depend dep on dep.refobjid = a.attrelid AND dep.deptype = 'i' and dep.refobjsubid = a.attnum and dep.classid = dep.refclassid + WHERE NOT a.attisdropped AND c.relkind not in ('i','I','c') AND c.oid=18000 + ORDER BY a.attnum + "#.to_string(), + DatabaseProtocol::PostgreSQL + ) + .await? + ); + Ok(()) } diff --git a/rust/cubesql/cubesql/src/sql/postgres/shim.rs b/rust/cubesql/cubesql/src/sql/postgres/shim.rs index 6636bba18a31f..adf1c53a1f11b 100644 --- a/rust/cubesql/cubesql/src/sql/postgres/shim.rs +++ b/rust/cubesql/cubesql/src/sql/postgres/shim.rs @@ -1117,13 +1117,21 @@ impl AsyncPostgresShim { if let Some(qtrace) = qtrace { qtrace.push_statement(&query); } - self.prepare_statement(parse.name, Ok(query), false, qtrace, span_id.clone()) - .await?; + self.prepare_statement( + parse.name, + Ok(query), + &parse.param_types, + false, + qtrace, + span_id.clone(), + ) + .await?; } Err(err) => { self.prepare_statement( parse.name, Err(parse.query.to_string()), + &parse.param_types, false, qtrace, span_id.clone(), @@ -1143,6 +1151,7 @@ impl AsyncPostgresShim { &mut self, name: String, query: Result, + param_types: &[u32], from_sql: bool, qtrace: &mut Option, span_id: Option>, @@ -1170,7 +1179,7 @@ impl AsyncPostgresShim { let (pstmt, result) = match query { Ok(query) => { - let stmt_finder = PostgresStatementParamsFinder::new(); + let stmt_finder = PostgresStatementParamsFinder::new(param_types); let parameters: Vec = stmt_finder .find(&query)? .into_iter() @@ -1703,8 +1712,15 @@ impl AsyncPostgresShim { _ => *statement, }; - self.prepare_statement(name.value, Ok(statement), true, qtrace, span_id.clone()) - .await?; + self.prepare_statement( + name.value, + Ok(statement), + &[], + true, + qtrace, + span_id.clone(), + ) + .await?; let plan = QueryPlan::MetaOk(StatusFlags::empty(), CommandCompletion::Prepare); diff --git a/rust/cubesql/cubesql/src/sql/statement.rs b/rust/cubesql/cubesql/src/sql/statement.rs index 77ef0743a4827..aec18178e24e5 100644 --- a/rust/cubesql/cubesql/src/sql/statement.rs +++ b/rust/cubesql/cubesql/src/sql/statement.rs @@ -2,7 +2,7 @@ use itertools::Itertools; use log::trace; use pg_srv::{ protocol::{ErrorCode, ErrorResponse}, - BindValue, PgType, + BindValue, PgType, PgTypeId, }; use sqlparser::ast::{ self, ArrayAgg, Expr, Function, FunctionArg, FunctionArgExpr, Ident, ObjectName, Value, @@ -13,6 +13,7 @@ use std::{collections::HashMap, error::Error}; use super::types::ColumnType; use crate::sql::shim::ConnectionError; +#[derive(Debug)] enum PlaceholderType { String, Number, @@ -521,14 +522,16 @@ impl FoundParameter { } #[derive(Debug)] -pub struct PostgresStatementParamsFinder { - parameters: HashMap, +pub struct PostgresStatementParamsFinder<'t> { + parameters: HashMap, + types: &'t [u32], } -impl PostgresStatementParamsFinder { - pub fn new() -> Self { +impl<'t> PostgresStatementParamsFinder<'t> { + pub fn new(types: &'t [u32]) -> Self { Self { parameters: HashMap::new(), + types, } } @@ -544,7 +547,7 @@ impl PostgresStatementParamsFinder { } } -impl<'ast> Visitor<'ast, ConnectionError> for PostgresStatementParamsFinder { +impl<'ast, 't> Visitor<'ast, ConnectionError> for PostgresStatementParamsFinder<'t> { fn visit_value( &mut self, v: &mut ast::Value, @@ -554,8 +557,15 @@ impl<'ast> Visitor<'ast, ConnectionError> for PostgresStatementParamsFinder { Value::Placeholder(name) => { let position = self.extract_placeholder_index(&name)?; + let coltype = self + .types + .get(position) + .and_then(|pg_type_oid| PgTypeId::from_oid(*pg_type_oid)) + .and_then(|pg_type| ColumnType::from_pg_tid(pg_type).ok()) + .unwrap_or_else(|| pt.to_coltype()); + self.parameters - .insert(position.to_string(), FoundParameter::new(pt.to_coltype())); + .insert(position, FoundParameter::new(coltype)); } _ => {} }; @@ -650,6 +660,8 @@ impl<'ast> Visitor<'ast, ConnectionError> for StatementPlaceholderReplacer { placeholder_type: PlaceholderType, ) -> Result<(), ConnectionError> { match &value { + // NOTE: it does not do any harm if a numeric placeholder is replaced with a string, + // this will be handled with Bind anyway ast::Value::Placeholder(_) => { *value = match placeholder_type { PlaceholderType::String => { @@ -1238,7 +1250,7 @@ mod tests { ) -> Result<(), CubeError> { let stmts = Parser::parse_sql(&PostgreSqlDialect {}, &input).unwrap(); - let finder = PostgresStatementParamsFinder::new(); + let finder = PostgresStatementParamsFinder::new(&[]); let result = finder.find(&stmts[0]).unwrap(); assert_eq!(result, expected); diff --git a/rust/cubesql/cubesql/src/sql/types.rs b/rust/cubesql/cubesql/src/sql/types.rs index f294e5540fa05..9df4e1670440a 100644 --- a/rust/cubesql/cubesql/src/sql/types.rs +++ b/rust/cubesql/cubesql/src/sql/types.rs @@ -1,4 +1,4 @@ -use crate::compile::CommandCompletion; +use crate::{compile::CommandCompletion, CubeError}; use bitflags::bitflags; use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit}; use pg_srv::{protocol::CommandComplete, PgTypeId}; @@ -24,6 +24,26 @@ pub enum ColumnType { } impl ColumnType { + pub fn from_pg_tid(pg_type_oid: PgTypeId) -> Result { + match pg_type_oid { + PgTypeId::TEXT | PgTypeId::VARCHAR => Ok(ColumnType::String), + PgTypeId::FLOAT8 => Ok(ColumnType::Double), + PgTypeId::BOOL => Ok(ColumnType::Boolean), + PgTypeId::INT2 => Ok(ColumnType::Int8), + PgTypeId::INT4 => Ok(ColumnType::Int32), + PgTypeId::INT8 => Ok(ColumnType::Int64), + PgTypeId::BYTEA => Ok(ColumnType::Blob), + PgTypeId::DATE => Ok(ColumnType::Date(false)), + PgTypeId::INTERVAL => Ok(ColumnType::Interval(IntervalUnit::MonthDayNano)), + PgTypeId::TIMESTAMP | PgTypeId::TIMESTAMPTZ => Ok(ColumnType::Timestamp), + PgTypeId::NUMERIC => Ok(ColumnType::Decimal(38, 10)), + _ => Err(CubeError::internal(format!( + "Unable to convert PostgreSQL type oid {} to ColumnType", + pg_type_oid as u32 + ))), + } + } + pub fn to_pg_tid(&self) -> PgTypeId { match self { ColumnType::Blob => PgTypeId::BYTEA,