diff --git a/src/engine.rs b/src/engine.rs index 445d68b..68aafc1 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -210,8 +210,50 @@ impl RivetEngine { Ok(()) } - /// Connect to a new external data source and register it as a catalog. - pub async fn connect(&self, name: &str, source: Source) -> Result<()> { + /// Register a connection without discovering tables. + /// + /// This persists the connection config to the catalog and registers it with DataFusion, + /// but does not attempt to connect to the remote database or discover tables. + /// Use `discover_connection()` to discover tables after registration. + pub async fn register_connection(&self, name: &str, source: Source) -> Result { + let source_type = source.source_type(); + + // Store config as JSON (includes "type" from serde tag) + let config_json = serde_json::to_string(&source)?; + let conn_id = self + .catalog + .add_connection(name, source_type, &config_json) + .await?; + + // Register with DataFusion (empty catalog - no tables yet) + let catalog_provider = Arc::new(RivetCatalogProvider::new( + conn_id, + name.to_string(), + Arc::new(source), + self.catalog.clone(), + self.orchestrator.clone(), + )) as Arc; + + self.df_ctx.register_catalog(name, catalog_provider); + + info!("Connection '{}' registered (discovery pending)", name); + + Ok(conn_id) + } + + /// Discover tables for an existing connection. + /// + /// Connects to the remote database, discovers available tables, and stores + /// their metadata in the catalog. Returns the number of tables discovered. + pub async fn discover_connection(&self, name: &str) -> Result { + // Get connection info + let conn = self + .catalog + .get_connection(name) + .await? + .ok_or_else(|| anyhow::anyhow!("Connection '{}' not found", name))?; + + let source: Source = serde_json::from_str(&conn.config_json)?; let source_type = source.source_type(); // Discover tables @@ -224,40 +266,32 @@ impl RivetEngine { info!("Discovered {} tables", tables.len()); - // Store config as JSON (includes "type" from serde tag) - let config_json = serde_json::to_string(&source)?; - let conn_id = self - .catalog - .add_connection(name, source_type, &config_json) - .await?; - - // Add discovered tables to catalog with schema in one call + // Add discovered tables to catalog with schema for table in &tables { let schema = table.to_arrow_schema(); let schema_json = serde_json::to_string(schema.as_ref()) .map_err(|e| anyhow::anyhow!("Failed to serialize schema: {}", e))?; self.catalog - .add_table(conn_id, &table.schema_name, &table.table_name, &schema_json) + .add_table(conn.id, &table.schema_name, &table.table_name, &schema_json) .await?; } - // Register with DataFusion - let catalog_provider = Arc::new(RivetCatalogProvider::new( - conn_id, - name.to_string(), - Arc::new(source), - self.catalog.clone(), - self.orchestrator.clone(), - )) as Arc; - - self.df_ctx.register_catalog(name, catalog_provider); - info!( - "Connection '{}' registered with {} tables", + "Connection '{}' discovery complete: {} tables", name, tables.len() ); + Ok(tables.len()) + } + + /// Connect to a new external data source and register it as a catalog. + /// + /// This is a convenience method that combines `register_connection()` and + /// `discover_connection()`. For more control, use those methods separately. + pub async fn connect(&self, name: &str, source: Source) -> Result<()> { + self.register_connection(name, source).await?; + self.discover_connection(name).await?; Ok(()) } diff --git a/src/http/app_server.rs b/src/http/app_server.rs index 6f69561..471953b 100644 --- a/src/http/app_server.rs +++ b/src/http/app_server.rs @@ -1,7 +1,7 @@ use crate::http::handlers::{ - create_connection_handler, delete_connection_handler, get_connection_handler, health_handler, - list_connections_handler, purge_connection_cache_handler, purge_table_cache_handler, - query_handler, tables_handler, + create_connection_handler, delete_connection_handler, discover_connection_handler, + get_connection_handler, health_handler, list_connections_handler, + purge_connection_cache_handler, purge_table_cache_handler, query_handler, tables_handler, }; use crate::RivetEngine; use axum::routing::{delete, get, post}; @@ -18,6 +18,7 @@ pub const PATH_TABLES: &str = "/tables"; pub const PATH_HEALTH: &str = "/health"; pub const PATH_CONNECTIONS: &str = "/connections"; pub const PATH_CONNECTION: &str = "/connections/{name}"; +pub const PATH_CONNECTION_DISCOVER: &str = "/connections/{name}/discover"; pub const PATH_CONNECTION_CACHE: &str = "/connections/{name}/cache"; pub const PATH_TABLE_CACHE: &str = "/connections/{name}/tables/{schema}/{table}/cache"; @@ -37,6 +38,7 @@ impl AppServer { PATH_CONNECTION, get(get_connection_handler).delete(delete_connection_handler), ) + .route(PATH_CONNECTION_DISCOVER, post(discover_connection_handler)) .route( PATH_CONNECTION_CACHE, delete(purge_connection_cache_handler), diff --git a/src/http/handlers.rs b/src/http/handlers.rs index 5299b6f..103f626 100644 --- a/src/http/handlers.rs +++ b/src/http/handlers.rs @@ -1,7 +1,8 @@ use crate::http::error::ApiError; use crate::http::models::{ - ConnectionInfo, CreateConnectionRequest, CreateConnectionResponse, GetConnectionResponse, - ListConnectionsResponse, QueryRequest, QueryResponse, TableInfo, TablesResponse, + ConnectionInfo, CreateConnectionRequest, CreateConnectionResponse, DiscoverConnectionResponse, + DiscoveryStatus, GetConnectionResponse, ListConnectionsResponse, QueryRequest, QueryResponse, + TableInfo, TablesResponse, }; use crate::http::serialization::{encode_value_at, make_array_encoder}; use crate::source::Source; @@ -186,28 +187,33 @@ pub async fn create_connection_handler( let source_type = source.source_type().to_string(); - // Attempt to connect (discovers tables and registers catalog) - engine.connect(&request.name, source).await.map_err(|e| { - error!("Failed to connect to database: {}", e); - // Extract root cause message only - don't expose full stack trace to clients - let root_cause = e.root_cause().to_string(); - let msg = root_cause.lines().next().unwrap_or("Unknown error"); - - if msg.contains("Failed to connect") || msg.contains("connection refused") { - ApiError::bad_gateway(format!("Failed to connect to database: {}", msg)) - } else if msg.contains("Unsupported source type") || msg.contains("Invalid configuration") { - ApiError::bad_request(msg.to_string()) - } else { - ApiError::bad_gateway(format!("Failed to connect to database: {}", msg)) - } - })?; - - // Count discovered tables - let tables_discovered = engine - .list_tables(Some(&request.name)) + // Step 1: Register the connection + engine + .register_connection(&request.name, source) .await - .map(|t| t.len()) - .unwrap_or(0); + .map_err(|e| { + error!("Failed to register connection: {}", e); + ApiError::internal_error(format!("Failed to register connection: {}", e)) + })?; + + // Step 2: Attempt discovery - catch errors and return partial success + let (tables_discovered, discovery_status, discovery_error) = + match engine.discover_connection(&request.name).await { + Ok(count) => (count, DiscoveryStatus::Success, None), + Err(e) => { + let root_cause = e.root_cause().to_string(); + let msg = root_cause + .lines() + .next() + .unwrap_or("Unknown error") + .to_string(); + error!( + "Discovery failed for connection '{}': {}", + request.name, msg + ); + (0, DiscoveryStatus::Failed, Some(msg)) + } + }; Ok(( StatusCode::CREATED, @@ -215,10 +221,49 @@ pub async fn create_connection_handler( name: request.name, source_type, tables_discovered, + discovery_status, + discovery_error, }), )) } +/// Handler for POST /connections/{name}/discover +pub async fn discover_connection_handler( + State(engine): State>, + Path(name): Path, +) -> Result, ApiError> { + // Validate connection exists + if engine.catalog().get_connection(&name).await?.is_none() { + return Err(ApiError::not_found(format!( + "Connection '{}' not found", + name + ))); + } + + // Attempt discovery + let (tables_discovered, discovery_status, discovery_error) = + match engine.discover_connection(&name).await { + Ok(count) => (count, DiscoveryStatus::Success, None), + Err(e) => { + let root_cause = e.root_cause().to_string(); + let msg = root_cause + .lines() + .next() + .unwrap_or("Unknown error") + .to_string(); + error!("Discovery failed for connection '{}': {}", name, msg); + (0, DiscoveryStatus::Failed, Some(msg)) + } + }; + + Ok(Json(DiscoverConnectionResponse { + name, + tables_discovered, + discovery_status, + discovery_error, + })) +} + /// Handler for GET /connections pub async fn list_connections_handler( State(engine): State>, diff --git a/src/http/models.rs b/src/http/models.rs index e01debf..2d5a6f7 100644 --- a/src/http/models.rs +++ b/src/http/models.rs @@ -39,12 +39,37 @@ pub struct CreateConnectionRequest { pub config: serde_json::Value, } +/// Discovery status for connection creation +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum DiscoveryStatus { + /// Discovery succeeded + Success, + /// Discovery was skipped (e.g., skip_discovery=true) + Skipped, + /// Discovery failed (connection still registered) + Failed, +} + /// Response body for POST /connections #[derive(Debug, Serialize)] pub struct CreateConnectionResponse { pub name: String, pub source_type: String, pub tables_discovered: usize, + pub discovery_status: DiscoveryStatus, + #[serde(skip_serializing_if = "Option::is_none")] + pub discovery_error: Option, +} + +/// Response body for POST /connections/{name}/discover +#[derive(Debug, Serialize)] +pub struct DiscoverConnectionResponse { + pub name: String, + pub tables_discovered: usize, + pub discovery_status: DiscoveryStatus, + #[serde(skip_serializing_if = "Option::is_none")] + pub discovery_error: Option, } /// Single connection metadata for API responses diff --git a/tests/http_server_tests.rs b/tests/http_server_tests.rs index e3bf610..7d1bde6 100644 --- a/tests/http_server_tests.rs +++ b/tests/http_server_tests.rs @@ -5,7 +5,9 @@ use axum::{ http::{Request, StatusCode}, Router, }; -use rivetdb::http::app_server::{AppServer, PATH_CONNECTIONS, PATH_QUERY, PATH_TABLES}; +use rivetdb::http::app_server::{ + AppServer, PATH_CONNECTIONS, PATH_CONNECTION_DISCOVER, PATH_QUERY, PATH_TABLES, +}; use rivetdb::RivetEngine; use serde_json::json; use tempfile::TempDir; @@ -448,3 +450,318 @@ async fn test_create_connection_missing_fields() -> Result<()> { Ok(()) } + +// ==================== Decoupled Registration/Discovery Tests ==================== + +#[tokio::test(flavor = "multi_thread")] +async fn test_create_connection_registers_even_when_discovery_fails() -> Result<()> { + let (app, _tempdir) = setup_test().await?; + + // Create a postgres connection with invalid credentials - registration should + // succeed but discovery should fail + let response = app + .clone() + .oneshot( + Request::builder() + .method("POST") + .uri(PATH_CONNECTIONS) + .header("content-type", "application/json") + .body(Body::from(serde_json::to_string(&json!({ + "name": "my_pg", + "source_type": "postgres", + "config": { + "host": "localhost", + "port": 5432, + "user": "nonexistent_user", + "password": "bad_password", + "database": "nonexistent_db" + } + }))?))?, + ) + .await?; + + // Should return 201 CREATED (connection was registered) + assert_eq!(response.status(), StatusCode::CREATED); + + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await?; + let json: serde_json::Value = serde_json::from_slice(&body)?; + + // Verify response structure + assert_eq!(json["name"], "my_pg"); + assert_eq!(json["source_type"], "postgres"); + assert_eq!(json["tables_discovered"], 0); + assert_eq!(json["discovery_status"], "failed"); + assert!(json["discovery_error"].is_string()); + + // Verify connection exists by listing connections + let list_response = app + .oneshot( + Request::builder() + .method("GET") + .uri(PATH_CONNECTIONS) + .body(Body::empty())?, + ) + .await?; + + assert_eq!(list_response.status(), StatusCode::OK); + + let list_body = axum::body::to_bytes(list_response.into_body(), usize::MAX).await?; + let list_json: serde_json::Value = serde_json::from_slice(&list_body)?; + + assert_eq!(list_json["connections"].as_array().unwrap().len(), 1); + assert_eq!(list_json["connections"][0]["name"], "my_pg"); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_discover_connection_not_found() -> Result<()> { + let (app, _tempdir) = setup_test().await?; + + let discover_path = PATH_CONNECTION_DISCOVER.replace("{name}", "nonexistent"); + + let response = app + .oneshot( + Request::builder() + .method("POST") + .uri(&discover_path) + .body(Body::empty())?, + ) + .await?; + + assert_eq!(response.status(), StatusCode::NOT_FOUND); + + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await?; + let json: serde_json::Value = serde_json::from_slice(&body)?; + + assert!(json["error"]["message"] + .as_str() + .unwrap() + .contains("not found")); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_discover_connection_retry_after_failed_discovery() -> Result<()> { + let (app, _tempdir) = setup_test().await?; + + // First create a connection with invalid credentials + let create_response = app + .clone() + .oneshot( + Request::builder() + .method("POST") + .uri(PATH_CONNECTIONS) + .header("content-type", "application/json") + .body(Body::from(serde_json::to_string(&json!({ + "name": "retry_conn", + "source_type": "postgres", + "config": { + "host": "localhost", + "port": 5432, + "user": "bad_user", + "password": "bad_pass", + "database": "bad_db" + } + }))?))?, + ) + .await?; + + assert_eq!(create_response.status(), StatusCode::CREATED); + + // Now try to discover again via the discover endpoint + let discover_path = PATH_CONNECTION_DISCOVER.replace("{name}", "retry_conn"); + + let discover_response = app + .oneshot( + Request::builder() + .method("POST") + .uri(&discover_path) + .body(Body::empty())?, + ) + .await?; + + // Should return 200 OK (endpoint works, even though discovery fails) + assert_eq!(discover_response.status(), StatusCode::OK); + + let body = axum::body::to_bytes(discover_response.into_body(), usize::MAX).await?; + let json: serde_json::Value = serde_json::from_slice(&body)?; + + // Verify response structure for discover endpoint + assert_eq!(json["name"], "retry_conn"); + assert_eq!(json["tables_discovered"], 0); + assert_eq!(json["discovery_status"], "failed"); + assert!(json["discovery_error"].is_string()); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_create_connection_duplicate_name_rejected() -> Result<()> { + let (app, _tempdir) = setup_test().await?; + + // Create first connection (will fail discovery but register successfully) + let first_response = app + .clone() + .oneshot( + Request::builder() + .method("POST") + .uri(PATH_CONNECTIONS) + .header("content-type", "application/json") + .body(Body::from(serde_json::to_string(&json!({ + "name": "dup_conn", + "source_type": "postgres", + "config": { + "host": "localhost", + "port": 5432, + "user": "user1", + "password": "pass1", + "database": "db1" + } + }))?))?, + ) + .await?; + + assert_eq!(first_response.status(), StatusCode::CREATED); + + // Try to create another connection with the same name + let second_response = app + .oneshot( + Request::builder() + .method("POST") + .uri(PATH_CONNECTIONS) + .header("content-type", "application/json") + .body(Body::from(serde_json::to_string(&json!({ + "name": "dup_conn", + "source_type": "postgres", + "config": { + "host": "localhost", + "port": 5432, + "user": "user2", + "password": "pass2", + "database": "db2" + } + }))?))?, + ) + .await?; + + // Should be rejected as conflict + assert_eq!(second_response.status(), StatusCode::CONFLICT); + + let body = axum::body::to_bytes(second_response.into_body(), usize::MAX).await?; + let json: serde_json::Value = serde_json::from_slice(&body)?; + + assert!(json["error"]["message"] + .as_str() + .unwrap() + .contains("already exists")); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_create_connection_successful_discovery() -> Result<()> { + let (app, tempdir) = setup_test().await?; + + // Create a DuckDB file with a table + let db_path = tempdir.path().join("test.duckdb"); + { + let conn = duckdb::Connection::open(&db_path)?; + conn.execute_batch( + "CREATE TABLE users (id INTEGER, name VARCHAR); + INSERT INTO users VALUES (1, 'Alice'), (2, 'Bob');", + )?; + } + + // Create connection via API + let response = app + .clone() + .oneshot( + Request::builder() + .method("POST") + .uri(PATH_CONNECTIONS) + .header("content-type", "application/json") + .body(Body::from(serde_json::to_string(&json!({ + "name": "test_duck", + "source_type": "duckdb", + "config": { + "path": db_path.to_str().unwrap() + } + }))?))?, + ) + .await?; + + // Should return 201 CREATED + assert_eq!(response.status(), StatusCode::CREATED); + + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await?; + let json: serde_json::Value = serde_json::from_slice(&body)?; + + // Verify successful discovery response + assert_eq!(json["name"], "test_duck"); + assert_eq!(json["source_type"], "duckdb"); + assert_eq!(json["tables_discovered"], 1); + assert_eq!(json["discovery_status"], "success"); + // discovery_error should not be present on success + assert!(json["discovery_error"].is_null()); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_discover_connection_successful() -> Result<()> { + let (app, tempdir) = setup_test().await?; + + // Create a DuckDB file with a table + let db_path = tempdir.path().join("discover_test.duckdb"); + { + let conn = duckdb::Connection::open(&db_path)?; + conn.execute_batch("CREATE TABLE orders (id INTEGER, amount DECIMAL);")?; + } + + // First create the connection + let create_response = app + .clone() + .oneshot( + Request::builder() + .method("POST") + .uri(PATH_CONNECTIONS) + .header("content-type", "application/json") + .body(Body::from(serde_json::to_string(&json!({ + "name": "discover_duck", + "source_type": "duckdb", + "config": { + "path": db_path.to_str().unwrap() + } + }))?))?, + ) + .await?; + + assert_eq!(create_response.status(), StatusCode::CREATED); + + // Now call discover endpoint (even though already discovered, it should work) + let discover_path = PATH_CONNECTION_DISCOVER.replace("{name}", "discover_duck"); + + let discover_response = app + .oneshot( + Request::builder() + .method("POST") + .uri(&discover_path) + .body(Body::empty())?, + ) + .await?; + + assert_eq!(discover_response.status(), StatusCode::OK); + + let body = axum::body::to_bytes(discover_response.into_body(), usize::MAX).await?; + let json: serde_json::Value = serde_json::from_slice(&body)?; + + // Verify successful discover response + assert_eq!(json["name"], "discover_duck"); + assert_eq!(json["tables_discovered"], 1); + assert_eq!(json["discovery_status"], "success"); + assert!(json["discovery_error"].is_null()); + + Ok(()) +}