Skip to content

Commit 743665b

Browse files
committed
feat: add source schema to table resolution
1 parent 2fce71d commit 743665b

File tree

6 files changed

+175
-71
lines changed

6 files changed

+175
-71
lines changed

packages/cipherstash-proxy/src/encrypt/schema/manager.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,13 @@ pub async fn load_schema(config: &DatabaseConfig) -> Result<Schema, Error> {
129129
};
130130

131131
for table in tables {
132+
let table_schema: String = table.get("table_schema");
132133
let table_name: String = table.get("table_name");
133134
let primary_keys: Vec<Option<String>> = table.get("primary_keys");
134135
let columns: Vec<String> = table.get("columns");
135136
let column_type_names: Vec<Option<String>> = table.get("column_type_names");
136137

137-
let mut table = Table::new(Ident::new(&table_name));
138+
let mut table = Table::new(Ident::new(&table_name), Ident::new(&table_schema));
138139

139140
columns.iter().zip(column_type_names).for_each(|(col, column_type_name)| {
140141
let is_primary_key = primary_keys.contains(&Some(col.to_string()));

packages/eql-mapper/src/importer.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::{
55
},
66
model::{SchemaError, TableResolver},
77
unifier::{Projection, ProjectionColumns},
8-
Relation, ScopeError, ScopeTracker,
8+
Relation, ScopeError, ScopeTracker, Table,
99
};
1010
use sqltk::parser::ast::{Cte, Ident, Insert, TableAlias, TableFactor};
1111
use sqltk::{Break, Visitable, Visitor};
@@ -42,9 +42,11 @@ impl<'ast> Importer<'ast> {
4242
..
4343
} = insert;
4444

45+
let (table_name, source_schema) = Table::get_table_name_and_source_schema(table_name);
46+
4547
let table = self
4648
.table_resolver
47-
.resolve_table(table_name.0.last().unwrap())?;
49+
.resolve_table(&table_name, &source_schema)?;
4850

4951
let cols = ProjectionColumns::new_from_schema_table(table.clone());
5052

@@ -103,8 +105,12 @@ impl<'ast> Importer<'ast> {
103105

104106
let mut scope_tracker = self.scope_tracker.borrow_mut();
105107

108+
let (table_name, source_schema) = Table::get_table_name_and_source_schema(name);
109+
106110
if scope_tracker.resolve_relation(name).is_err() {
107-
let table = self.table_resolver.resolve_table(name.0.last().unwrap())?;
111+
let table = self
112+
.table_resolver
113+
.resolve_table(&table_name, &source_schema)?;
108114

109115
let cols = ProjectionColumns::new_from_schema_table(table.clone());
110116

packages/eql-mapper/src/inference/infer_type_impls/insert_statement.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ use crate::{
77
InferType,
88
},
99
unifier::{EqlValue, NativeValue, Value},
10-
ColumnKind, TableColumn, TypeInferencer,
10+
ColumnKind, Table, TableColumn, TypeInferencer,
1111
};
1212
use eql_mapper_macros::trace_infer;
13-
use sqltk::parser::ast::{Ident, Insert};
13+
use sqltk::parser::ast::Insert;
1414

1515
#[trace_infer]
1616
impl<'ast> InferType<'ast, Insert> for TypeInferencer<'ast> {
@@ -27,15 +27,19 @@ impl<'ast> InferType<'ast, Insert> for TypeInferencer<'ast> {
2727
return Err(TypeError::UnsupportedSqlFeature("INSERT with ALIAS".into()));
2828
}
2929

30-
let table_name: &Ident = table_name.0.last().unwrap();
30+
let (table_name, source_schema) = Table::get_table_name_and_source_schema(table_name);
3131

3232
let table_columns = if columns.is_empty() {
3333
// When no columns are specified, the source must unify with a projection of ALL table columns.
34-
self.table_resolver.resolve_table_columns(table_name)?
34+
self.table_resolver
35+
.resolve_table_columns(&table_name, &source_schema)?
3536
} else {
3637
columns
3738
.iter()
38-
.map(|c| self.table_resolver.resolve_table_column(table_name, c))
39+
.map(|c| {
40+
self.table_resolver
41+
.resolve_table_column(&table_name, &source_schema, c)
42+
})
3943
.collect::<Result<Vec<_>, _>>()?
4044
};
4145

packages/eql-mapper/src/model/schema.rs

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use super::sql_ident::*;
66
use crate::iterator_ext::IteratorExt;
77
use core::fmt::Debug;
88
use derive_more::Display;
9-
use sqltk::parser::ast::Ident;
9+
use sqltk::parser::ast::{Ident, ObjectName};
1010
use std::sync::Arc;
1111
use thiserror::Error;
1212

@@ -23,9 +23,16 @@ pub struct Schema {
2323
/// A table (or view).
2424
///
2525
/// It has a name and some columns
26+
///
27+
/// The source_schema is the schema that the table "physically" belongs to
28+
/// A user-visible schema can provide permissions across many source schemas.
29+
/// ie, the "information_schema" and "pg_catalog" schemas are both included in the cidde loaded schema
30+
///
2631
#[derive(Debug, Clone, PartialEq, Eq, Display, Hash)]
2732
#[display("Table<{}>", name)]
2833
pub struct Table {
34+
// The schema this table belongs to (e.g. "public", "information_schema", "pg_catalog")
35+
pub source_schema: Ident,
2936
pub name: Ident,
3037
pub columns: Vec<Arc<Column>>,
3138
// Stores indices into the columns Vec.
@@ -102,19 +109,27 @@ impl Schema {
102109

103110
/// Resolves a table by `Ident`, which takes into account the SQL rules
104111
/// of quoted and new identifier matching.
105-
pub fn resolve_table(&self, name: &Ident) -> Result<Arc<Table>, SchemaError> {
112+
pub fn resolve_table(
113+
&self,
114+
table_name: &Ident,
115+
source_schema: &Ident,
116+
) -> Result<Arc<Table>, SchemaError> {
106117
let mut haystack = self.tables.iter();
107118
haystack
108-
.find_unique(&|table| SqlIdent::from(&table.name) == SqlIdent::from(name))
119+
.find_unique(&|table| {
120+
SqlIdent::from(&table.name) == SqlIdent::from(table_name)
121+
&& SqlIdent::from(&table.source_schema) == SqlIdent::from(source_schema)
122+
})
109123
.cloned()
110-
.map_err(|_| SchemaError::TableNotFound(name.to_string()))
124+
.map_err(|_| SchemaError::TableNotFound(table_name.to_string()))
111125
}
112126

113127
pub fn resolve_table_columns(
114128
&self,
115129
table_name: &Ident,
130+
source_schema: &Ident,
116131
) -> Result<Vec<SchemaTableColumn>, SchemaError> {
117-
let table = self.resolve_table(table_name)?;
132+
let table = self.resolve_table(table_name, source_schema)?;
118133
Ok(table
119134
.columns
120135
.iter()
@@ -129,12 +144,14 @@ impl Schema {
129144
pub fn resolve_table_column(
130145
&self,
131146
table_name: &Ident,
147+
source_schema: &Ident,
132148
column_name: &Ident,
133149
) -> Result<SchemaTableColumn, SchemaError> {
134150
let mut haystack = self.tables.iter();
135-
match haystack
136-
.find_unique(&|table| SqlIdent::from(&table.name) == SqlIdent::from(table_name))
137-
{
151+
match haystack.find_unique(&|table| {
152+
SqlIdent::from(&table.name) == SqlIdent::from(table_name)
153+
&& SqlIdent::from(&table.source_schema) == SqlIdent::from(source_schema)
154+
}) {
138155
Ok(table) => match table.get_column(column_name) {
139156
Ok(column) => Ok(SchemaTableColumn {
140157
table: table.name.clone(),
@@ -153,9 +170,10 @@ impl Schema {
153170

154171
impl Table {
155172
/// Create a new named table with no columns.
156-
pub fn new(name: Ident) -> Self {
173+
pub fn new(name: Ident, source_schema: Ident) -> Self {
157174
Self {
158175
name,
176+
source_schema,
159177
primary_key: Vec::with_capacity(1),
160178
columns: Vec::with_capacity(16),
161179
}
@@ -192,6 +210,18 @@ impl Table {
192210
.cloned()
193211
.collect()
194212
}
213+
214+
// ObjectName is a list of identifiers, the last entry is always the table name
215+
pub fn get_table_name_and_source_schema(name: &ObjectName) -> (Ident, Ident) {
216+
let idents = &name.0;
217+
let table_name = idents.last().unwrap().clone();
218+
let source_schema = if idents.len() == 2 {
219+
idents.first().unwrap().clone()
220+
} else {
221+
Ident::new("public")
222+
};
223+
(table_name, source_schema)
224+
}
195225
}
196226

197227
/// A DSL to create a [`Schema`] for testing purposes.
@@ -227,8 +257,8 @@ macro_rules! schema {
227257
(@add_table $schema:ident $table_name:ident $table:ident { $($columns:tt)* }) => {
228258
$schema.add_table(
229259
{
230-
231-
let mut $table = $crate::model::Table::new(::sqltk::parser::ast::Ident::new(stringify!($table_name)));
260+
let source_schema = ::sqltk::parser::ast::Ident::new("public");
261+
let mut $table = $crate::model::Table::new(::sqltk::parser::ast::Ident::new(stringify!($table_name)), source_schema);
232262
schema!(@add_columns $table $($columns)*);
233263
$table
234264
}

0 commit comments

Comments
 (0)