Skip to content

Commit a43f367

Browse files
committed
Merge branch 'master' into refactor/pg-catalog-record-batch-access
2 parents df3c0c3 + 9c9e167 commit a43f367

File tree

6 files changed

+266
-8
lines changed

6 files changed

+266
-8
lines changed

datafusion-postgres-cli/src/main.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use datafusion::execution::options::{
66
ArrowReadOptions, AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions,
77
};
88
use datafusion::prelude::{SessionConfig, SessionContext};
9+
use datafusion_postgres::auth::AuthManager;
910
use datafusion_postgres::pg_catalog::setup_pg_catalog;
1011
use datafusion_postgres::{serve, ServerOptions};
1112
use env_logger::Env;
@@ -124,6 +125,7 @@ impl Opt {
124125
async fn setup_session_context(
125126
session_context: &SessionContext,
126127
opts: &Opt,
128+
auth_manager: Arc<AuthManager>,
127129
) -> Result<(), Box<dyn std::error::Error>> {
128130
// Register CSV tables
129131
for (table_name, table_path) in opts.csv_tables.iter().map(|s| parse_table_def(s.as_ref())) {
@@ -179,7 +181,7 @@ async fn setup_session_context(
179181
}
180182

181183
// Register pg_catalog
182-
setup_pg_catalog(session_context, "datafusion")?;
184+
setup_pg_catalog(session_context, "datafusion", auth_manager)?;
183185

184186
Ok(())
185187
}
@@ -196,16 +198,17 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
196198

197199
let session_config = SessionConfig::new().with_information_schema(true);
198200
let session_context = SessionContext::new_with_config(session_config);
201+
let auth_manager = Arc::new(AuthManager::new());
199202

200-
setup_session_context(&session_context, &opts).await?;
203+
setup_session_context(&session_context, &opts, Arc::clone(&auth_manager)).await?;
201204

202205
let server_options = ServerOptions::new()
203206
.with_host(opts.host)
204207
.with_port(opts.port)
205208
.with_tls_cert_path(opts.tls_cert)
206209
.with_tls_key_path(opts.tls_key);
207210

208-
serve(Arc::new(session_context), &server_options)
211+
serve(Arc::new(session_context), &server_options, auth_manager)
209212
.await
210213
.map_err(|e| format!("Failed to run server: {e}"))?;
211214

datafusion-postgres/src/lib.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,8 @@ fn setup_tls(cert_path: &str, key_path: &str) -> Result<TlsAcceptor, IOError> {
8383
pub async fn serve(
8484
session_context: Arc<SessionContext>,
8585
opts: &ServerOptions,
86+
auth_manager: Arc<AuthManager>,
8687
) -> Result<(), std::io::Error> {
87-
// Create authentication manager
88-
let auth_manager = Arc::new(AuthManager::new());
89-
9088
// Create the handler factory with authentication
9189
let factory = Arc::new(HandlerFactory::new(session_context, auth_manager));
9290

datafusion-postgres/src/pg_catalog.rs

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ use std::sync::Arc;
44

55
use async_trait::async_trait;
66
use datafusion::arrow::array::{
7-
as_boolean_array, ArrayRef, AsArray, BooleanBuilder, RecordBatch, StringArray, StringBuilder,
7+
as_boolean_array, ArrayRef, AsArray, BooleanBuilder, Int32Builder, RecordBatch, StringArray,
8+
StringBuilder,
89
};
910
use datafusion::arrow::datatypes::{DataType, Field, Int32Type, SchemaRef};
1011
use datafusion::arrow::ipc::reader::FileReader;
@@ -19,6 +20,7 @@ use datafusion::prelude::{create_udf, Expr, SessionContext};
1920
use postgres_types::Oid;
2021
use tokio::sync::RwLock;
2122

23+
use crate::auth::AuthManager;
2224
use crate::pg_catalog::catalog_info::CatalogInfo;
2325
use crate::pg_catalog::empty_table::EmptyTable;
2426

@@ -32,7 +34,9 @@ pub mod pg_database;
3234
pub mod pg_get_expr_udf;
3335
pub mod pg_namespace;
3436
pub mod pg_replication_slot;
37+
pub mod pg_roles;
3538
pub mod pg_settings;
39+
pub mod pg_stat_gssapi;
3640
pub mod pg_tables;
3741
pub mod pg_views;
3842

@@ -100,7 +104,9 @@ const PG_CATALOG_TABLE_PG_USER_MAPPING: &str = "pg_user_mapping";
100104
const PG_CATALOG_VIEW_PG_SETTINGS: &str = "pg_settings";
101105
const PG_CATALOG_VIEW_PG_VIEWS: &str = "pg_views";
102106
const PG_CATALOG_VIEW_PG_MATVIEWS: &str = "pg_matviews";
107+
const PG_CATALOG_VIEW_PG_ROLES: &str = "pg_roles";
103108
const PG_CATALOG_VIEW_PG_TABLES: &str = "pg_tables";
109+
const PG_CATALOG_VIEW_PG_STAT_GSSAPI: &str = "pg_stat_gssapi";
104110
const PG_CATALOG_VIEW_PG_STAT_USER_TABLES: &str = "pg_stat_user_tables";
105111
const PG_CATALOG_VIEW_PG_REPLICATION_SLOTS: &str = "pg_replication_slots";
106112

@@ -166,7 +172,10 @@ pub const PG_CATALOG_TABLES: &[&str] = &[
166172
PG_CATALOG_TABLE_PG_TABLESPACE,
167173
PG_CATALOG_TABLE_PG_TRIGGER,
168174
PG_CATALOG_TABLE_PG_USER_MAPPING,
175+
PG_CATALOG_VIEW_PG_ROLES,
169176
PG_CATALOG_VIEW_PG_SETTINGS,
177+
PG_CATALOG_VIEW_PG_STAT_GSSAPI,
178+
PG_CATALOG_VIEW_PG_TABLES,
170179
PG_CATALOG_VIEW_PG_VIEWS,
171180
PG_CATALOG_VIEW_PG_MATVIEWS,
172181
PG_CATALOG_VIEW_PG_STAT_USER_TABLES,
@@ -188,6 +197,7 @@ pub struct PgCatalogSchemaProvider<C> {
188197
oid_counter: Arc<AtomicU32>,
189198
oid_cache: Arc<RwLock<HashMap<OidCacheKey, Oid>>>,
190199
static_tables: Arc<PgCatalogStaticTables>,
200+
auth_manager: Arc<AuthManager>,
191201
}
192202

193203
#[async_trait]
@@ -218,12 +228,14 @@ impl<C: CatalogInfo> PgCatalogSchemaProvider<C> {
218228
pub fn try_new(
219229
catalog_list: C,
220230
static_tables: Arc<PgCatalogStaticTables>,
231+
auth_manager: Arc<AuthManager>,
221232
) -> Result<PgCatalogSchemaProvider<C>> {
222233
Ok(Self {
223234
catalog_list,
224235
oid_counter: Arc::new(AtomicU32::new(16384)),
225236
oid_cache: Arc::new(RwLock::new(HashMap::new())),
226237
static_tables,
238+
auth_manager,
227239
})
228240
}
229241

@@ -395,6 +407,16 @@ impl<C: CatalogInfo> PgCatalogSchemaProvider<C> {
395407
let table = Arc::new(pg_settings::PgSettingsView::new());
396408
Ok(Some(PgCatalogTable::Dynamic(table)))
397409
}
410+
411+
PG_CATALOG_VIEW_PG_STAT_GSSAPI => {
412+
let table = Arc::new(pg_stat_gssapi::PgStatGssApiTable::new());
413+
Ok(Some(PgCatalogTable::Dynamic(table)))
414+
}
415+
PG_CATALOG_VIEW_PG_ROLES => {
416+
let table = Arc::new(pg_roles::PgRolesTable::new(Arc::clone(&self.auth_manager)));
417+
Ok(Some(PgCatalogTable::Dynamic(table)))
418+
}
419+
398420
PG_CATALOG_VIEW_PG_VIEWS => Ok(Some(pg_views::pg_views().into())),
399421
PG_CATALOG_VIEW_PG_MATVIEWS => Ok(Some(pg_views::pg_matviews().into())),
400422
PG_CATALOG_VIEW_PG_STAT_USER_TABLES => Ok(Some(pg_views::pg_stat_user_tables().into())),
@@ -1236,15 +1258,36 @@ pub fn create_pg_encoding_to_char_udf() -> ScalarUDF {
12361258
)
12371259
}
12381260

1261+
pub fn create_pg_backend_pid_udf() -> ScalarUDF {
1262+
let func = move |_args: &[ColumnarValue]| {
1263+
let mut builder = Int32Builder::new();
1264+
builder.append_value(BACKEND_PID);
1265+
let array: ArrayRef = Arc::new(builder.finish());
1266+
Ok(ColumnarValue::Array(array))
1267+
};
1268+
1269+
create_udf(
1270+
"pg_backend_pid",
1271+
vec![],
1272+
DataType::Int32,
1273+
Volatility::Stable,
1274+
Arc::new(func),
1275+
)
1276+
}
1277+
1278+
const BACKEND_PID: i32 = 1;
1279+
12391280
/// Install pg_catalog and postgres UDFs to current `SessionContext`
12401281
pub fn setup_pg_catalog(
12411282
session_context: &SessionContext,
12421283
catalog_name: &str,
1284+
auth_manager: Arc<AuthManager>,
12431285
) -> Result<(), Box<DataFusionError>> {
12441286
let static_tables = Arc::new(PgCatalogStaticTables::try_new()?);
12451287
let pg_catalog = PgCatalogSchemaProvider::try_new(
12461288
session_context.state().catalog_list().clone(),
12471289
static_tables.clone(),
1290+
auth_manager,
12481291
)?;
12491292
session_context
12501293
.catalog(catalog_name)
@@ -1281,6 +1324,7 @@ pub fn setup_pg_catalog(
12811324
session_context.register_udf(create_pg_relation_is_publishable_udf());
12821325
session_context.register_udf(create_pg_get_statisticsobjdef_columns_udf());
12831326
session_context.register_udf(create_pg_encoding_to_char_udf());
1327+
session_context.register_udf(create_pg_backend_pid_udf());
12841328

12851329
Ok(())
12861330
}
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
use std::sync::Arc;
2+
3+
use datafusion::arrow::array::{
4+
ArrayRef, BooleanArray, Int32Array, ListBuilder, RecordBatch, StringArray, StringBuilder,
5+
TimestampMicrosecondBuilder,
6+
};
7+
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
8+
use datafusion::error::Result;
9+
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
10+
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
11+
use datafusion::physical_plan::streaming::PartitionStream;
12+
13+
use crate::auth::AuthManager;
14+
15+
#[derive(Debug, Clone)]
16+
pub(crate) struct PgRolesTable {
17+
schema: SchemaRef,
18+
auth_manager: Arc<AuthManager>,
19+
}
20+
21+
impl PgRolesTable {
22+
pub(crate) fn new(auth_manager: Arc<AuthManager>) -> Self {
23+
let schema = Arc::new(Schema::new(vec![
24+
Field::new("rolname", DataType::Utf8, true),
25+
Field::new("rolsuper", DataType::Boolean, true),
26+
Field::new("rolinherit", DataType::Boolean, true),
27+
Field::new("rolcreaterole", DataType::Boolean, true),
28+
Field::new("rolcreatedb", DataType::Boolean, true),
29+
Field::new("rolcanlogin", DataType::Boolean, true),
30+
Field::new("rolreplication", DataType::Boolean, true),
31+
Field::new("rolconnlimit", DataType::Int32, true),
32+
Field::new("rolpassword", DataType::Utf8, true),
33+
Field::new(
34+
"rolvaliduntil",
35+
DataType::Timestamp(
36+
datafusion::arrow::datatypes::TimeUnit::Microsecond,
37+
Some(Arc::from("UTC")),
38+
),
39+
true,
40+
),
41+
Field::new("rolbypassrls", DataType::Boolean, true),
42+
Field::new(
43+
"rolconfig",
44+
DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
45+
true,
46+
),
47+
Field::new("oid", DataType::Int32, true),
48+
]));
49+
50+
Self {
51+
schema,
52+
auth_manager,
53+
}
54+
}
55+
56+
async fn get_data(this: Self) -> Result<RecordBatch> {
57+
let mut rolname = Vec::new();
58+
let mut rolsuper = Vec::new();
59+
let mut rolinherit = Vec::new();
60+
let mut rolcreaterole = Vec::new();
61+
let mut rolcreatedb = Vec::new();
62+
let mut rolcanlogin = Vec::new();
63+
let mut rolreplication = Vec::new();
64+
let mut rolconnlimit = Vec::new();
65+
let mut rolpassword: Vec<Option<String>> = Vec::new();
66+
let mut rolvaliduntil: Vec<Option<i64>> = Vec::new();
67+
let mut rolbypassrls = Vec::new();
68+
let mut rolconfig: Vec<Option<Vec<String>>> = Vec::new();
69+
let mut oid: Vec<i32> = Vec::new();
70+
71+
for role_name in &this.auth_manager.list_roles().await {
72+
let role = &this.auth_manager.get_role(role_name).await.unwrap();
73+
rolname.push(role.name.clone());
74+
rolsuper.push(role.is_superuser);
75+
rolinherit.push(true);
76+
rolcreaterole.push(role.can_create_role);
77+
rolcreatedb.push(role.can_create_db);
78+
rolcanlogin.push(role.can_login);
79+
rolreplication.push(role.can_replication);
80+
rolconnlimit.push(-1);
81+
rolpassword.push(None);
82+
rolvaliduntil.push(None);
83+
rolbypassrls.push(None);
84+
rolconfig.push(None);
85+
oid.push(0); // TODO: handle oid properly somehow
86+
}
87+
88+
let arrays: Vec<ArrayRef> = vec![
89+
Arc::new(StringArray::from(rolname)),
90+
Arc::new(BooleanArray::from(rolsuper)),
91+
Arc::new(BooleanArray::from(rolinherit)),
92+
Arc::new(BooleanArray::from(rolcreaterole)),
93+
Arc::new(BooleanArray::from(rolcreatedb)),
94+
Arc::new(BooleanArray::from(rolcanlogin)),
95+
Arc::new(BooleanArray::from(rolreplication)),
96+
Arc::new(Int32Array::from(rolconnlimit)),
97+
Arc::new(StringArray::from(rolpassword)),
98+
Arc::new({
99+
let mut builder =
100+
TimestampMicrosecondBuilder::with_capacity(rolconfig.len()).with_data_type(
101+
DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("UTC"))),
102+
);
103+
for field in &rolvaliduntil {
104+
builder.append_option(field.as_ref().copied());
105+
}
106+
builder.finish()
107+
}),
108+
Arc::new(BooleanArray::from(rolbypassrls)),
109+
Arc::new({
110+
let mut builder = ListBuilder::new(StringBuilder::new());
111+
for field in &rolconfig {
112+
match field {
113+
Some(values) => {
114+
for s in values {
115+
builder.values().append_value(s);
116+
}
117+
builder.append(true);
118+
}
119+
None => builder.append(false),
120+
}
121+
}
122+
builder.finish()
123+
}),
124+
Arc::new(Int32Array::from(oid)),
125+
];
126+
127+
Ok(RecordBatch::try_new(this.schema.clone(), arrays)?)
128+
}
129+
}
130+
131+
impl PartitionStream for PgRolesTable {
132+
fn schema(&self) -> &SchemaRef {
133+
&self.schema
134+
}
135+
136+
fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
137+
let this = self.clone();
138+
Box::pin(RecordBatchStreamAdapter::new(
139+
this.schema.clone(),
140+
futures::stream::once(async move { PgRolesTable::get_data(this).await }),
141+
))
142+
}
143+
}

0 commit comments

Comments
 (0)