diff --git a/Cargo.lock b/Cargo.lock index 980835d5..88b3a697 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -704,7 +704,8 @@ dependencies = [ "serde", "serde_json", "socket2", - "sqlparser", + "sqltk", + "sqltk-parser", "temp-env", "thiserror 2.0.12", "tokio", @@ -1164,15 +1165,26 @@ name = "eql-mapper" version = "1.0.0" dependencies = [ "derive_more", + "eql-mapper-macros", + "impl-trait-for-tuples", "itertools 0.13.0", "pretty_assertions", - "sqlparser", "sqltk", + "sqltk-parser", "thiserror 2.0.12", "tracing", "tracing-subscriber", ] +[[package]] +name = "eql-mapper-macros" +version = "2.0.0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -1803,6 +1815,17 @@ dependencies = [ "icu_properties", ] +[[package]] +name = "impl-trait-for-tuples" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0eb5a3343abf848c0984fe4604b2b105da9539376e24fc0a3b0007411ae4fd9" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + [[package]] name = "indexmap" version = "2.8.0" @@ -3434,23 +3457,23 @@ dependencies = [ ] [[package]] -name = "sqlparser" -version = "0.52.0" +name = "sqltk" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a875d8cd437cc8a97e9aeaeea352ec9a19aea99c23e9effb17757291de80b08" +checksum = "7e10583ccac2be380f1dcc8ce79b751d9366e73802c31f6081e69e130545fdf3" dependencies = [ "bigdecimal", - "log", - "serde", + "sqltk-parser", ] [[package]] -name = "sqltk" -version = "0.2.2" -source = "git+https://github.com/cipherstash/sqltk/?rev=214f9b90e4f07d4414292813ffd6e45dec075fbb#214f9b90e4f07d4414292813ffd6e45dec075fbb" +name = "sqltk-parser" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e443633c3788add36b38f86f1a9ca72fb1fd0a8141ae3e94fee41542a59daf47" dependencies = [ "bigdecimal", - "sqlparser", + "log", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index a93f0c70..b131db7c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,8 @@ strip = "none" debug = true [workspace.dependencies] -sqlparser = { version = "^0.52", features = ["bigdecimal", "serde"] } +sqltk = { version = "0.5.0" } +sqltk-parser = { version = "0.52.0" } thiserror = "2.0.9" tokio = { version = "1.44", features = ["full"] } tracing = "0.1" diff --git a/mise.local.example.toml b/mise.local.example.toml index cd8f0241..a1069871 100644 --- a/mise.local.example.toml +++ b/mise.local.example.toml @@ -42,6 +42,6 @@ CS_LOG__AUTHENTICATION_LEVEL = "info" CS_LOG__CONTEXT_LEVEL = "info" CS_LOG__KEYSET_LEVEL = "info" CS_LOG__PROTOCOL_LEVEL = "info" -CS_LOG__MAPPER_LEVEL = "info" +CS_LOG__MAPPER_LEVEL = "debug" CS_LOG__SCHEMA_LEVEL = "info" CS_LOG__CONFIG_LEVEL = "info" diff --git a/mise.toml b/mise.toml index 98610bb4..9cf04e2a 100644 --- a/mise.toml +++ b/mise.toml @@ -108,6 +108,7 @@ run = "docker compose rm --stop --force proxy proxy-tls" alias = ['t', 'ci'] description = "Runs all tests (hygiene, unit, integration)" run = """ +mise run rust:version mise run test:check mise run test:format mise run test:clippy @@ -369,6 +370,15 @@ mise --env tls run proxy:down description = "Runs cargo nextest, skipping integration tests" run = 'cargo nextest run --no-fail-fast --nocapture -E "not package(cipherstash-proxy-integration)" {{arg(name="test",default="")}}' +[tasks."rust:version"] +description = "Outputs rust toolchain version info" +run = """ +echo "rustc --version = " $(rustc --version) +echo "cargo --version = " $(cargo --version) +echo "cargo fmt --version = " $(cargo fmt --version) +echo "cargo clippy --version = " $(cargo clippy --version) +""" + [tasks."test:format"] description = "Runs cargo fmt" run = 'cargo fmt --all -- --check' diff --git a/packages/cipherstash-proxy/Cargo.toml b/packages/cipherstash-proxy/Cargo.toml index a189d091..bc285caa 100644 --- a/packages/cipherstash-proxy/Cargo.toml +++ b/packages/cipherstash-proxy/Cargo.toml @@ -40,7 +40,8 @@ rustls-pki-types = "1.10.0" serde = "1.0" serde_json = "1.0" socket2 = "0.5.7" -sqlparser = { workspace = true } +sqltk = { workspace = true } +sqltk-parser = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true } tokio-postgres = { version = "0.7", features = [ diff --git a/packages/cipherstash-proxy/src/encrypt/schema/manager.rs b/packages/cipherstash-proxy/src/encrypt/schema/manager.rs index 9ad5c52f..dbf48730 100644 --- a/packages/cipherstash-proxy/src/encrypt/schema/manager.rs +++ b/packages/cipherstash-proxy/src/encrypt/schema/manager.rs @@ -4,7 +4,7 @@ use crate::error::Error; use crate::{connect, log::SCHEMA}; use arc_swap::ArcSwap; use eql_mapper::{Column, Schema, Table}; -use sqlparser::ast::Ident; +use sqltk_parser::ast::Ident; use std::sync::Arc; use std::time::Duration; use tokio::{task::JoinHandle, time}; diff --git a/packages/cipherstash-proxy/src/eql/mod.rs b/packages/cipherstash-proxy/src/eql/mod.rs index 1d1c3a9e..59cc764d 100644 --- a/packages/cipherstash-proxy/src/eql/mod.rs +++ b/packages/cipherstash-proxy/src/eql/mod.rs @@ -3,7 +3,7 @@ use cipherstash_client::{ zerokms::{encrypted_record, EncryptedRecord}, }; use serde::{Deserialize, Serialize}; -use sqlparser::ast::Ident; +use sqltk_parser::ast::Ident; #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] pub struct Plaintext { diff --git a/packages/cipherstash-proxy/src/error.rs b/packages/cipherstash-proxy/src/error.rs index 36bc4aea..023cc4c8 100644 --- a/packages/cipherstash-proxy/src/error.rs +++ b/packages/cipherstash-proxy/src/error.rs @@ -1,6 +1,7 @@ use crate::{postgresql::Column, Identifier}; use bytes::BytesMut; use cipherstash_client::encryption; +use eql_mapper::EqlMapperError; use metrics_exporter_prometheus::BuildError; use std::{io, time::Duration}; use thiserror::Error; @@ -92,6 +93,9 @@ pub enum MappingError { #[error("Statement encountered an internal error. This may be a bug in the statement mapping module of CipherStash Proxy. Please visit {}#mapping-internal-error for more information.", ERROR_DOC_BASE_URL)] Internal(String), + + #[error(transparent)] + EqlMapper(#[from] EqlMapperError), } #[derive(Error, Debug)] @@ -320,8 +324,8 @@ impl From for Error { } } -impl From for Error { - fn from(e: sqlparser::parser::ParserError) -> Self { +impl From for Error { + fn from(e: sqltk_parser::parser::ParserError) -> Self { Error::Mapping(MappingError::InvalidSqlStatement(e.to_string())) } } diff --git a/packages/cipherstash-proxy/src/postgresql/data/from_sql.rs b/packages/cipherstash-proxy/src/postgresql/data/from_sql.rs index b44fad3b..a138406a 100644 --- a/packages/cipherstash-proxy/src/postgresql/data/from_sql.rs +++ b/packages/cipherstash-proxy/src/postgresql/data/from_sql.rs @@ -11,7 +11,7 @@ use cipherstash_config::ColumnType; use postgres_types::FromSql; use postgres_types::Type; use rust_decimal::Decimal; -use sqlparser::ast::Value; +use sqltk_parser::ast::Value; use std::str::FromStr; use tracing::debug; diff --git a/packages/cipherstash-proxy/src/postgresql/frontend.rs b/packages/cipherstash-proxy/src/postgresql/frontend.rs index 53b9758e..5a62cd2a 100644 --- a/packages/cipherstash-proxy/src/postgresql/frontend.rs +++ b/packages/cipherstash-proxy/src/postgresql/frontend.rs @@ -26,13 +26,14 @@ use crate::prometheus::{ use crate::Encrypted; use bytes::BytesMut; use cipherstash_client::encryption::Plaintext; -use eql_mapper::{self, EqlMapperError, EqlValue, NodeKey, TableColumn, TypedStatement}; +use eql_mapper::{self, EqlMapperError, EqlValue, TableColumn, TypeCheckedStatement}; use metrics::{counter, histogram}; use pg_escape::quote_literal; use serde::Serialize; -use sqlparser::ast::{self, Expr, Value}; -use sqlparser::dialect::PostgreSqlDialect; -use sqlparser::parser::Parser; +use sqltk::NodeKey; +use sqltk_parser::ast::{self, Value}; +use sqltk_parser::dialect::PostgreSqlDialect; +use sqltk_parser::parser::Parser; use std::collections::HashMap; use std::time::Instant; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; @@ -285,7 +286,7 @@ where match self.to_encryptable_statement(&typed_statement, vec![])? { Some(statement) => { - if statement.has_literals() || typed_statement.has_nodes_to_wrap() { + if typed_statement.requires_transform() { let encrypted_literals = self .encrypt_literals(&typed_statement, &statement.literal_columns) .await?; @@ -356,7 +357,7 @@ where /// async fn encrypt_literals( &mut self, - typed_statement: &TypedStatement<'_>, + typed_statement: &TypeCheckedStatement<'_>, literal_columns: &Vec>, ) -> Result>, Error> { let literal_values = typed_statement.literal_values(); @@ -402,14 +403,14 @@ where /// async fn transform_statement( &mut self, - typed_statement: &TypedStatement<'_>, + typed_statement: &TypeCheckedStatement<'_>, encrypted_literals: &Vec>, ) -> Result, Error> { // Convert literals to ast Expr let mut encrypted_expressions = vec![]; for encrypted in encrypted_literals { let e = match encrypted { - Some(en) => Some(to_json_literal_expr(&en)?), + Some(en) => Some(to_json_literal_value(&en)?), None => None, }; encrypted_expressions.push(e); @@ -426,11 +427,10 @@ where debug!(target: MAPPER, client_id = self.context.client_id, - nodes_to_wrap = typed_statement.nodes_to_wrap.len(), literals = encrypted_nodes.len(), ); - if !typed_statement.has_nodes_to_wrap() && encrypted_nodes.is_empty() { + if !typed_statement.requires_transform() { return Ok(None); } @@ -500,7 +500,7 @@ where match self.to_encryptable_statement(&typed_statement, param_types)? { Some(statement) => { - if statement.has_literals() || typed_statement.has_nodes_to_wrap() { + if typed_statement.requires_transform() { let encrypted_literals = self .encrypt_literals(&typed_statement, &statement.literal_columns) .await?; @@ -607,7 +607,7 @@ where /// fn to_encryptable_statement( &self, - typed_statement: &TypedStatement<'_>, + typed_statement: &TypeCheckedStatement<'_>, param_types: Vec, ) -> Result, Error> { let param_columns = self.get_param_columns(typed_statement)?; @@ -619,8 +619,7 @@ where if (param_columns.is_empty() || no_encrypted_param_columns) && (projection_columns.is_empty() || no_encrypted_projection_columns) - && literal_columns.is_empty() - && !typed_statement.has_nodes_to_wrap() + && !typed_statement.requires_transform() { return Ok(None); } @@ -735,7 +734,10 @@ where Ok(encrypted) } - fn type_check<'a>(&self, statement: &'a ast::Statement) -> Result, Error> { + fn type_check<'a>( + &self, + statement: &'a ast::Statement, + ) -> Result, Error> { match eql_mapper::type_check(self.context.get_table_resolver(), statement) { Ok(typed_statement) => { debug!(target: MAPPER, @@ -778,10 +780,10 @@ where /// fn get_projection_columns( &self, - typed_statement: &eql_mapper::TypedStatement<'_>, + typed_statement: &eql_mapper::TypeCheckedStatement<'_>, ) -> Result>, Error> { let mut projection_columns = vec![]; - if let Some(eql_mapper::Projection::WithColumns(columns)) = &typed_statement.projection { + if let eql_mapper::Projection::WithColumns(columns) = &typed_statement.projection { for col in columns { let eql_mapper::ProjectionColumn { ty, .. } = col; let configured_column = match ty { @@ -813,13 +815,13 @@ where /// fn get_param_columns( &self, - typed_statement: &eql_mapper::TypedStatement<'_>, + typed_statement: &eql_mapper::TypeCheckedStatement<'_>, ) -> Result>, Error> { let mut param_columns = vec![]; for param in typed_statement.params.iter() { let configured_column = match param { - eql_mapper::Value::Eql(EqlValue(TableColumn { table, column })) => { + (_, eql_mapper::Value::Eql(EqlValue(TableColumn { table, column }))) => { let identifier = Identifier::from((table, column)); debug!( @@ -841,7 +843,7 @@ where fn get_literal_columns( &self, - typed_statement: &eql_mapper::TypedStatement<'_>, + typed_statement: &eql_mapper::TypeCheckedStatement<'_>, ) -> Result>, Error> { let mut literal_columns = vec![]; @@ -967,9 +969,9 @@ fn literals_to_plaintext( Ok(plaintexts) } -fn to_json_literal_expr(literal: &T) -> Result +fn to_json_literal_value(literal: &T) -> Result where T: ?Sized + Serialize, { - Ok(serde_json::to_string(literal).map(|json| Expr::Value(Value::SingleQuotedString(json)))?) + Ok(serde_json::to_string(literal).map(Value::SingleQuotedString)?) } diff --git a/packages/eql-mapper-macros/Cargo.toml b/packages/eql-mapper-macros/Cargo.toml new file mode 100644 index 00000000..6ceb98c2 --- /dev/null +++ b/packages/eql-mapper-macros/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "eql-mapper-macros" +version.workspace = true +edition.workspace = true + +[lib] +proc-macro = true + +[dependencies] +syn = { version = "2.0", features = ["full"] } +quote = "1.0" +proc-macro2 = "1.0" \ No newline at end of file diff --git a/packages/eql-mapper-macros/src/lib.rs b/packages/eql-mapper-macros/src/lib.rs new file mode 100644 index 00000000..79bb5872 --- /dev/null +++ b/packages/eql-mapper-macros/src/lib.rs @@ -0,0 +1,115 @@ +use proc_macro::TokenStream; +use quote::{quote, ToTokens}; +use syn::{ + parse::Parse, parse_macro_input, parse_quote, Attribute, FnArg, Ident, ImplItem, ImplItemFn, + ItemImpl, Pat, PatType, Signature, Type, TypePath, TypeReference, +}; + +/// This macro generates consistently defined `#[tracing::instrument]` attributes for `InferType::infer_enter` & +/// `InferType::infer_enter` implementations on `TypeInferencer`. +/// +/// This attribute MUST be defined on the trait `impl` itself (not the trait method impls). +#[proc_macro_attribute] +pub fn trace_infer(_attr: TokenStream, item: TokenStream) -> TokenStream { + let mut input = parse_macro_input!(item as ItemImpl); + + for item in &mut input.items { + if let ImplItem::Fn(ImplItemFn { + attrs, + sig: + Signature { + ident: method, + inputs, + .. + }, + .. + }) = item + { + let node_ident_and_type: Option<(&Ident, &Type)> = + if let Some(FnArg::Typed(PatType { + ty: node_ty, pat, .. + })) = inputs.get(1) + { + if let Pat::Ident(pat_ident) = &**pat { + Some((&pat_ident.ident, node_ty)) + } else { + None + } + } else { + None + }; + + let vec_ident: Ident = parse_quote!(Vec); + + match node_ident_and_type { + Some((node_ident, node_ty)) => { + let (formatter, node_ty_abbrev) = match node_ty { + Type::Reference(TypeReference { elem, .. }) => match &**elem { + Type::Path(TypePath { path, .. }) => { + let last_segment = path.segments.last().unwrap(); + let last_segment_ident = &last_segment.ident; + let last_segment_arguments = if last_segment.arguments.is_empty() { + None + } else { + let args = &last_segment.arguments; + Some(quote!(<#args>)) + }; + match last_segment_ident { + ident if vec_ident == *ident => { + (quote!(crate::FmtAstVec), quote!(#last_segment_ident #last_segment_arguments)) + } + _ => (quote!(crate::FmtAst), quote!(#last_segment_ident #last_segment_arguments)) + } + }, + _ => unreachable!("Infer::infer_enter/infer_exit has sig: infer_..(&mut self, delete: &'ast N) -> Result<(), TypeError>") + }, + _ => unreachable!("Infer::infer_enter/infer_exit has sig: infer_..(&mut self, delete: &'ast N) -> Result<(), TypeError>") + }; + + let node_ty_abbrev = node_ty_abbrev + .to_token_stream() + .to_string() + .replace(" ", ""); + + let target = format!("eql-mapper::{}", method.to_string().to_uppercase()); + + let attr: TracingInstrumentAttr = syn::parse2(quote! { + #[tracing::instrument( + target = #target, + level = "trace", + skip(self, #node_ident), + fields( + ast_ty = #node_ty_abbrev, + ast = %#formatter(#node_ident), + ), + ret(Debug) + )] + }) + .unwrap(); + attrs.push(attr.attr); + } + None => { + return quote!(compile_error!( + "could not determine name of node argumemt in Infer impl" + )) + .to_token_stream() + .into(); + } + } + } + } + + input.to_token_stream().into() +} + +struct TracingInstrumentAttr { + attr: Attribute, +} + +impl Parse for TracingInstrumentAttr { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + Ok(Self { + attr: Attribute::parse_outer(input)?.first().unwrap().clone(), + }) + } +} diff --git a/packages/eql-mapper/Cargo.toml b/packages/eql-mapper/Cargo.toml index e1724278..56c3caac 100644 --- a/packages/eql-mapper/Cargo.toml +++ b/packages/eql-mapper/Cargo.toml @@ -11,10 +11,12 @@ authors = [ ] [dependencies] +eql-mapper-macros = { path = "../eql-mapper-macros" } derive_more = { version = "^1.0", features = ["display", "constructor"] } +impl-trait-for-tuples = "0.2.3" itertools = "^0.13" -sqlparser = { workspace = true } -sqltk = { git = "https://github.com/cipherstash/sqltk/", rev = "214f9b90e4f07d4414292813ffd6e45dec075fbb" } +sqltk = { workspace = true } +sqltk-parser = { workspace = true } thiserror = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } diff --git a/packages/eql-mapper/src/display_helpers.rs b/packages/eql-mapper/src/display_helpers.rs new file mode 100644 index 00000000..e575d1b3 --- /dev/null +++ b/packages/eql-mapper/src/display_helpers.rs @@ -0,0 +1,130 @@ +use std::{ + collections::HashMap, + fmt::{Debug, Display}, +}; + +use sqltk::NodeKey; +use sqltk_parser::ast::{ + Delete, Expr, Function, Insert, Query, Select, SelectItem, SetExpr, Statement, Value, Values, +}; + +use crate::{EqlValue, Param, Type}; + +#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)] +pub struct Fmt(pub(crate) T); + +#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)] +pub struct FmtAst(pub(crate) T); + +#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)] +pub struct FmtAstVec(pub(crate) T); + +impl Display for Fmt> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(node) = self.0.get_as::() { + return Display::fmt(&FmtAst(node), f); + } + if let Some(node) = self.0.get_as::() { + return Display::fmt(&FmtAst(node), f); + } + if let Some(node) = self.0.get_as::() { + return Display::fmt(&FmtAst(node), f); + } + if let Some(node) = self.0.get_as::() { + return Display::fmt(&FmtAst(node), f); + } + if let Some(node) = self.0.get_as::() { + return Display::fmt(&FmtAst(node), f); + } + if let Some(node) = self.0.get_as::() { + return Display::fmt(&FmtAst(node), f); + } + if let Some(node) = self.0.get_as::() { - into_control_flow(self.infer_enter(node))? - } - - if let Some(node) = node.downcast_ref::>() { - into_control_flow(self.infer_enter(node))? - } - - if let Some(node) = node.downcast_ref::() { - into_control_flow(self.infer_enter(node))? - } - - if let Some(node) = node.downcast_ref::() { - into_control_flow(self.infer_enter(node))? - } - + dispatch_all!(self, infer_enter, node); ControlFlow::Continue(()) } fn exit(&mut self, node: &'ast N) -> ControlFlow> { - if let Some(node) = node.downcast_ref::() { - into_control_flow(self.infer_exit(node))? - } - - if let Some(node) = node.downcast_ref::() { - into_control_flow(self.infer_exit(node))? - } - - if let Some(node) = node.downcast_ref::() { - into_control_flow(self.infer_exit(node))? - } - - if let Some(node) = node.downcast_ref::() { - into_control_flow(self.infer_exit(node))? - } - - if let Some(node) = node.downcast_ref::() { - into_control_flow(self.infer_exit(node))? - } - - if let Some(node) = node.downcast_ref::() { - into_control_flow(self.infer_exit(node))? - } - - if let Some(node) = node.downcast_ref::() { - self.0.dump_node(node); - } - - if let Some(node) = node.downcast_ref::() { - self.0.dump_node(node); - } - - if let Some(node) = node.downcast_ref::() { - self.0.dump_node(node); - } - - if let Some(node) = node.downcast_ref::() { - self.0.dump_node(node); - } - - ControlFlow::Continue(()) - } + /// Gets (and creates, if required) the [`Type`] associated with a node. If the node does not already have an + /// associated `Type` then a fresh [`Type::Var`] will be assigned. + fn get_or_init_node_type(&mut self, node: &'ast N) -> Arc { + match self.peek_node_type(node) { + Some(ty) => ty, + None => { + let tvar = self.fresh_tvar(); + self.node_types.insert(node.as_node_key(), tvar); + Type::Var(tvar).into() } + } + } - root_node.accept(&mut FindNodeFromKeyVisitor(self)); + /// Gets (and creates, if required) the [`Type`] associated with a param. If the param does not already have an + /// associated `Type` then a fresh [`Type::Var`] will be assigned. + fn get_or_init_param_type(&mut self, param: &'ast String) -> Arc { + match self.param_types.get(¶m).cloned() { + Some(tvar) => Type::Var(tvar).into(), + None => { + let tvar = self.fresh_tvar(); + self.param_types.insert(param, tvar); + Type::Var(tvar).into() + } } } + + pub(crate) fn fresh_tvar(&mut self) -> TypeVar { + self.tvar_seq.next_value() + } } diff --git a/packages/eql-mapper/src/inference/sequence.rs b/packages/eql-mapper/src/inference/sequence.rs new file mode 100644 index 00000000..0c026086 --- /dev/null +++ b/packages/eql-mapper/src/inference/sequence.rs @@ -0,0 +1,26 @@ +use std::marker::PhantomData; + +use super::unifier::TypeVar; + +#[derive(Debug)] +pub(crate) struct Sequence { + next_value: usize, + _marker: PhantomData, +} + +impl Sequence { + pub(crate) fn new() -> Self { + Self { + next_value: 0, + _marker: PhantomData, + } + } +} + +impl Sequence { + pub(crate) fn next_value(&mut self) -> TypeVar { + let t = TypeVar(self.next_value); + self.next_value += 1; + t + } +} diff --git a/packages/eql-mapper/src/inference/type_error.rs b/packages/eql-mapper/src/inference/type_error.rs index 0aa4a816..2011a3d8 100644 --- a/packages/eql-mapper/src/inference/type_error.rs +++ b/packages/eql-mapper/src/inference/type_error.rs @@ -1,9 +1,7 @@ -use std::collections::HashSet; +use std::{collections::HashSet, sync::Arc}; use crate::{unifier::Type, SchemaError, ScopeError}; -use super::TypeCell; - #[derive(Debug, PartialEq, Eq, thiserror::Error)] pub enum TypeError { #[error("SQL feature {} is not supported", _0)] @@ -24,9 +22,6 @@ pub enum TypeError { #[error("One or more params failed to unify: {}", _0.iter().cloned().collect::>().join(", "))] Params(HashSet), - #[error("Expected scalar type for param {} but got type {}", _0, _1)] - NonScalarParam(String, String), - #[error("Expected param count to be {}, but got {}", _0, _1)] ParamCount(usize, usize), @@ -39,8 +34,8 @@ pub enum TypeError { #[error("{}", _0)] SchemaError(#[from] SchemaError), - #[error("Cannot unify node types for nodes:\n 1. node: {} type: {}\n 2. node: {} type: {}\n error: {}", _1, *_2.as_type(), _3, *_4.as_type(), _0)] - OnNodes(Box, String, TypeCell, String, TypeCell), + #[error("Cannot unify node types for nodes:\n 1. node: {} type: {}\n 2. node: {} type: {}\n error: {}", _1, _2, _3, _4, _0)] + OnNodes(Box, String, Arc, String, Arc), #[error( "Cannot unify node with type:\n node: {}\n type: {} error: {}", diff --git a/packages/eql-mapper/src/inference/type_variables.rs b/packages/eql-mapper/src/inference/type_variables.rs deleted file mode 100644 index 182546c4..00000000 --- a/packages/eql-mapper/src/inference/type_variables.rs +++ /dev/null @@ -1,20 +0,0 @@ -/// A type variable generator. -/// -/// Every time the [`Unifier`] sees a [`TypeVar::Fresh`] it replaces it with a `TypeVar::Assigned(_)` using the next -/// `u32` in the sequence. -#[derive(Debug, Default)] -pub(crate) struct TypeVarGenerator(u32); - -impl TypeVarGenerator { - /// Creates a new `TypeVarGenerator`. - pub(crate) fn new() -> Self { - Self(0) - } - - /// Gets the next type variable. - pub(crate) fn next_tvar(&mut self) -> u32 { - let next_id = self.0; - self.0 += 1; - next_id - } -} diff --git a/packages/eql-mapper/src/inference/unifier/mod.rs b/packages/eql-mapper/src/inference/unifier/mod.rs index d5930547..daac059a 100644 --- a/packages/eql-mapper/src/inference/unifier/mod.rs +++ b/packages/eql-mapper/src/inference/unifier/mod.rs @@ -1,35 +1,93 @@ -use std::{cell::RefCell, rc::Rc}; +use std::{cell::RefCell, collections::HashMap, rc::Rc, sync::Arc}; -mod type_cell; mod types; use crate::inference::TypeError; -pub(crate) use type_cell::*; +use sqltk::AsNodeKey; pub(crate) use types::*; pub use types::{EqlValue, NativeValue, TableColumn}; use super::TypeRegistry; -use tracing::{span, Level}; +use tracing::{event, instrument, Level}; /// Implements the type unification algorithm and maintains an association of type variables with the type that they /// point to. #[derive(Debug)] pub struct Unifier<'ast> { registry: Rc>>, - depth: usize, } impl<'ast> Unifier<'ast> { /// Creates a new `Unifier`. - pub fn new(reg: impl Into>>>) -> Self { + pub fn new(registry: impl Into>>>) -> Self { Self { - registry: reg.into(), - depth: 0, + registry: registry.into(), } } + pub(crate) fn fresh_tvar(&self) -> Arc { + Type::Var(self.registry.borrow_mut().fresh_tvar()).into() + } + + pub(crate) fn get_substitutions(&self) -> HashMap> { + self.registry.borrow().get_substititions() + } + + /// Looks up a previously registered [`Type`] by its [`TypeVar`]. + pub(crate) fn get_type(&self, tvar: TypeVar) -> Option> { + self.registry.borrow().get_type(tvar) + } + + pub(crate) fn get_node_type(&self, node: &'ast N) -> Arc { + let node_type = { self.registry.borrow_mut().get_node_type(node) }; + node_type.follow_tvars(self) + } + + pub(crate) fn peek_node_type(&self, node: &'ast N) -> Option> { + self.registry.borrow_mut().peek_node_type(node) + } + + pub(crate) fn get_param_type(&mut self, param: &'ast String) -> Arc { + self.registry.borrow_mut().get_param_type(param) + } + + /// [`sqltk_parser::ast::Value`] nodes with type `Type::Var(_)` after the inference phase is complete will be unified + /// with [`NativeValue`]. + /// + /// This can happen when a literal or param is never used in an expression that would constrain its type. + /// + /// In that case, it is safe to resolve its type as native because it cannot possibly be an EQL type, which are + /// always correctly inferred. + pub(crate) fn resolve_unresolved_value_nodes(&mut self) -> Result<(), TypeError> { + let unresolved_value_nodes: Vec<_> = self + .registry + .borrow() + .get_nodes_and_types::() + .into_iter() + .map(|(node, ty)| (node, ty.follow_tvars(&*self))) + .filter(|(_, ty)| matches!(&**ty, Type::Var(_))) + .collect(); + + for (_, ty) in unresolved_value_nodes { + self.unify(ty, Type::any_native().into())?; + } + + Ok(()) + } + + pub(crate) fn substitute(&mut self, tvar: TypeVar, sub_ty: Arc) -> Arc { + event!( + target: "eql-mapper::EVENT_SUBSTITUTE", + Level::TRACE, + tvar = %tvar, + sub_ty = %sub_ty, + ); + + self.registry.borrow_mut().substitute(tvar, sub_ty) + } + /// Unifies two [`Type`]s or fails with a [`TypeError`]. /// /// "Type Unification" is a fancy term for finding a set of type variable substitutions for multiple types @@ -39,47 +97,45 @@ impl<'ast> Unifier<'ast> { /// dangling type variables). /// /// Returns `Ok(ty)` if successful, or `Err(TypeError)` on failure. - pub(crate) fn unify(&mut self, left: TypeCell, right: TypeCell) -> Result { + #[instrument( + target = "eql-mapper::UNIFY", + skip(self), + level = "trace", + ret(Display), + err(Debug), + fields( + lhs = %lhs, + rhs = %rhs, + ) + )] + pub(crate) fn unify(&mut self, lhs: Arc, rhs: Arc) -> Result, TypeError> { use types::Constructor::*; use types::Value::*; - let span = span!( - Level::DEBUG, - "unify", - depth = self.depth, - left = left.as_type().to_string(), - right = right.as_type().to_string() - ); - - let _guard = span.enter(); + let lhs: Arc = lhs; + let rhs: Arc = rhs; - self.depth += 1; - - // Short-circuit the unification when left & right are equal. - if left == right { - return Ok(left.join(&right)); + // Short-circuit the unification when lhs & rhs are equal. + if lhs == rhs { + return Ok(lhs.clone()); } - let (a, b) = (left.as_type(), right.as_type()); - - let unification = match (&*a, &*b) { + let unification = match (&*lhs, &*rhs) { // Two projections unify if they have the same number of columns and all of the paired column types also // unify. (Type::Constructor(Projection(_)), Type::Constructor(Projection(_))) => { - Ok(self.unify_projections(left.clone(), right.clone())?) + self.unify_projections(lhs, rhs) } // Two arrays unify if the types of their element types unify. ( - Type::Constructor(Value(Array(element_ty_left))), - Type::Constructor(Value(Array(element_ty_right))), + Type::Constructor(Value(Array(lhs_element_ty))), + Type::Constructor(Value(Array(rhs_element_ty))), ) => { - let element_ty = self.unify(element_ty_left.clone(), element_ty_right.clone())?; - - Ok(left.join_all(&[ - &right, - &TypeCell::new(Type::Constructor(Value(Array(element_ty)))), - ])) + let unified_element_ty = + self.unify(lhs_element_ty.clone(), rhs_element_ty.clone())?; + let unified_array_ty = Type::Constructor(Value(Array(unified_element_ty))); + Ok(unified_array_ty.into()) } // A Value can unify with a single column projection @@ -87,10 +143,7 @@ impl<'ast> Unifier<'ast> { let projection = projection.flatten(); let len = projection.len(); if len == 1 { - let unified = - self.unify_value_type_with_type(left.clone(), projection[0].ty.clone())?; - - Ok(TypeCell::join_all(&left, &[&right, &unified])) + self.unify_value_type_with_one_col_projection(lhs, projection[0].ty.clone()) } else { Err(TypeError::Conflict( "cannot unify value type with projection of more than one column" @@ -103,9 +156,7 @@ impl<'ast> Unifier<'ast> { let projection = projection.flatten(); let len = projection.len(); if len == 1 { - let unified = - self.unify_value_type_with_type(right.clone(), projection[0].ty.clone())?; - Ok(TypeCell::join_all(&left, &[&right, &unified])) + self.unify_value_type_with_one_col_projection(rhs, projection[0].ty.clone()) } else { Err(TypeError::Conflict( "cannot unify value type with projection of more than one column" @@ -114,13 +165,22 @@ impl<'ast> Unifier<'ast> { } } - (Type::Constructor(Value(Native(_))), Type::Constructor(Value(Native(_)))) => { - Ok(left.join(&right)) - } - - (Type::Constructor(Value(Eql(lhs))), Type::Constructor(Value(Eql(rhs)))) => { + // All native types are considered equal in the type system. However, for improved test readability the + // unifier favours a `NativeValue(Some(_))` over a `NativeValue(None)` because `NativeValue(Some(_))` + // carries more information. In a tie, the left hand side wins. + ( + Type::Constructor(Value(Native(native_lhs))), + Type::Constructor(Value(Native(native_rhs))), + ) => match (native_lhs, native_rhs) { + (NativeValue(Some(_)), NativeValue(Some(_))) => Ok(lhs), + (NativeValue(Some(_)), NativeValue(None)) => Ok(lhs), + (NativeValue(None), NativeValue(Some(_))) => Ok(rhs), + _ => Ok(lhs), + }, + + (Type::Constructor(Value(Eql(_))), Type::Constructor(Value(Eql(_)))) => { if lhs == rhs { - Ok(left.join(&right)) + Ok(lhs) } else { Err(TypeError::Conflict(format!( "cannot unify different EQL types: {} and {}", @@ -132,18 +192,12 @@ impl<'ast> Unifier<'ast> { // A constructor resolves with a type variable if either: // 1. the type variable does not already refer to a constructor (transitively), or // 2. it does refer to a constructor and the two constructors unify - (_, Type::Var(tvar)) => { - let unified = self.unify_with_type_var(left.clone(), *tvar)?; - Ok(TypeCell::join_all(&left, &[&right, &unified])) - } + (_, Type::Var(tvar)) => self.unify_with_type_var(lhs, *tvar), // A constructor resolves with a type variable if either: // 1. the type variable does not already refer to a constructor (transitively), or // 2. it does refer to a constructor and the two constructors unify - (Type::Var(tvar), _) => { - let unified = self.unify_with_type_var(right.clone(), *tvar)?; - Ok(TypeCell::join_all(&left, &[&right, &unified])) - } + (Type::Var(tvar), _) => self.unify_with_type_var(rhs, *tvar), // Any other combination of types is a type error. (lhs, rhs) => Err(TypeError::Conflict(format!( @@ -152,9 +206,28 @@ impl<'ast> Unifier<'ast> { ))), }; - self.depth -= 1; + match unification { + Ok(ty) => { + event!( + name: "UNIFY::OK", + target: "eql-mapper::EVENT_UNIFY_OK", + Level::TRACE, + ty = %ty, + ); - unification + Ok(ty) + } + Err(err) => { + event!( + name: "UNIFY::ERR", + target: "eql-mapper::EVENT_UNIFY_ERR", + Level::TRACE, + err = ?&err + ); + + Err(err) + } + } } /// Unifies a type with a type variable. @@ -162,29 +235,33 @@ impl<'ast> Unifier<'ast> { /// Attempts to unify the type with whatever the type variable is pointing to. /// /// After successful unification `ty_rc` and `tvar_rc` will refer to the same allocation. - fn unify_with_type_var(&mut self, ty: TypeCell, tvar: TypeVar) -> Result { - let sub = { - let reg = &*self.registry.borrow(); - reg.get_substitution(tvar) + fn unify_with_type_var( + &mut self, + ty: Arc, + tvar: TypeVar, + ) -> Result, TypeError> { + let sub_ty = { + let registry = &*self.registry.borrow(); + registry.get_type(tvar) }; - let ty = match sub { - Some(sub_ty) => self.unify(ty.clone(), sub_ty.clone())?, - None => ty.clone(), + let unified_ty: Arc = match sub_ty { + Some(sub_ty) => self.unify(ty, sub_ty)?, + None => ty, }; - self.registry.borrow_mut().substitute(tvar, ty.clone()); + self.substitute(tvar, unified_ty.clone()); - Ok(ty) + Ok(unified_ty) } /// Unifies two projection types. fn unify_projections( &mut self, - left: TypeCell, - right: TypeCell, - ) -> Result { - match (&*left.as_type(), &*right.as_type()) { + lhs: Arc, + rhs: Arc, + ) -> Result, TypeError> { + match (&*lhs, &*rhs) { ( Type::Constructor(Constructor::Projection(lhs_projection)), Type::Constructor(Constructor::Projection(rhs_projection)), @@ -200,20 +277,18 @@ impl<'ast> Unifier<'ast> { .iter() .zip(rhs_projection.columns()) { - cols.push(ProjectionColumn { - ty: self.unify(lhs_col.ty.clone(), rhs_col.ty.clone())?, - alias: lhs_col.alias.clone(), - }); + let unified_ty = self.unify(lhs_col.ty.clone(), rhs_col.ty.clone())?; + cols.push(ProjectionColumn::new(unified_ty, lhs_col.alias.clone())); } - let unified = TypeCell::new(Type::Constructor(Constructor::Projection( - Projection::new(cols), - ))); - Ok(left.join_all(&[&right, &unified])) + let unified_ty = + Type::Constructor(Constructor::Projection(Projection::new(cols))); + + Ok(unified_ty.into()) } else { Err(TypeError::Conflict(format!( "cannot unify projections {} and {} because they have different numbers of columns", - *left.as_type(), *right.as_type() + lhs, rhs ))) } } @@ -223,206 +298,291 @@ impl<'ast> Unifier<'ast> { } } - fn unify_value_type_with_type( + fn unify_value_type_with_one_col_projection( &mut self, - value: TypeCell, - ty: TypeCell, - ) -> Result { - let (a, b) = (value.as_type(), ty.as_type()); - - match (&*a, &*b) { + value_ty: Arc, + projection_ty: Arc, + ) -> Result, TypeError> { + match (&*value_ty, &*projection_ty) { ( - Type::Constructor(Constructor::Value(Value::Eql(left))), - Type::Constructor(Constructor::Value(Value::Eql(right))), - ) if left == right => Ok(value.join(&ty)), + Type::Constructor(Constructor::Value(Value::Eql(lhs))), + Type::Constructor(Constructor::Value(Value::Eql(rhs))), + ) if lhs == rhs => Ok(value_ty.clone()), ( - Type::Constructor(Constructor::Value(Value::Native(_))), - Type::Constructor(Constructor::Value(Value::Native(_))), - ) => Ok(value.join(&ty)), + Type::Constructor(Constructor::Value(Value::Native(lhs))), + Type::Constructor(Constructor::Value(Value::Native(rhs))), + ) => match (lhs, rhs) { + (NativeValue(Some(_)), NativeValue(Some(_))) => Ok(value_ty.clone()), + (NativeValue(Some(_)), NativeValue(None)) => Ok(value_ty.clone()), + (NativeValue(None), NativeValue(Some(_))) => Ok(projection_ty.clone()), + _ => Ok(value_ty.clone()), + }, ( - Type::Constructor(Constructor::Value(Value::Array(left))), - Type::Constructor(Constructor::Value(Value::Array(right))), + Type::Constructor(Constructor::Value(Value::Array(lhs))), + Type::Constructor(Constructor::Value(Value::Array(rhs))), ) => { - self.unify(left.clone(), right.clone())?; - Ok(value.join(&ty)) + let unified_element_ty = self.unify(lhs.clone(), rhs.clone())?; + let unified_array_ty = + Type::Constructor(Constructor::Value(Value::Array(unified_element_ty))); + Ok(unified_array_ty.into()) } (Type::Constructor(Constructor::Value(Value::Eql(_))), Type::Var(tvar)) => { - let unified = self.unify_with_type_var(value.clone(), *tvar)?; - Ok(value.join_all(&[&ty, &unified])) + self.unify_with_type_var(value_ty.clone(), *tvar) } (Type::Var(tvar), Type::Constructor(Constructor::Value(Value::Eql(_)))) => { - let unified = self.unify_with_type_var(value.clone(), *tvar)?; - Ok(value.join_all(&[&ty, &unified])) + self.unify_with_type_var(projection_ty.clone(), *tvar) } _ => Err(TypeError::Conflict(format!( "value type {} cannot be unified with single column projection of {}", - *value.as_type(), - *ty.as_type() + value_ty, projection_ty ))), } } } +pub(crate) mod test_util { + use sqltk::{AsNodeKey, Break, Visitable, Visitor}; + use sqltk_parser::ast::{ + Delete, Expr, Function, FunctionArguments, Insert, Query, Select, SelectItem, SetExpr, + Statement, Value, Values, + }; + use std::{any::type_name, convert::Infallible, fmt::Debug, ops::ControlFlow}; + use tracing::{event, Level}; + + use crate::unifier::Unifier; + + use std::fmt::Display; + + impl<'ast> super::Unifier<'ast> { + pub(crate) fn dump_substitutions(&self) { + for (tvar, ty) in self.get_substitutions().iter() { + event!( + target: "eql-mapper::DUMP_SUB", + Level::TRACE, + sub = format!("{} => {}", tvar, ty) + ); + } + } + + /// Dumps the type information for a specific AST node to STDERR. + /// + /// Useful when debugging tests. + pub(crate) fn dump_node(&self, node: &'ast N) { + let root_ty = self.get_node_type(node).clone(); + let found_ty = root_ty.clone().follow_tvars(self); + let ast_ty = type_name::(); + + event!( + target: "eql-mapper::DUMP_NODE", + Level::TRACE, + ast_ty = ast_ty, + node = %node, + root_ty = %root_ty, + found_ty = %found_ty + ); + } + + /// Dumps the type information for all nodes visited so far to STDERR. + /// + /// Useful when debugging tests. + pub(crate) fn dump_all_nodes(&self, root_node: &'ast N) { + struct FindNodeFromKeyVisitor<'a, 'ast>(&'a Unifier<'ast>); + + impl<'ast> Visitor<'ast> for FindNodeFromKeyVisitor<'_, 'ast> { + type Error = Infallible; + + fn enter( + &mut self, + node: &'ast N, + ) -> ControlFlow> { + if let Some(node) = node.downcast_ref::() { + self.0.dump_node(node); + } + + if let Some(node) = node.downcast_ref::() { + self.0.dump_node(node); + } + + if let Some(node) = node.downcast_ref::() { + self.0.dump_node(node); + } + + if let Some(node) = node.downcast_ref::() { + self.0.dump_node(node); + } + + if let Some(node) = node.downcast_ref::() { + self.0.dump_node(node); + } + + if let Some(node) = node.downcast_ref::() { + self.0.dump_node(node); + } + + if let Some(node) = node.downcast_ref::