Skip to content

Commit 678c746

Browse files
committed
feat: add more rewrite rules
1 parent 08bdc60 commit 678c746

File tree

3 files changed

+89
-5
lines changed

3 files changed

+89
-5
lines changed

datafusion-postgres/src/handlers.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ use std::sync::Arc;
33

44
use crate::auth::{AuthManager, Permission, ResourceType};
55
use crate::sql::{
6-
parse, rewrite, AliasDuplicatedProjectionRewrite, RemoveUnsupportedTypes,
7-
ResolveUnqualifiedIdentifer, RewriteArrayAnyOperation, SqlStatementRewriteRule,
6+
parse, rewrite, AliasDuplicatedProjectionRewrite, PrependUnqualifiedTableName,
7+
RemoveUnsupportedTypes, ResolveUnqualifiedIdentifer, RewriteArrayAnyOperation,
8+
SqlStatementRewriteRule,
89
};
910
use async_trait::async_trait;
1011
use datafusion::arrow::datatypes::DataType;
@@ -82,6 +83,7 @@ impl DfSessionService {
8283
Arc::new(ResolveUnqualifiedIdentifer),
8384
Arc::new(RemoveUnsupportedTypes::new()),
8485
Arc::new(RewriteArrayAnyOperation),
86+
Arc::new(PrependUnqualifiedTableName::new()),
8587
];
8688
let parser = Arc::new(Parser {
8789
session_context: session_context.clone(),
@@ -297,8 +299,8 @@ impl DfSessionService {
297299
Ok(Some(Response::Query(resp)))
298300
}
299301
"show search_path" => {
300-
let default_catalog = "datafusion";
301-
let resp = Self::mock_show_response("search_path", default_catalog)?;
302+
let default_schema = "public";
303+
let resp = Self::mock_show_response("search_path", default_schema)?;
302304
Ok(Some(Response::Query(resp)))
303305
}
304306
_ => Err(PgWireError::UserError(Box::new(

datafusion-postgres/src/sql.rs

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use datafusion::sql::sqlparser::ast::FunctionArgumentList;
1111
use datafusion::sql::sqlparser::ast::FunctionArguments;
1212
use datafusion::sql::sqlparser::ast::Ident;
1313
use datafusion::sql::sqlparser::ast::ObjectName;
14+
use datafusion::sql::sqlparser::ast::ObjectNamePart;
1415
use datafusion::sql::sqlparser::ast::OrderByKind;
1516
use datafusion::sql::sqlparser::ast::Query;
1617
use datafusion::sql::sqlparser::ast::Select;
@@ -401,6 +402,63 @@ impl SqlStatementRewriteRule for RewriteArrayAnyOperation {
401402
}
402403
}
403404

405+
/// Prepend qualifier to table_name
406+
///
407+
/// Postgres has pg_catalog in search_path by default so it allow access to
408+
/// `pg_namespace` without `pg_catalog.` qualifier
409+
#[derive(Debug)]
410+
pub struct PrependUnqualifiedTableName {
411+
table_names: HashSet<String>,
412+
}
413+
414+
impl PrependUnqualifiedTableName {
415+
pub fn new() -> Self {
416+
let mut table_names = HashSet::new();
417+
418+
table_names.insert("pg_namespace".to_owned());
419+
420+
Self { table_names }
421+
}
422+
}
423+
424+
struct PrependUnqualifiedTableNameVisitor<'a> {
425+
table_names: &'a HashSet<String>,
426+
}
427+
428+
impl<'a> VisitorMut for PrependUnqualifiedTableNameVisitor<'a> {
429+
type Break = ();
430+
431+
fn pre_visit_table_factor(
432+
&mut self,
433+
table_factor: &mut TableFactor,
434+
) -> ControlFlow<Self::Break> {
435+
if let TableFactor::Table { name, .. } = table_factor {
436+
if name.0.len() == 1 {
437+
let ObjectNamePart::Identifier(ident) = &name.0[0];
438+
if self.table_names.contains(&ident.to_string()) {
439+
*name = ObjectName(vec![
440+
ObjectNamePart::Identifier(Ident::new("pg_catalog")),
441+
name.0[0].clone(),
442+
]);
443+
}
444+
}
445+
}
446+
447+
ControlFlow::Continue(())
448+
}
449+
}
450+
451+
impl SqlStatementRewriteRule for PrependUnqualifiedTableName {
452+
fn rewrite(&self, mut s: Statement) -> Statement {
453+
let mut visitor = PrependUnqualifiedTableNameVisitor {
454+
table_names: &self.table_names,
455+
};
456+
457+
let _ = s.visit(&mut visitor);
458+
s
459+
}
460+
}
461+
404462
#[cfg(test)]
405463
mod tests {
406464
use super::*;
@@ -524,4 +582,28 @@ mod tests {
524582
"SELECT a FROM tbl WHERE array_contains(current_schemas(true), a)"
525583
);
526584
}
585+
586+
#[test]
587+
fn test_prepend_unqualified_table_name() {
588+
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
589+
vec![Arc::new(PrependUnqualifiedTableName::new())];
590+
591+
assert_rewrite!(
592+
&rules,
593+
"SELECT * FROM pg_catalog.pg_namespace",
594+
"SELECT * FROM pg_catalog.pg_namespace"
595+
);
596+
597+
assert_rewrite!(
598+
&rules,
599+
"SELECT * FROM pg_namespace",
600+
"SELECT * FROM pg_catalog.pg_namespace"
601+
);
602+
603+
assert_rewrite!(
604+
&rules,
605+
"SELECT typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN pg_namespace as ns ON ns.oid = oid",
606+
"SELECT typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN pg_catalog.pg_namespace AS ns ON ns.oid = oid"
607+
);
608+
}
527609
}

datafusion-postgres/tests/dbeaver.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ const DBEAVER_QUERIES: &[&str] = &[
1313
"SELECT typinput='pg_catalog.array_in'::regproc as is_array, typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN (select ns.oid as nspoid, ns.nspname, r.r from pg_namespace as ns join ( select s.r, (current_schemas(false))[s.r] as nspname from generate_series(1, array_upper(current_schemas(false), 1)) as s(r) ) as r using ( nspname ) ) as sp ON sp.nspoid = typnamespace WHERE pg_type.oid = 1034 ORDER BY sp.r, pg_type.oid DESC",
1414
"SHOW search_path",
1515
"SELECT db.oid,db.* FROM pg_catalog.pg_database db WHERE datname='postgres'",
16-
"SELECT * FROM pg_catalog.pg_settinngs where name='standard_conforming_strings'",
16+
"SELECT * FROM pg_catalog.pg_settings where name='standard_conforming_strings'",
1717
"SELECT string_agg(word, ',' ) from pg_catalog.pg_get_keywords() where word <> ALL ('{a,abs,absolute,action,ada,add,admin,after,all,allocate,alter,aIways,and,any,are,array,as,asc,asenstitive,assertion,assignment,asymmetric,at,atomic,attribute,attributes,authorization,avg,before,begin,bernoulli,between,bigint,binary,blob,boolean,both,breaadth,by,c,call,called,cardinaliity,cascade,cascaded,case,cast,catalog,catalog_name,ceil,ceiling,chain,char,char_length,character,character_length,character_set_catalog,character_set_name,character_set_schema,characteristics,characters,check,checkeed,class_origin,clob,close,coalesce,coboI,code_units,collate,collation,collaition_catalog,collaition_name,collaition_schema,collect,colum,column_name,command_function,command_function_code,commit,committed,condiition,condiition_number,connect,connection_name,constraint,constraint_catalog,constraint_name,constraint_schema,constraints,constructors,contains,continue,convert,corr,correspondiing,count,covar_pop,covar_samp,create,cross,cube,cume_dist,current,current_collation,current_date,current_default_transfom_group,current_path,current_role,current_time,current_timestamp,current_transfom_group_for_type,current_user,cursor,cursor_name,cycle,data,date,datetime_interval_code,datetime_interval_precision,day,deallocate,dec,decimaI,declare,default,defaults,not,null,nullable,nullif,nulls,number,numeric,object,octeet_length,octets,of,old,on,only,open,option,options,or,order,ordering,ordinaliity,others,out,outer,output,over,overlaps,overlay,overriding,pad,parameter,parameter_mode,parameter_name,parameter_ordinal_position,parameter_speciific_catalog,parameter_speciific_name,parameter_speciific_schema,partiaI,partitioon,pascal,path,percent_rank,percentile_cont,percentile_disc,placing,pli,position,power,preceding,precision,prepare,preseerv,primary,prior,privileges,procedure,public,range,rank,read,reads,real,recursivve,ref,references,referencing,regr_avgx,regr_avgy,regr_count,regr_intercept,regr_r2,regr_slope,regr_sxx,regr_sxy,regr_sy y,relative,release,repeatable,restart,result,retun,returned_cardinality,returned_length,returned_octeet_length,returned_sqlstate,returns,revoe,right,role,rollback,rollup,routine,routine_catalog,routine_name,routine_schema,row,row_count,row_number,rows,savepoint,scale,schema,schema_name,scope_catalog,scope_name,scope_schema,scroll,search,second,section,security,select,self,sensitive,sequence,seriializeable,server_name,session,session_user,set,sets,similar,simple,size,smalIint,some,source,space,specifiic,speciific_name,speciifictype,sql,sqlexception,sqlstate,sqlwarning,sqrt,start,state,statement,static,stddev_pop,stddev_samp,structure,style,subclass_origin,submultiset,substring,sum,symmetric,system,system_user,table,table_name,tablesample,temporary,then,ties,time,timesamp,timezone_hour,timezone_minute,to,top_level_count,trailing,transaction,transaction_active,transactions_committed,transactions_rolled_back,transfor,transforms,translate,translation,treat,trigger,trigger_catalog,trigger_name,trigger_schema,trim,true,type,unbounde,undefined,uncommitted,under,union,unique,unknown,unnaamed,unnest,update,upper,usage,user,user_defined_type_catalog,user_defined_type_code,user_defined_type_name,user_defined_type_schema,using,value,values,var_pop,var_samp,varchar,varying,view,when,whenever,where,width_bucket,window,with,within,without,work,write,year,zone",
1818
"SELECT version()",
1919
"SELECT * FROM pg_catalog.pg_enum WHERE 1<>1 LIMIT 1",

0 commit comments

Comments
 (0)