Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion packages/cipherstash-proxy/src/encrypt/schema/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,13 @@ pub async fn load_schema(config: &DatabaseConfig) -> Result<Schema, Error> {
};

for table in tables {
let table_schema: String = table.get("table_schema");
let table_name: String = table.get("table_name");
let primary_keys: Vec<Option<String>> = table.get("primary_keys");
let columns: Vec<String> = table.get("columns");
let column_type_names: Vec<Option<String>> = table.get("column_type_names");

let mut table = Table::new(Ident::new(&table_name));
let mut table = Table::new(Ident::new(&table_name), Ident::new(&table_schema));

columns.iter().zip(column_type_names).for_each(|(col, column_type_name)| {
let is_primary_key = primary_keys.contains(&Some(col.to_string()));
Expand Down
12 changes: 9 additions & 3 deletions packages/eql-mapper/src/importer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
},
model::{SchemaError, TableResolver},
unifier::{Projection, ProjectionColumns},
Relation, ScopeError, ScopeTracker,
Relation, ScopeError, ScopeTracker, Table,
};
use sqltk::parser::ast::{Cte, Ident, Insert, TableAlias, TableFactor};
use sqltk::{Break, Visitable, Visitor};
Expand Down Expand Up @@ -42,9 +42,11 @@ impl<'ast> Importer<'ast> {
..
} = insert;

let (table_name, source_schema) = Table::get_table_name_and_source_schema(table_name);

let table = self
.table_resolver
.resolve_table(table_name.0.last().unwrap())?;
.resolve_table(&table_name, &source_schema)?;

let cols = ProjectionColumns::new_from_schema_table(table.clone());

Expand Down Expand Up @@ -103,8 +105,12 @@ impl<'ast> Importer<'ast> {

let mut scope_tracker = self.scope_tracker.borrow_mut();

let (table_name, source_schema) = Table::get_table_name_and_source_schema(name);

if scope_tracker.resolve_relation(name).is_err() {
let table = self.table_resolver.resolve_table(name.0.last().unwrap())?;
let table = self
.table_resolver
.resolve_table(&table_name, &source_schema)?;

let cols = ProjectionColumns::new_from_schema_table(table.clone());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ use crate::{
InferType,
},
unifier::{EqlValue, NativeValue, Value},
ColumnKind, TableColumn, TypeInferencer,
ColumnKind, Table, TableColumn, TypeInferencer,
};
use eql_mapper_macros::trace_infer;
use sqltk::parser::ast::{Ident, Insert};
use sqltk::parser::ast::Insert;

#[trace_infer]
impl<'ast> InferType<'ast, Insert> for TypeInferencer<'ast> {
Expand All @@ -27,15 +27,19 @@ impl<'ast> InferType<'ast, Insert> for TypeInferencer<'ast> {
return Err(TypeError::UnsupportedSqlFeature("INSERT with ALIAS".into()));
}

let table_name: &Ident = table_name.0.last().unwrap();
let (table_name, source_schema) = Table::get_table_name_and_source_schema(table_name);

let table_columns = if columns.is_empty() {
// When no columns are specified, the source must unify with a projection of ALL table columns.
self.table_resolver.resolve_table_columns(table_name)?
self.table_resolver
.resolve_table_columns(&table_name, &source_schema)?
} else {
columns
.iter()
.map(|c| self.table_resolver.resolve_table_column(table_name, c))
.map(|c| {
self.table_resolver
.resolve_table_column(&table_name, &source_schema, c)
})
.collect::<Result<Vec<_>, _>>()?
};

Expand Down
52 changes: 41 additions & 11 deletions packages/eql-mapper/src/model/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use super::sql_ident::*;
use crate::iterator_ext::IteratorExt;
use core::fmt::Debug;
use derive_more::Display;
use sqltk::parser::ast::Ident;
use sqltk::parser::ast::{Ident, ObjectName};
use std::sync::Arc;
use thiserror::Error;

Expand All @@ -23,9 +23,16 @@ pub struct Schema {
/// A table (or view).
///
/// It has a name and some columns
///
/// The source_schema is the schema that the table "physically" belongs to
/// A user-visible schema can provide permissions across many source schemas.
/// ie, the "information_schema" and "pg_catalog" schemas are both included in the cidde loaded schema
///
#[derive(Debug, Clone, PartialEq, Eq, Display, Hash)]
#[display("Table<{}>", name)]
pub struct Table {
// The schema this table belongs to (e.g. "public", "information_schema", "pg_catalog")
pub source_schema: Ident,
pub name: Ident,
pub columns: Vec<Arc<Column>>,
// Stores indices into the columns Vec.
Expand Down Expand Up @@ -102,19 +109,27 @@ impl Schema {

/// Resolves a table by `Ident`, which takes into account the SQL rules
/// of quoted and new identifier matching.
pub fn resolve_table(&self, name: &Ident) -> Result<Arc<Table>, SchemaError> {
pub fn resolve_table(
&self,
table_name: &Ident,
source_schema: &Ident,
) -> Result<Arc<Table>, SchemaError> {
let mut haystack = self.tables.iter();
haystack
.find_unique(&|table| SqlIdent::from(&table.name) == SqlIdent::from(name))
.find_unique(&|table| {
SqlIdent::from(&table.name) == SqlIdent::from(table_name)
&& SqlIdent::from(&table.source_schema) == SqlIdent::from(source_schema)
})
.cloned()
.map_err(|_| SchemaError::TableNotFound(name.to_string()))
.map_err(|_| SchemaError::TableNotFound(table_name.to_string()))
}

pub fn resolve_table_columns(
&self,
table_name: &Ident,
source_schema: &Ident,
) -> Result<Vec<SchemaTableColumn>, SchemaError> {
let table = self.resolve_table(table_name)?;
let table = self.resolve_table(table_name, source_schema)?;
Ok(table
.columns
.iter()
Expand All @@ -129,12 +144,14 @@ impl Schema {
pub fn resolve_table_column(
&self,
table_name: &Ident,
source_schema: &Ident,
column_name: &Ident,
) -> Result<SchemaTableColumn, SchemaError> {
let mut haystack = self.tables.iter();
match haystack
.find_unique(&|table| SqlIdent::from(&table.name) == SqlIdent::from(table_name))
{
match haystack.find_unique(&|table| {
SqlIdent::from(&table.name) == SqlIdent::from(table_name)
&& SqlIdent::from(&table.source_schema) == SqlIdent::from(source_schema)
}) {
Ok(table) => match table.get_column(column_name) {
Ok(column) => Ok(SchemaTableColumn {
table: table.name.clone(),
Expand All @@ -153,9 +170,10 @@ impl Schema {

impl Table {
/// Create a new named table with no columns.
pub fn new(name: Ident) -> Self {
pub fn new(name: Ident, source_schema: Ident) -> Self {
Self {
name,
source_schema,
primary_key: Vec::with_capacity(1),
columns: Vec::with_capacity(16),
}
Expand Down Expand Up @@ -192,6 +210,18 @@ impl Table {
.cloned()
.collect()
}

// ObjectName is a list of identifiers, the last entry is always the table name
pub fn get_table_name_and_source_schema(name: &ObjectName) -> (Ident, Ident) {
let idents = &name.0;
let table_name = idents.last().unwrap().clone();
let source_schema = if idents.len() == 2 {
idents.first().unwrap().clone()
} else {
Ident::new("public")
};
(table_name, source_schema)
}
}

/// A DSL to create a [`Schema`] for testing purposes.
Expand Down Expand Up @@ -227,8 +257,8 @@ macro_rules! schema {
(@add_table $schema:ident $table_name:ident $table:ident { $($columns:tt)* }) => {
$schema.add_table(
{

let mut $table = $crate::model::Table::new(::sqltk::parser::ast::Ident::new(stringify!($table_name)));
let source_schema = ::sqltk::parser::ast::Ident::new("public");
let mut $table = $crate::model::Table::new(::sqltk::parser::ast::Ident::new(stringify!($table_name)), source_schema);
schema!(@add_columns $table $($columns)*);
$table
}
Expand Down
Loading
Loading