Skip to content

Commit c1779dc

Browse files
authored
feat: implement dbeaver startup queries one-by-one (#140)
* feat: implement dbeaver startup queries one-by-one * fix: correct query feed into handler * chore: correct test case * test: update some test case * feat: add rewrite rule to transform any operation * feat: add more rewrite rules * feat: add pg_settings view * fix: crash on encoding utf8view list * feat: add array rewrite * feat: implement more about pg_get_keywords * feat: add support for ALLop * feat: add remaining udfs for dbeaver startup queries * chore: update readme
1 parent c01cf50 commit c1779dc

File tree

13 files changed

+856
-28
lines changed

13 files changed

+856
-28
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ project.
2222
- Permission control
2323
- Built-in `pg_catalog` tables
2424
- Built-in postgres functions for common meta queries
25+
- [x] DBeaver compatibility
2526
- `datafusion-postgres-cli`: A cli tool starts a postgres compatible server for
2627
datafusion supported file formats, just like python's `SimpleHTTPServer`.
2728
- `arrow-pg`: A data type mapping, encoding/decoding library for arrow and

arrow-pg/src/list_encoder.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::{str::FromStr, sync::Arc};
22

3+
use arrow::array::{BinaryViewArray, StringViewArray};
34
#[cfg(not(feature = "datafusion"))]
45
use arrow::{
56
array::{
@@ -150,6 +151,15 @@ pub(crate) fn encode_list(
150151
.collect();
151152
encode_field(&value, type_, format)
152153
}
154+
DataType::Utf8View => {
155+
let value: Vec<Option<&str>> = arr
156+
.as_any()
157+
.downcast_ref::<StringViewArray>()
158+
.unwrap()
159+
.iter()
160+
.collect();
161+
encode_field(&value, type_, format)
162+
}
153163
DataType::Binary => {
154164
let value: Vec<Option<_>> = arr
155165
.as_any()
@@ -168,6 +178,15 @@ pub(crate) fn encode_list(
168178
.collect();
169179
encode_field(&value, type_, format)
170180
}
181+
DataType::BinaryView => {
182+
let value: Vec<Option<_>> = arr
183+
.as_any()
184+
.downcast_ref::<BinaryViewArray>()
185+
.unwrap()
186+
.iter()
187+
.collect();
188+
encode_field(&value, type_, format)
189+
}
171190

172191
DataType::Date32 => {
173192
let value: Vec<Option<_>> = arr

datafusion-postgres/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,6 @@ tokio = { version = "1.47", features = ["sync", "net"] }
2828
tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] }
2929
rustls-pemfile = "2.0"
3030
rustls-pki-types = "1.0"
31+
32+
[dev-dependencies]
33+
env_logger = "0.11"

datafusion-postgres/src/handlers.rs

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@ use std::sync::Arc;
33

44
use crate::auth::{AuthManager, Permission, ResourceType};
55
use crate::sql::{
6-
parse, rewrite, AliasDuplicatedProjectionRewrite, RemoveUnsupportedTypes,
7-
ResolveUnqualifiedIdentifer, SqlStatementRewriteRule,
6+
parse, rewrite, AliasDuplicatedProjectionRewrite, FixArrayLiteral, PrependUnqualifiedTableName,
7+
RemoveTableFunctionQualifier, RemoveUnsupportedTypes, ResolveUnqualifiedIdentifer,
8+
RewriteArrayAnyAllOperation, SqlStatementRewriteRule,
89
};
910
use async_trait::async_trait;
1011
use datafusion::arrow::datatypes::DataType;
1112
use datafusion::logical_expr::LogicalPlan;
1213
use datafusion::prelude::*;
1314
use datafusion::sql::parser::Statement;
15+
use log::warn;
1416
use pgwire::api::auth::noop::NoopStartupHandler;
1517
use pgwire::api::auth::StartupHandler;
1618
use pgwire::api::portal::{Format, Portal};
@@ -80,6 +82,10 @@ impl DfSessionService {
8082
Arc::new(AliasDuplicatedProjectionRewrite),
8183
Arc::new(ResolveUnqualifiedIdentifer),
8284
Arc::new(RemoveUnsupportedTypes::new()),
85+
Arc::new(RewriteArrayAnyAllOperation),
86+
Arc::new(PrependUnqualifiedTableName::new()),
87+
Arc::new(FixArrayLiteral),
88+
Arc::new(RemoveTableFunctionQualifier),
8389
];
8490
let parser = Arc::new(Parser {
8591
session_context: session_context.clone(),
@@ -211,14 +217,12 @@ impl DfSessionService {
211217
}
212218
} else {
213219
// pass SET query to datafusion
214-
let df = self
215-
.session_context
216-
.sql(query_lower)
217-
.await
218-
.map_err(|err| PgWireError::ApiError(Box::new(err)))?;
219-
220-
let resp = df::encode_dataframe(df, &Format::UnifiedText).await?;
221-
Ok(Some(Response::Query(resp)))
220+
if let Err(e) = self.session_context.sql(query_lower).await {
221+
warn!("SET statement {query_lower} is not supported by datafusion, error {e}, statement ignored");
222+
}
223+
224+
// Always return SET success
225+
Ok(Some(Response::Execution(Tag::new("SET"))))
222226
}
223227
} else {
224228
Ok(None)
@@ -297,8 +301,8 @@ impl DfSessionService {
297301
Ok(Some(Response::Query(resp)))
298302
}
299303
"show search_path" => {
300-
let default_catalog = "datafusion";
301-
let resp = Self::mock_show_response("search_path", default_catalog)?;
304+
let default_schema = "public";
305+
let resp = Self::mock_show_response("search_path", default_schema)?;
302306
Ok(Some(Response::Query(resp)))
303307
}
304308
_ => Err(PgWireError::UserError(Box::new(
@@ -331,7 +335,8 @@ impl SimpleQueryHandler for DfSessionService {
331335
statement = rewrite(statement, &self.sql_rewrite_rules);
332336

333337
// TODO: improve statement check by using statement directly
334-
let query_lower = statement.to_string().to_lowercase().trim().to_string();
338+
let query = statement.to_string();
339+
let query_lower = query.to_lowercase().trim().to_string();
335340

336341
// Check permissions for the query (skip for SET, transaction, and SHOW statements)
337342
if !query_lower.starts_with("set")
@@ -343,7 +348,7 @@ impl SimpleQueryHandler for DfSessionService {
343348
&& !query_lower.starts_with("abort")
344349
&& !query_lower.starts_with("show")
345350
{
346-
self.check_query_permission(client, query).await?;
351+
self.check_query_permission(client, &query).await?;
347352
}
348353

349354
if let Some(resp) = self.try_respond_set_statements(&query_lower).await? {
@@ -373,7 +378,7 @@ impl SimpleQueryHandler for DfSessionService {
373378
)));
374379
}
375380

376-
let df_result = self.session_context.sql(query).await;
381+
let df_result = self.session_context.sql(&query).await;
377382

378383
// Handle query execution errors and transaction state
379384
let df = match df_result {

datafusion-postgres/src/pg_catalog.rs

Lines changed: 119 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,21 @@ use datafusion::arrow::array::{
1010
use datafusion::arrow::datatypes::{DataType, Field, SchemaRef};
1111
use datafusion::arrow::ipc::reader::FileReader;
1212
use datafusion::catalog::streaming::StreamingTable;
13-
use datafusion::catalog::{CatalogProviderList, MemTable, SchemaProvider};
13+
use datafusion::catalog::{CatalogProviderList, MemTable, SchemaProvider, TableFunctionImpl};
1414
use datafusion::common::utils::SingleRowListArrayBuilder;
1515
use datafusion::datasource::{TableProvider, ViewTable};
1616
use datafusion::error::{DataFusionError, Result};
1717
use datafusion::logical_expr::{ColumnarValue, ScalarUDF, Volatility};
1818
use datafusion::physical_plan::streaming::PartitionStream;
19-
use datafusion::prelude::{create_udf, SessionContext};
19+
use datafusion::prelude::{create_udf, Expr, SessionContext};
2020
use postgres_types::Oid;
2121
use tokio::sync::RwLock;
2222

2323
mod pg_attribute;
2424
mod pg_class;
2525
mod pg_database;
2626
mod pg_namespace;
27+
mod pg_settings;
2728

2829
const PG_CATALOG_TABLE_PG_AGGREGATE: &str = "pg_aggregate";
2930
const PG_CATALOG_TABLE_PG_AM: &str = "pg_am";
@@ -86,6 +87,7 @@ const PG_CATALOG_TABLE_PG_SUBSCRIPTION_REL: &str = "pg_subscription_rel";
8687
const PG_CATALOG_TABLE_PG_TABLESPACE: &str = "pg_tablespace";
8788
const PG_CATALOG_TABLE_PG_TRIGGER: &str = "pg_trigger";
8889
const PG_CATALOG_TABLE_PG_USER_MAPPING: &str = "pg_user_mapping";
90+
const PG_CATALOG_VIEW_PG_SETTINGS: &str = "pg_settings";
8991

9092
/// Determine PostgreSQL table type (relkind) from DataFusion TableProvider
9193
fn get_table_type(table: &Arc<dyn TableProvider>) -> &'static str {
@@ -180,6 +182,7 @@ pub const PG_CATALOG_TABLES: &[&str] = &[
180182
PG_CATALOG_TABLE_PG_TABLESPACE,
181183
PG_CATALOG_TABLE_PG_TRIGGER,
182184
PG_CATALOG_TABLE_PG_USER_MAPPING,
185+
PG_CATALOG_VIEW_PG_SETTINGS,
183186
];
184187

185188
#[derive(Debug, Hash, Eq, PartialEq, PartialOrd, Ord)]
@@ -196,7 +199,7 @@ pub struct PgCatalogSchemaProvider {
196199
catalog_list: Arc<dyn CatalogProviderList>,
197200
oid_counter: Arc<AtomicU32>,
198201
oid_cache: Arc<RwLock<HashMap<OidCacheKey, Oid>>>,
199-
static_tables: PgCatalogStaticTables,
202+
static_tables: Arc<PgCatalogStaticTables>,
200203
}
201204

202205
#[async_trait]
@@ -345,6 +348,10 @@ impl SchemaProvider for PgCatalogSchemaProvider {
345348
StreamingTable::try_new(Arc::clone(table.schema()), vec![table]).unwrap(),
346349
)))
347350
}
351+
PG_CATALOG_VIEW_PG_SETTINGS => {
352+
let table = pg_settings::PgSettingsView::try_new()?;
353+
Ok(Some(Arc::new(table.try_into_memtable()?)))
354+
}
348355

349356
_ => Ok(None),
350357
}
@@ -356,12 +363,15 @@ impl SchemaProvider for PgCatalogSchemaProvider {
356363
}
357364

358365
impl PgCatalogSchemaProvider {
359-
pub fn try_new(catalog_list: Arc<dyn CatalogProviderList>) -> Result<PgCatalogSchemaProvider> {
366+
pub fn try_new(
367+
catalog_list: Arc<dyn CatalogProviderList>,
368+
static_tables: Arc<PgCatalogStaticTables>,
369+
) -> Result<PgCatalogSchemaProvider> {
360370
Ok(Self {
361371
catalog_list,
362372
oid_counter: Arc::new(AtomicU32::new(16384)),
363373
oid_cache: Arc::new(RwLock::new(HashMap::new())),
364-
static_tables: PgCatalogStaticTables::try_new()?,
374+
static_tables,
365375
})
366376
}
367377
}
@@ -399,10 +409,17 @@ impl ArrowTable {
399409
}
400410
}
401411

412+
impl TableFunctionImpl for ArrowTable {
413+
fn call(&self, _args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
414+
let table = self.clone().try_into_memtable()?;
415+
Ok(Arc::new(table))
416+
}
417+
}
418+
402419
/// pg_catalog table as datafusion table provider
403420
///
404421
/// This implementation only contains static tables
405-
#[derive(Debug)]
422+
#[derive(Debug, Clone)]
406423
pub struct PgCatalogStaticTables {
407424
pub pg_aggregate: Arc<dyn TableProvider>,
408425
pub pg_am: Arc<dyn TableProvider>,
@@ -461,6 +478,8 @@ pub struct PgCatalogStaticTables {
461478
pub pg_tablespace: Arc<dyn TableProvider>,
462479
pub pg_trigger: Arc<dyn TableProvider>,
463480
pub pg_user_mapping: Arc<dyn TableProvider>,
481+
482+
pub pg_get_keywords: Arc<dyn TableFunctionImpl>,
464483
}
465484

466485
impl PgCatalogStaticTables {
@@ -647,6 +666,10 @@ impl PgCatalogStaticTables {
647666
pg_user_mapping: Self::create_arrow_table(
648667
include_bytes!("../../pg_catalog_arrow_exports/pg_user_mapping.feather").to_vec(),
649668
)?,
669+
670+
pg_get_keywords: Self::create_arrow_table_function(
671+
include_bytes!("../../pg_catalog_arrow_exports/pg_get_keywords.feather").to_vec(),
672+
)?,
650673
})
651674
}
652675

@@ -656,6 +679,11 @@ impl PgCatalogStaticTables {
656679
let mem_table = table.try_into_memtable()?;
657680
Ok(Arc::new(mem_table))
658681
}
682+
683+
fn create_arrow_table_function(data_bytes: Vec<u8>) -> Result<Arc<dyn TableFunctionImpl>> {
684+
let table = ArrowTable::from_ipc_data(data_bytes)?;
685+
Ok(Arc::new(table))
686+
}
659687
}
660688

661689
pub fn create_current_schemas_udf() -> ScalarUDF {
@@ -862,7 +890,78 @@ pub fn create_format_type_udf() -> ScalarUDF {
862890

863891
create_udf(
864892
"format_type",
865-
vec![DataType::Int32, DataType::Int32],
893+
vec![DataType::Int64, DataType::Int32],
894+
DataType::Utf8,
895+
Volatility::Stable,
896+
Arc::new(func),
897+
)
898+
}
899+
900+
pub fn create_session_user_udf() -> ScalarUDF {
901+
let func = move |_args: &[ColumnarValue]| {
902+
let mut builder = StringBuilder::new();
903+
// TODO: return real user
904+
builder.append_value("postgres");
905+
906+
let array: ArrayRef = Arc::new(builder.finish());
907+
908+
Ok(ColumnarValue::Array(array))
909+
};
910+
911+
create_udf(
912+
"session_user",
913+
vec![],
914+
DataType::Utf8,
915+
Volatility::Stable,
916+
Arc::new(func),
917+
)
918+
}
919+
920+
pub fn create_pg_get_expr_udf() -> ScalarUDF {
921+
let func = move |args: &[ColumnarValue]| {
922+
let args = ColumnarValue::values_to_arrays(args)?;
923+
let expr = &args[0];
924+
let _oid = &args[1];
925+
926+
// For now, always return true (full access for current user)
927+
let mut builder = StringBuilder::new();
928+
for _ in 0..expr.len() {
929+
builder.append_value("");
930+
}
931+
932+
let array: ArrayRef = Arc::new(builder.finish());
933+
934+
Ok(ColumnarValue::Array(array))
935+
};
936+
937+
create_udf(
938+
"pg_catalog.pg_get_expr",
939+
vec![DataType::Utf8, DataType::Int32],
940+
DataType::Utf8,
941+
Volatility::Stable,
942+
Arc::new(func),
943+
)
944+
}
945+
946+
pub fn create_pg_get_partkeydef_udf() -> ScalarUDF {
947+
let func = move |args: &[ColumnarValue]| {
948+
let args = ColumnarValue::values_to_arrays(args)?;
949+
let oid = &args[0];
950+
951+
// For now, always return true (full access for current user)
952+
let mut builder = StringBuilder::new();
953+
for _ in 0..oid.len() {
954+
builder.append_value("");
955+
}
956+
957+
let array: ArrayRef = Arc::new(builder.finish());
958+
959+
Ok(ColumnarValue::Array(array))
960+
};
961+
962+
create_udf(
963+
"pg_catalog.pg_get_partkeydef",
964+
vec![DataType::Utf8],
866965
DataType::Utf8,
867966
Volatility::Stable,
868967
Arc::new(func),
@@ -874,8 +973,11 @@ pub fn setup_pg_catalog(
874973
session_context: &SessionContext,
875974
catalog_name: &str,
876975
) -> Result<(), Box<DataFusionError>> {
877-
let pg_catalog =
878-
PgCatalogSchemaProvider::try_new(session_context.state().catalog_list().clone())?;
976+
let static_tables = Arc::new(PgCatalogStaticTables::try_new()?);
977+
let pg_catalog = PgCatalogSchemaProvider::try_new(
978+
session_context.state().catalog_list().clone(),
979+
static_tables.clone(),
980+
)?;
879981
session_context
880982
.catalog(catalog_name)
881983
.ok_or_else(|| {
@@ -892,6 +994,10 @@ pub fn setup_pg_catalog(
892994
session_context.register_udf(create_has_table_privilege_2param_udf());
893995
session_context.register_udf(create_pg_table_is_visible());
894996
session_context.register_udf(create_format_type_udf());
997+
session_context.register_udf(create_session_user_udf());
998+
session_context.register_udtf("pg_get_keywords", static_tables.pg_get_keywords.clone());
999+
session_context.register_udf(create_pg_get_expr_udf());
1000+
session_context.register_udf(create_pg_get_partkeydef_udf());
8951001

8961002
Ok(())
8971003
}
@@ -1145,5 +1251,9 @@ mod test {
11451251
include_bytes!("../../pg_catalog_arrow_exports/pg_user_mapping.feather").to_vec(),
11461252
)
11471253
.expect("Failed to load ipc data");
1254+
let _ = ArrowTable::from_ipc_data(
1255+
include_bytes!("../../pg_catalog_arrow_exports/pg_get_keywords.feather").to_vec(),
1256+
)
1257+
.expect("Failed to load ipc data");
11481258
}
11491259
}

0 commit comments

Comments
 (0)