Skip to content

Commit af30b53

Browse files
committed
feat: add a serve function
1 parent b0c1a1d commit af30b53

File tree

3 files changed

+67
-33
lines changed

3 files changed

+67
-33
lines changed

datafusion-postgres-cli/src/main.rs

Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
1-
use std::sync::Arc;
2-
31
use datafusion::execution::options::{
42
ArrowReadOptions, AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions,
53
};
64
use datafusion::prelude::SessionContext;
7-
use datafusion_postgres::{DfSessionService, HandlerFactory}; // Assuming the crate name is `datafusion_postgres`
8-
use pgwire::tokio::process_socket;
5+
use datafusion_postgres::{serve, ServerOptions}; // Assuming the crate name is `datafusion_postgres`
96
use structopt::StructOpt;
10-
use tokio::net::TcpListener;
117

128
#[derive(Debug, StructOpt)]
139
#[structopt(
@@ -103,33 +99,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
10399
println!("Loaded {} as table {}", table_path, table_name);
104100
}
105101

106-
// Get the first catalog name from the session context
107-
let catalog_name = session_context
108-
.catalog_names() // Fixed: Removed .catalog_list()
109-
.first()
110-
.cloned();
111-
112-
// Create the handler factory with the session context and catalog name
113-
let factory = Arc::new(HandlerFactory(Arc::new(DfSessionService::new(
114-
session_context,
115-
catalog_name,
116-
))));
102+
let server_options = ServerOptions {
103+
host: opts.host.clone(),
104+
port: opts.port,
105+
};
117106

118-
// Bind to the specified host and port
119-
let server_addr = format!("{}:{}", opts.host, opts.port);
120-
let listener = TcpListener::bind(&server_addr).await?;
121-
println!("Listening on {}", server_addr);
107+
serve(session_context, &server_options)
108+
.await
109+
.map_err(|e| format!("Failed to run server: {}", e))?;
122110

123-
// Accept incoming connections
124-
loop {
125-
let (socket, addr) = listener.accept().await?;
126-
let factory_ref = factory.clone();
127-
println!("Accepted connection from {}", addr);
128-
129-
tokio::spawn(async move {
130-
if let Err(e) = process_socket(socket, None, factory_ref).await {
131-
eprintln!("Error processing socket: {}", e);
132-
}
133-
});
134-
}
111+
Ok(())
135112
}

datafusion-postgres/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ log = "0.4"
2525
pgwire = { workspace = true }
2626
postgres-types = "0.2"
2727
rust_decimal = { version = "1.37", features = ["db-postgres"] }
28-
tokio = { version = "1.45", features = ["sync"] }
28+
tokio = { version = "1.45", features = ["sync", "net"] }

datafusion-postgres/src/lib.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,60 @@ mod handlers;
44
mod information_schema;
55

66
pub use handlers::{DfSessionService, HandlerFactory, Parser};
7+
8+
use std::sync::Arc;
9+
10+
use datafusion::prelude::SessionContext;
11+
use pgwire::tokio::process_socket;
12+
use tokio::net::TcpListener;
13+
14+
pub struct ServerOptions {
15+
pub host: String,
16+
pub port: u16,
17+
}
18+
19+
impl Default for ServerOptions {
20+
fn default() -> Self {
21+
ServerOptions {
22+
host: "127.0.0.1".to_string(),
23+
port: 5432,
24+
}
25+
}
26+
}
27+
28+
/// Serve the Datafusion `SessionContext` with Postgres protocol.
29+
pub async fn serve(
30+
session_context: SessionContext,
31+
opts: &ServerOptions,
32+
) -> Result<(), std::io::Error> {
33+
// Get the first catalog name from the session context
34+
let catalog_name = session_context
35+
.catalog_names() // Fixed: Removed .catalog_list()
36+
.first()
37+
.cloned();
38+
39+
// Create the handler factory with the session context and catalog name
40+
let factory = Arc::new(HandlerFactory(Arc::new(DfSessionService::new(
41+
session_context,
42+
catalog_name,
43+
))));
44+
45+
// Bind to the specified host and port
46+
let server_addr = format!("{}:{}", opts.host, opts.port);
47+
let listener = TcpListener::bind(&server_addr).await?;
48+
println!("Listening on {}", server_addr);
49+
50+
// Accept incoming connections
51+
loop {
52+
if let Ok((socket, addr)) = listener.accept().await {
53+
let factory_ref = factory.clone();
54+
println!("Accepted connection from {}", addr);
55+
56+
tokio::spawn(async move {
57+
if let Err(e) = process_socket(socket, None, factory_ref).await {
58+
eprintln!("Error processing socket: {}", e);
59+
}
60+
});
61+
};
62+
}
63+
}

0 commit comments

Comments
 (0)