Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 57 additions & 23 deletions src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,50 @@ impl RivetEngine {
Ok(())
}

/// Connect to a new external data source and register it as a catalog.
pub async fn connect(&self, name: &str, source: Source) -> Result<()> {
/// Register a connection without discovering tables.
///
/// This persists the connection config to the catalog and registers it with DataFusion,
/// but does not attempt to connect to the remote database or discover tables.
/// Use `discover_connection()` to discover tables after registration.
pub async fn register_connection(&self, name: &str, source: Source) -> Result<i32> {
let source_type = source.source_type();

// Store config as JSON (includes "type" from serde tag)
let config_json = serde_json::to_string(&source)?;
let conn_id = self
.catalog
.add_connection(name, source_type, &config_json)
.await?;

// Register with DataFusion (empty catalog - no tables yet)
let catalog_provider = Arc::new(RivetCatalogProvider::new(
conn_id,
name.to_string(),
Arc::new(source),
self.catalog.clone(),
self.orchestrator.clone(),
)) as Arc<dyn CatalogProvider>;

self.df_ctx.register_catalog(name, catalog_provider);

info!("Connection '{}' registered (discovery pending)", name);

Ok(conn_id)
}

/// Discover tables for an existing connection.
///
/// Connects to the remote database, discovers available tables, and stores
/// their metadata in the catalog. Returns the number of tables discovered.
pub async fn discover_connection(&self, name: &str) -> Result<usize> {
// Get connection info
let conn = self
.catalog
.get_connection(name)
.await?
.ok_or_else(|| anyhow::anyhow!("Connection '{}' not found", name))?;

let source: Source = serde_json::from_str(&conn.config_json)?;
let source_type = source.source_type();

// Discover tables
Expand All @@ -224,40 +266,32 @@ impl RivetEngine {

info!("Discovered {} tables", tables.len());

// Store config as JSON (includes "type" from serde tag)
let config_json = serde_json::to_string(&source)?;
let conn_id = self
.catalog
.add_connection(name, source_type, &config_json)
.await?;

// Add discovered tables to catalog with schema in one call
// Add discovered tables to catalog with schema
for table in &tables {
let schema = table.to_arrow_schema();
let schema_json = serde_json::to_string(schema.as_ref())
.map_err(|e| anyhow::anyhow!("Failed to serialize schema: {}", e))?;
self.catalog
.add_table(conn_id, &table.schema_name, &table.table_name, &schema_json)
.add_table(conn.id, &table.schema_name, &table.table_name, &schema_json)
.await?;
}

// Register with DataFusion
let catalog_provider = Arc::new(RivetCatalogProvider::new(
conn_id,
name.to_string(),
Arc::new(source),
self.catalog.clone(),
self.orchestrator.clone(),
)) as Arc<dyn CatalogProvider>;

self.df_ctx.register_catalog(name, catalog_provider);

info!(
"Connection '{}' registered with {} tables",
"Connection '{}' discovery complete: {} tables",
name,
tables.len()
);

Ok(tables.len())
}

/// Connect to a new external data source and register it as a catalog.
///
/// This is a convenience method that combines `register_connection()` and
/// `discover_connection()`. For more control, use those methods separately.
pub async fn connect(&self, name: &str, source: Source) -> Result<()> {
self.register_connection(name, source).await?;
self.discover_connection(name).await?;
Ok(())
}

Expand Down
8 changes: 5 additions & 3 deletions src/http/app_server.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::http::handlers::{
create_connection_handler, delete_connection_handler, get_connection_handler, health_handler,
list_connections_handler, purge_connection_cache_handler, purge_table_cache_handler,
query_handler, tables_handler,
create_connection_handler, delete_connection_handler, discover_connection_handler,
get_connection_handler, health_handler, list_connections_handler,
purge_connection_cache_handler, purge_table_cache_handler, query_handler, tables_handler,
};
use crate::RivetEngine;
use axum::routing::{delete, get, post};
Expand All @@ -18,6 +18,7 @@ pub const PATH_TABLES: &str = "/tables";
pub const PATH_HEALTH: &str = "/health";
pub const PATH_CONNECTIONS: &str = "/connections";
pub const PATH_CONNECTION: &str = "/connections/{name}";
pub const PATH_CONNECTION_DISCOVER: &str = "/connections/{name}/discover";
pub const PATH_CONNECTION_CACHE: &str = "/connections/{name}/cache";
pub const PATH_TABLE_CACHE: &str = "/connections/{name}/tables/{schema}/{table}/cache";

Expand All @@ -37,6 +38,7 @@ impl AppServer {
PATH_CONNECTION,
get(get_connection_handler).delete(delete_connection_handler),
)
.route(PATH_CONNECTION_DISCOVER, post(discover_connection_handler))
.route(
PATH_CONNECTION_CACHE,
delete(purge_connection_cache_handler),
Expand Down
91 changes: 68 additions & 23 deletions src/http/handlers.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::http::error::ApiError;
use crate::http::models::{
ConnectionInfo, CreateConnectionRequest, CreateConnectionResponse, GetConnectionResponse,
ListConnectionsResponse, QueryRequest, QueryResponse, TableInfo, TablesResponse,
ConnectionInfo, CreateConnectionRequest, CreateConnectionResponse, DiscoverConnectionResponse,
DiscoveryStatus, GetConnectionResponse, ListConnectionsResponse, QueryRequest, QueryResponse,
TableInfo, TablesResponse,
};
use crate::http::serialization::{encode_value_at, make_array_encoder};
use crate::source::Source;
Expand Down Expand Up @@ -186,39 +187,83 @@ pub async fn create_connection_handler(

let source_type = source.source_type().to_string();

// Attempt to connect (discovers tables and registers catalog)
engine.connect(&request.name, source).await.map_err(|e| {
error!("Failed to connect to database: {}", e);
// Extract root cause message only - don't expose full stack trace to clients
let root_cause = e.root_cause().to_string();
let msg = root_cause.lines().next().unwrap_or("Unknown error");

if msg.contains("Failed to connect") || msg.contains("connection refused") {
ApiError::bad_gateway(format!("Failed to connect to database: {}", msg))
} else if msg.contains("Unsupported source type") || msg.contains("Invalid configuration") {
ApiError::bad_request(msg.to_string())
} else {
ApiError::bad_gateway(format!("Failed to connect to database: {}", msg))
}
})?;

// Count discovered tables
let tables_discovered = engine
.list_tables(Some(&request.name))
// Step 1: Register the connection
engine
.register_connection(&request.name, source)
.await
.map(|t| t.len())
.unwrap_or(0);
.map_err(|e| {
error!("Failed to register connection: {}", e);
ApiError::internal_error(format!("Failed to register connection: {}", e))
})?;

// Step 2: Attempt discovery - catch errors and return partial success
let (tables_discovered, discovery_status, discovery_error) =
match engine.discover_connection(&request.name).await {
Ok(count) => (count, DiscoveryStatus::Success, None),
Err(e) => {
let root_cause = e.root_cause().to_string();
let msg = root_cause
.lines()
.next()
.unwrap_or("Unknown error")
.to_string();
error!(
"Discovery failed for connection '{}': {}",
request.name, msg
);
(0, DiscoveryStatus::Failed, Some(msg))
}
};

Ok((
StatusCode::CREATED,
Json(CreateConnectionResponse {
name: request.name,
source_type,
tables_discovered,
discovery_status,
discovery_error,
}),
))
}

/// Handler for POST /connections/{name}/discover
pub async fn discover_connection_handler(
State(engine): State<Arc<RivetEngine>>,
Path(name): Path<String>,
) -> Result<Json<DiscoverConnectionResponse>, ApiError> {
// Validate connection exists
if engine.catalog().get_connection(&name).await?.is_none() {
return Err(ApiError::not_found(format!(
"Connection '{}' not found",
name
)));
}

// Attempt discovery
let (tables_discovered, discovery_status, discovery_error) =
match engine.discover_connection(&name).await {
Ok(count) => (count, DiscoveryStatus::Success, None),
Err(e) => {
let root_cause = e.root_cause().to_string();
let msg = root_cause
.lines()
.next()
.unwrap_or("Unknown error")
.to_string();
error!("Discovery failed for connection '{}': {}", name, msg);
(0, DiscoveryStatus::Failed, Some(msg))
}
};

Ok(Json(DiscoverConnectionResponse {
name,
tables_discovered,
discovery_status,
discovery_error,
}))
}

/// Handler for GET /connections
pub async fn list_connections_handler(
State(engine): State<Arc<RivetEngine>>,
Expand Down
25 changes: 25 additions & 0 deletions src/http/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,37 @@ pub struct CreateConnectionRequest {
pub config: serde_json::Value,
}

/// Discovery status for connection creation
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum DiscoveryStatus {
/// Discovery succeeded
Success,
/// Discovery was skipped (e.g., skip_discovery=true)
Skipped,
/// Discovery failed (connection still registered)
Failed,
}

/// Response body for POST /connections
#[derive(Debug, Serialize)]
pub struct CreateConnectionResponse {
pub name: String,
pub source_type: String,
pub tables_discovered: usize,
pub discovery_status: DiscoveryStatus,
#[serde(skip_serializing_if = "Option::is_none")]
pub discovery_error: Option<String>,
}

/// Response body for POST /connections/{name}/discover
#[derive(Debug, Serialize)]
pub struct DiscoverConnectionResponse {
pub name: String,
pub tables_discovered: usize,
pub discovery_status: DiscoveryStatus,
#[serde(skip_serializing_if = "Option::is_none")]
pub discovery_error: Option<String>,
}

/// Single connection metadata for API responses
Expand Down
Loading
Loading