diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index ac62f13..ab03bcf 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -63,6 +63,7 @@ impl PgWireServerHandlers for HandlerFactory { } } +/// The pgwire handler backed by a datafusion `SessionContext` pub struct DfSessionService { session_context: Arc, parser: Arc, diff --git a/datafusion-postgres/src/lib.rs b/datafusion-postgres/src/lib.rs index c47253b..03e6591 100644 --- a/datafusion-postgres/src/lib.rs +++ b/datafusion-postgres/src/lib.rs @@ -10,6 +10,7 @@ use datafusion::prelude::SessionContext; pub mod auth; use getset::{Getters, Setters, WithSetters}; use log::{info, warn}; +use pgwire::api::PgWireServerHandlers; use pgwire::tokio::process_socket; use rustls_pemfile::{certs, pkcs8_private_keys}; use rustls_pki_types::{CertificateDer, PrivateKeyDer}; @@ -81,6 +82,18 @@ pub async fn serve( // Create the handler factory with authentication let factory = Arc::new(HandlerFactory::new(session_context, auth_manager)); + serve_with_handlers(factory, opts).await +} + +/// Serve with custom pgwire handlers +/// +/// This function allows you to rewrite some of the built-in logic including +/// authentication and query processing. You can Implement your own +/// `PgWireServerHandlers` by reusing `DfSessionService`. +pub async fn serve_with_handlers( + handlers: Arc, + opts: &ServerOptions, +) -> Result<(), std::io::Error> { // Set up TLS if configured let tls_acceptor = if let (Some(cert_path), Some(key_path)) = (&opts.tls_cert_path, &opts.tls_key_path) { @@ -112,9 +125,8 @@ pub async fn serve( loop { match listener.accept().await { Ok((socket, _addr)) => { - let factory_ref = factory.clone(); + let factory_ref = handlers.clone(); let tls_acceptor_ref = tls_acceptor.clone(); - // Connection accepted from {addr} - log appropriately based on your logging strategy tokio::spawn(async move { if let Err(e) = process_socket(socket, tls_acceptor_ref, factory_ref).await {