Skip to content

Commit da23f3e

Browse files
committed
Add support for pg_roles
Implement support for the `pg_roles` table in the catalog, and use the existing information from the Roles setup to provide the data where possible.
1 parent 26fb0db commit da23f3e

File tree

5 files changed

+171
-7
lines changed

5 files changed

+171
-7
lines changed

datafusion-postgres-cli/src/main.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use datafusion::execution::options::{
66
ArrowReadOptions, AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions,
77
};
88
use datafusion::prelude::{SessionConfig, SessionContext};
9+
use datafusion_postgres::auth::AuthManager;
910
use datafusion_postgres::pg_catalog::setup_pg_catalog;
1011
use datafusion_postgres::{serve, ServerOptions};
1112
use env_logger::Env;
@@ -124,6 +125,7 @@ impl Opt {
124125
async fn setup_session_context(
125126
session_context: &SessionContext,
126127
opts: &Opt,
128+
auth_manager: Arc<AuthManager>,
127129
) -> Result<(), Box<dyn std::error::Error>> {
128130
// Register CSV tables
129131
for (table_name, table_path) in opts.csv_tables.iter().map(|s| parse_table_def(s.as_ref())) {
@@ -179,7 +181,7 @@ async fn setup_session_context(
179181
}
180182

181183
// Register pg_catalog
182-
setup_pg_catalog(session_context, "datafusion")?;
184+
setup_pg_catalog(session_context, "datafusion", auth_manager)?;
183185

184186
Ok(())
185187
}
@@ -196,16 +198,17 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
196198

197199
let session_config = SessionConfig::new().with_information_schema(true);
198200
let session_context = SessionContext::new_with_config(session_config);
201+
let auth_manager = Arc::new(AuthManager::new());
199202

200-
setup_session_context(&session_context, &opts).await?;
203+
setup_session_context(&session_context, &opts, Arc::clone(&auth_manager)).await?;
201204

202205
let server_options = ServerOptions::new()
203206
.with_host(opts.host)
204207
.with_port(opts.port)
205208
.with_tls_cert_path(opts.tls_cert)
206209
.with_tls_key_path(opts.tls_key);
207210

208-
serve(Arc::new(session_context), &server_options)
211+
serve(Arc::new(session_context), &server_options, auth_manager)
209212
.await
210213
.map_err(|e| format!("Failed to run server: {e}"))?;
211214

datafusion-postgres/src/lib.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,8 @@ fn setup_tls(cert_path: &str, key_path: &str) -> Result<TlsAcceptor, IOError> {
8383
pub async fn serve(
8484
session_context: Arc<SessionContext>,
8585
opts: &ServerOptions,
86+
auth_manager: Arc<AuthManager>,
8687
) -> Result<(), std::io::Error> {
87-
// Create authentication manager
88-
let auth_manager = Arc::new(AuthManager::new());
89-
9088
// Create the handler factory with authentication
9189
let factory = Arc::new(HandlerFactory::new(session_context, auth_manager));
9290

datafusion-postgres/src/pg_catalog.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use datafusion::prelude::{create_udf, Expr, SessionContext};
2020
use postgres_types::Oid;
2121
use tokio::sync::RwLock;
2222

23+
use crate::auth::AuthManager;
2324
use crate::pg_catalog::catalog_info::CatalogInfo;
2425

2526
pub mod catalog_info;
@@ -32,6 +33,7 @@ pub mod pg_database;
3233
pub mod pg_get_expr_udf;
3334
pub mod pg_namespace;
3435
pub mod pg_replication_slot;
36+
pub mod pg_roles;
3537
pub mod pg_settings;
3638
pub mod pg_stat_gssapi;
3739
pub mod pg_tables;
@@ -101,6 +103,7 @@ const PG_CATALOG_TABLE_PG_USER_MAPPING: &str = "pg_user_mapping";
101103
const PG_CATALOG_VIEW_PG_SETTINGS: &str = "pg_settings";
102104
const PG_CATALOG_VIEW_PG_VIEWS: &str = "pg_views";
103105
const PG_CATALOG_VIEW_PG_MATVIEWS: &str = "pg_matviews";
106+
const PG_CATALOG_VIEW_PG_ROLES: &str = "pg_roles";
104107
const PG_CATALOG_VIEW_PG_TABLES: &str = "pg_tables";
105108
const PG_CATALOG_VIEW_PG_STAT_GSSAPI: &str = "pg_stat_gssapi";
106109
const PG_CATALOG_VIEW_PG_STAT_USER_TABLES: &str = "pg_stat_user_tables";
@@ -190,6 +193,7 @@ pub struct PgCatalogSchemaProvider<C> {
190193
oid_counter: Arc<AtomicU32>,
191194
oid_cache: Arc<RwLock<HashMap<OidCacheKey, Oid>>>,
192195
static_tables: Arc<PgCatalogStaticTables>,
196+
auth_manager: Arc<AuthManager>,
193197
}
194198

195199
#[async_trait]
@@ -349,6 +353,13 @@ impl<C: CatalogInfo> SchemaProvider for PgCatalogSchemaProvider<C> {
349353
vec![table],
350354
)?)))
351355
}
356+
PG_CATALOG_VIEW_PG_ROLES => {
357+
let table = Arc::new(pg_roles::PgRolesTable::new(Arc::clone(&self.auth_manager)));
358+
Ok(Some(Arc::new(StreamingTable::try_new(
359+
Arc::clone(table.schema()),
360+
vec![table],
361+
)?)))
362+
}
352363
PG_CATALOG_VIEW_PG_TABLES => {
353364
let table = Arc::new(pg_tables::PgTablesTable::new(self.catalog_list.clone()));
354365
Ok(Some(Arc::new(StreamingTable::try_new(
@@ -382,12 +393,14 @@ impl<C: CatalogInfo> PgCatalogSchemaProvider<C> {
382393
pub fn try_new(
383394
catalog_list: C,
384395
static_tables: Arc<PgCatalogStaticTables>,
396+
auth_manager: Arc<AuthManager>,
385397
) -> Result<PgCatalogSchemaProvider<C>> {
386398
Ok(Self {
387399
catalog_list,
388400
oid_counter: Arc::new(AtomicU32::new(16384)),
389401
oid_cache: Arc::new(RwLock::new(HashMap::new())),
390402
static_tables,
403+
auth_manager,
391404
})
392405
}
393406
}
@@ -1195,11 +1208,13 @@ const BACKEND_PID: i32 = 1;
11951208
pub fn setup_pg_catalog(
11961209
session_context: &SessionContext,
11971210
catalog_name: &str,
1211+
auth_manager: Arc<AuthManager>,
11981212
) -> Result<(), Box<DataFusionError>> {
11991213
let static_tables = Arc::new(PgCatalogStaticTables::try_new()?);
12001214
let pg_catalog = PgCatalogSchemaProvider::try_new(
12011215
session_context.state().catalog_list().clone(),
12021216
static_tables.clone(),
1217+
auth_manager,
12031218
)?;
12041219
session_context
12051220
.catalog(catalog_name)
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
use std::sync::Arc;
2+
3+
use datafusion::arrow::array::{
4+
ArrayRef, BooleanArray, Int32Array, ListBuilder, RecordBatch, StringArray, StringBuilder,
5+
TimestampMicrosecondBuilder,
6+
};
7+
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
8+
use datafusion::error::Result;
9+
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
10+
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
11+
use datafusion::physical_plan::streaming::PartitionStream;
12+
13+
use crate::auth::AuthManager;
14+
15+
#[derive(Debug, Clone)]
16+
pub(crate) struct PgRolesTable {
17+
schema: SchemaRef,
18+
auth_manager: Arc<AuthManager>,
19+
}
20+
21+
impl PgRolesTable {
22+
pub(crate) fn new(auth_manager: Arc<AuthManager>) -> Self {
23+
let schema = Arc::new(Schema::new(vec![
24+
Field::new("rolname", DataType::Utf8, true),
25+
Field::new("rolsuper", DataType::Boolean, true),
26+
Field::new("rolinherit", DataType::Boolean, true),
27+
Field::new("rolcreaterole", DataType::Boolean, true),
28+
Field::new("rolcreatedb", DataType::Boolean, true),
29+
Field::new("rolcanlogin", DataType::Boolean, true),
30+
Field::new("rolreplication", DataType::Boolean, true),
31+
Field::new("rolconnlimit", DataType::Int32, true),
32+
Field::new("rolpassword", DataType::Utf8, true),
33+
Field::new(
34+
"rolvaliduntil",
35+
DataType::Timestamp(
36+
datafusion::arrow::datatypes::TimeUnit::Microsecond,
37+
Some(Arc::from("UTC")),
38+
),
39+
true,
40+
),
41+
Field::new("rolbypassrls", DataType::Boolean, true),
42+
Field::new(
43+
"rolconfig",
44+
DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
45+
true,
46+
),
47+
Field::new("oid", DataType::Int32, true),
48+
]));
49+
50+
Self {
51+
schema,
52+
auth_manager,
53+
}
54+
}
55+
56+
async fn get_data(this: Self) -> Result<RecordBatch> {
57+
let mut rolname = Vec::new();
58+
let mut rolsuper = Vec::new();
59+
let mut rolinherit = Vec::new();
60+
let mut rolcreaterole = Vec::new();
61+
let mut rolcreatedb = Vec::new();
62+
let mut rolcanlogin = Vec::new();
63+
let mut rolreplication = Vec::new();
64+
let mut rolconnlimit = Vec::new();
65+
let mut rolpassword: Vec<Option<String>> = Vec::new();
66+
let mut rolvaliduntil: Vec<Option<i64>> = Vec::new();
67+
let mut rolbypassrls = Vec::new();
68+
let mut rolconfig: Vec<Option<Vec<String>>> = Vec::new();
69+
let mut oid: Vec<i32> = Vec::new();
70+
71+
for role_name in &this.auth_manager.list_roles().await {
72+
let role = &this.auth_manager.get_role(&role_name).await.unwrap();
73+
rolname.push(role.name.clone());
74+
rolsuper.push(role.is_superuser);
75+
rolinherit.push(true);
76+
rolcreaterole.push(role.can_create_role);
77+
rolcreatedb.push(role.can_create_db);
78+
rolcanlogin.push(role.can_login);
79+
rolreplication.push(role.can_replication);
80+
rolconnlimit.push(-1);
81+
rolpassword.push(None);
82+
rolvaliduntil.push(None);
83+
rolbypassrls.push(None);
84+
rolconfig.push(None);
85+
oid.push(0); // TODO: handle oid properly somehow
86+
}
87+
88+
let arrays: Vec<ArrayRef> = vec![
89+
Arc::new(StringArray::from(rolname)),
90+
Arc::new(BooleanArray::from(rolsuper)),
91+
Arc::new(BooleanArray::from(rolinherit)),
92+
Arc::new(BooleanArray::from(rolcreaterole)),
93+
Arc::new(BooleanArray::from(rolcreatedb)),
94+
Arc::new(BooleanArray::from(rolcanlogin)),
95+
Arc::new(BooleanArray::from(rolreplication)),
96+
Arc::new(Int32Array::from(rolconnlimit)),
97+
Arc::new(StringArray::from(rolpassword)),
98+
Arc::new({
99+
let mut builder =
100+
TimestampMicrosecondBuilder::with_capacity(rolconfig.len()).with_data_type(
101+
DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("UTC"))),
102+
);
103+
for field in &rolvaliduntil {
104+
builder.append_option(field.as_ref().copied());
105+
}
106+
builder.finish()
107+
}),
108+
Arc::new(BooleanArray::from(rolbypassrls)),
109+
Arc::new({
110+
let mut builder = ListBuilder::new(StringBuilder::new());
111+
for field in &rolconfig {
112+
match field {
113+
Some(values) => {
114+
for s in values {
115+
builder.values().append_value(s);
116+
}
117+
builder.append(true);
118+
}
119+
None => builder.append(false),
120+
}
121+
}
122+
builder.finish()
123+
}),
124+
Arc::new(Int32Array::from(oid)),
125+
];
126+
127+
Ok(RecordBatch::try_new(this.schema.clone(), arrays)?)
128+
}
129+
}
130+
131+
impl PartitionStream for PgRolesTable {
132+
fn schema(&self) -> &SchemaRef {
133+
&self.schema
134+
}
135+
136+
fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
137+
let this = self.clone();
138+
Box::pin(RecordBatchStreamAdapter::new(
139+
this.schema.clone(),
140+
futures::stream::once(async move { PgRolesTable::get_data(this).await }),
141+
))
142+
}
143+
}

datafusion-postgres/tests/common/mod.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@ use pgwire::{
1212

1313
pub fn setup_handlers() -> DfSessionService {
1414
let session_context = SessionContext::new();
15-
setup_pg_catalog(&session_context, "datafusion").expect("Failed to setup sesession context");
15+
setup_pg_catalog(
16+
&session_context,
17+
"datafusion",
18+
Arc::new(AuthManager::default()),
19+
)
20+
.expect("Failed to setup sesession context");
1621

1722
DfSessionService::new(Arc::new(session_context), Arc::new(AuthManager::new()))
1823
}

0 commit comments

Comments
 (0)