diff --git a/Cargo.lock b/Cargo.lock index f863fc9..dfc7ed2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,42 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" +[[package]] +name = "aead" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" +dependencies = [ + "crypto-common", + "generic-array", +] + +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + +[[package]] +name = "aes-gcm-siv" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae0784134ba9375416d469ec31e7c5f9fa94405049cf08c5ce5b4698be673e0d" +dependencies = [ + "aead", + "aes", + "cipher", + "ctr", + "polyval", + "subtle", + "zeroize", +] + [[package]] name = "ahash" version = "0.7.8" @@ -793,6 +829,16 @@ dependencies = [ "phf", ] +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "clap" version = "4.5.51" @@ -1006,6 +1052,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ "generic-array", + "rand_core 0.6.4", "typenum", ] @@ -1030,6 +1077,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "ctr" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" +dependencies = [ + "cipher", +] + [[package]] name = "darling" version = "0.21.3" @@ -2588,6 +2644,15 @@ dependencies = [ "serde_core", ] +[[package]] +name = "inout" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" +dependencies = [ + "generic-array", +] + [[package]] name = "integer-encoding" version = "3.0.4" @@ -3110,6 +3175,12 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +[[package]] +name = "opaque-debug" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" + [[package]] name = "openssl-probe" version = "0.1.6" @@ -3365,6 +3436,18 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "polyval" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25" +dependencies = [ + "cfg-if", + "cpufeatures", + "opaque-debug", + "universal-hash", +] + [[package]] name = "portable-atomic" version = "1.11.1" @@ -3797,12 +3880,14 @@ dependencies = [ name = "rivetdb" version = "0.0.0" dependencies = [ + "aes-gcm-siv", "anyhow", "arrow-csv", "arrow-json", "arrow-schema", "async-trait", "axum", + "base64 0.22.1", "chrono", "clap", "config", @@ -3813,6 +3898,7 @@ dependencies = [ "log", "object_store", "proptest", + "rand 0.8.5", "rustyline", "serde", "serde_json", @@ -5232,6 +5318,16 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" +[[package]] +name = "universal-hash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" +dependencies = [ + "crypto-common", + "subtle", +] + [[package]] name = "untrusted" version = "0.9.0" diff --git a/Cargo.toml b/Cargo.toml index c8710a8..60c09ba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,8 +37,11 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] } thiserror = "2.0.17" uuid = { version = "1.11", features = ["v4"] } urlencoding = "2.1" +aes-gcm-siv = "0.11" +base64 = "0.22" [dev-dependencies] testcontainers = "0.23" testcontainers-modules = { version = "0.11", features = ["postgres"] } proptest = "1.6" +rand = "0.8" diff --git a/docs/plans/2025-12-15-secret-manager-design.md b/docs/plans/2025-12-15-secret-manager-design.md new file mode 100644 index 0000000..ea8b2e9 --- /dev/null +++ b/docs/plans/2025-12-15-secret-manager-design.md @@ -0,0 +1,257 @@ +# Secret Manager Design + +## Overview + +Rivet stores database credentials as managed secrets rather than embedding plaintext in connection configs. +A `SecretManager` orchestrates validation, catalog metadata, and encryption, while pluggable `SecretBackend` +implementations handle the actual byte storage. + +The backend implemented initially, encrypts values into the catalog database with AES-256-GCM-SIV. +Future backends (Vault, KMS, etc.) only need to implement the backend trait, allowing us to reuse the same manager logic and HTTP API. + +Connections reference secrets by name (e.g., `{ "credential": { "type": "secret_ref", "name": "prod-pg" } }`). +Secret values never get persisted outside the dedicated storage layer and are resolved just-in-time when creating external connections. + +## Architecture + +``` +┌────────────────────────────────────────────────────────┐ +│ SecretManager │ +│ - Validates names │ +│ - Manages catalog metadata + lifecycle status │ +│ - Coordinates optimistic locking + retries │ +│ - Delegates byte storage to SecretBackend │ +└──────────────┬─────────────────────────────────────────┘ + │ + ▼ +┌────────────────────────────────────────────────────────┐ +│ SecretBackend (trait) │ +│ get/put/delete raw bytes using SecretRecord context │ +│ returns provider_ref + status info │ +└──────────────┬─────────────────────────────────────────┘ + │ + ▼ +┌────────────────────────────────────────────────────────┐ +│ EncryptedCatalogBackend (v1 implementation) │ +│ - AES-256-GCM-SIV (nonce-misuse resistant) │ +│ - Key sourced from RIVETDB_SECRET_KEY (base64) │ +│ - Stores ciphertext in encrypted_secret_values table │ +└────────────────────────────────────────────────────────┘ +``` + +### Data Flow + +1. Client calls `POST /secrets` with `name` + plaintext value (UTF-8 string for v1). +2. `SecretManager::create` validates the name, claims the metadata row (`status=creating`), and writes bytes through the backend. +3. Metadata transitions to `active` with timestamps and provider info. +4. Connection configs refer to the secret by name. +5. When a fetcher needs credentials it resolves `Credential::SecretRef` via `SecretManager::get` and only uses the plaintext in memory. + +## Secret Name Constraints + +- Allowed characters: `[a-zA-Z0-9_-]` +- Length: 1–128 characters +- Comparison: case-insensitive (names normalized to lowercase) +- Validation: centralized in `validate_and_normalize_name` before any catalog or backend call + +Validation regex: `^[a-zA-Z0-9_-]{1,128}$` + +## Database Schema + +```sql +CREATE TABLE secrets ( + name TEXT PRIMARY KEY, -- normalized lowercase + provider TEXT NOT NULL, -- e.g. "encrypted" + provider_ref TEXT, + status TEXT NOT NULL, -- 'creating' | 'active' | 'pending_delete' + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL +); + +CREATE TABLE encrypted_secret_values ( + name TEXT PRIMARY KEY, + encrypted_value BLOB NOT NULL +); +``` + +Metadata rows track lifecycle and provider info. The encrypted values table is deliberately separate so other backends can ignore it completely. + +## Lifecycle & Concurrency + +``` +creating --(backend write succeeds)--> active --(delete requested)--> pending_delete --(cleanup)--> deleted +``` + +- **Creating**: Metadata row inserted to “claim” the name. Only one request can create a row because of the PK constraint. +- **Active**: Secret is usable and visible to read/list operations. +- **PendingDelete**: Delete in progress. Reads ignore this state, but delete calls continue to operate on it so retries work. + +### Create + +`SecretManager::create` uses optimistic locking: +1. Insert metadata row with `status=creating` (fails fast if name already active/creating). +2. Write via backend. On failure we leave the metadata row as `creating`; the user must delete to retry. +3. Promote to `active` with an `OptimisticLock` keyed by `created_at`. If another process deleted the row before promotion we fail with a database error, leaving at most an orphaned backend value that the next create overwrites. + +### Update + +`update` loads active metadata, writes via backend, and updates timestamps/provider_ref with last-write-wins semantics. Since the secret already exists we don’t need an optimistic lock; the final update simply reflects the latest successful request. + +### Delete + +Delete is a three-phase commit: +1. Set `status=pending_delete` (idempotent) so reads stop seeing the secret immediately. +2. Delete through the backend. +3. Remove metadata. If backend deletion fails we keep the row as `pending_delete` so the client can retry. If metadata deletion fails we log a warning; the value is already gone and a future create cleans up the stale row. + +## SecretBackend Trait + +```rust +#[async_trait] +pub trait SecretBackend: Debug + Send + Sync { + async fn get(&self, record: &SecretRecord) -> Result, BackendError>; + async fn put(&self, record: &SecretRecord, value: &[u8]) -> Result; + async fn delete(&self, record: &SecretRecord) -> Result; +} + +pub struct SecretRecord { + pub name: String, + pub provider_ref: Option, +} +``` + +The manager passes the normalized name and any provider-specific reference (e.g., Vault path, KMS ARN). Backends return `BackendWrite` so the manager can persist refreshed `provider_ref` values and track whether the operation created or updated data. + +## Encrypted Catalog Backend + +The default backend encrypts plaintext values with AES-256-GCM-SIV using the normalized name as associated data (AAD). The encrypted blob layout: + +``` +[ 'R','V','S','1' ][ scheme ][ key_version ][ 12-byte nonce ][ ciphertext || tag ] +``` + +- Scheme `0x01` = AES-256-GCM-SIV +- Key version `0x01` for the current master key +- Nonce is randomly generated per secret write + +All keys are provided via `RIVETDB_SECRET_KEY` (base64-encoded 32-byte key). The builder fails immediately if the env var is missing or invalid. Additional key versions can be supported later without schema changes. + +## Source & Credential Representation + +```rust +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Credential { + #[default] + None, + SecretRef { name: String }, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Source { + Postgres { + host: String, + port: u16, + user: String, + database: String, + #[serde(default)] + credential: Credential, + }, + Snowflake { + account: String, + user: String, + warehouse: String, + database: String, + #[serde(skip_serializing_if = "Option::is_none")] + role: Option, + #[serde(default)] + credential: Credential, + }, + Motherduck { + database: String, + #[serde(default)] + credential: Credential, + }, + Duckdb { path: String }, +} +``` + +`Credential::resolve(&SecretManager)` either errors (`Credential::None`) or fetches the referenced secret, returning a UTF-8 `String`. Data fetchers call this right before building external connection strings so that invalid or missing secrets don’t block connection registration—discovery simply fails with a descriptive error until the secret exists. + +## HTTP API + +| Method | Path | Description | +|--------|--------------------|----------------------------------------------| +| POST | `/secrets` | Create a new secret | +| PUT | `/secrets/{name}` | Update an existing secret’s value | +| GET | `/secrets` | List metadata for all active secrets | +| GET | `/secrets/{name}` | Fetch metadata for a specific secret | +| DELETE | `/secrets/{name}` | Delete a secret (three-phase commit) | + +### Request/Response Models + +```rust +pub struct CreateSecretRequest { + pub name: String, + pub value: String, // UTF-8 for v1 +} + +pub struct CreateSecretResponse { + pub name: String, + pub created_at: DateTime, +} + +pub struct UpdateSecretRequest { + pub value: String, +} + +pub struct UpdateSecretResponse { + pub name: String, + pub updated_at: DateTime, +} + +pub struct SecretMetadataResponse { + pub name: String, + pub created_at: DateTime, + pub updated_at: DateTime, +} +``` + +The API accepts UTF-8 strings today; internally the manager stores raw bytes so binary secrets can be added later without schema changes. + +### Error Handling + +| Scenario | Status | Message | +|----------|--------|---------| +| Secret manager disabled | 503 | `Secret manager not configured` | +| Invalid name | 400 | Includes validation hint | +| Secret not found | 404 | `Secret '{name}' not found` | +| Secret already exists | 409 | `Secret '{name}' already exists` | +| Creation in progress | 409 | `Secret '{name}' is being created by another process...` | +| Backend/storage failure | 500 | `Backend error: ...` | + +Errors are surfaced via `SecretError -> ApiError`, so any new failure modes automatically produce consistent HTTP responses. + +## Configuration + +- `RIVETDB_SECRET_KEY`: base64-encoded 32-byte key required to enable the encrypted backend. +- `RivetEngine::builder().secret_key(...)` enforces the presence of this key at startup; the server refuses to boot without it because every HTTP secret operation depends on the manager. + +## Code Organization + +- `src/secrets/mod.rs`: `SecretManager`, lifecycle management, validation, optimistic locking. +- `src/secrets/backend.rs`: `SecretBackend` trait plus helper structs (`SecretRecord`, `BackendWrite`, etc.). +- `src/secrets/encrypted_catalog_backend.rs`: AES-256-GCM-SIV backend implementation. +- `src/catalog/*_manager.rs`: catalog metadata operations, status transitions, and encrypted value storage helpers. +- `src/http/{handlers,models}.rs`: Secret CRUD endpoints and DTOs. +- `src/source.rs` & `src/datafetch`: credential references and runtime secret resolution. + +## Test Coverage + +- Unit tests for encryption/decryption round-trips, base64 key parsing, and decryption failure scenarios. +- Secret manager tests for create/update/delete/list flows, optimistic locking, name normalization, and invalid names. +- HTTP integration tests for all secret endpoints (including update and error cases). +- Engine HTTP tests that cover connection creation without upfront credentials and discover retries. + +This final design matches the implementation on `feat/secret-manager`: the manager owns lifecycle + metadata, backends only store ciphertext, and the API exposes clear semantics for create/update/delete with robust concurrency guarantees. diff --git a/src/catalog/manager.rs b/src/catalog/manager.rs index f2cde3e..6175d19 100644 --- a/src/catalog/manager.rs +++ b/src/catalog/manager.rs @@ -1,9 +1,23 @@ +use crate::secrets::{SecretMetadata, SecretStatus}; use anyhow::Result; use async_trait::async_trait; +use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use sqlx::FromRow; use std::fmt::Debug; +/// Used to conditionally update a secret only if it hasn't been modified. +#[derive(Debug, Clone, Copy)] +pub struct OptimisticLock { + pub created_at: DateTime, +} + +impl From> for OptimisticLock { + fn from(created_at: DateTime) -> Self { + Self { created_at } + } +} + #[derive(Debug, Clone, Serialize, Deserialize, FromRow)] pub struct ConnectionInfo { pub id: i32, @@ -68,4 +82,45 @@ pub trait CatalogManager: Debug + Send + Sync { /// Delete connection and all associated table rows from metadata. async fn delete_connection(&self, name: &str) -> Result<()>; + + // Secret management methods - metadata (used by all secret providers) + + /// Get metadata for an active secret (without value). + /// Returns None for secrets with status != 'active'. + async fn get_secret_metadata(&self, name: &str) -> Result>; + + /// Get metadata for a secret regardless of status (for internal cleanup). + async fn get_secret_metadata_any_status(&self, name: &str) -> Result>; + + /// Create secret metadata. Fails if the secret already exists. + async fn create_secret_metadata(&self, metadata: &SecretMetadata) -> Result<()>; + + /// Update existing secret metadata. + /// If `lock` is Some, only updates if created_at matches (returns false on mismatch). + /// If `lock` is None, updates unconditionally. + async fn update_secret_metadata( + &self, + metadata: &SecretMetadata, + lock: Option, + ) -> Result; + + /// Set the status of a secret. + async fn set_secret_status(&self, name: &str, status: SecretStatus) -> Result; + + /// Delete secret metadata. Returns true if the secret existed. + async fn delete_secret_metadata(&self, name: &str) -> Result; + + /// List all active secrets (metadata only). + async fn list_secrets(&self) -> Result>; + + // Secret management methods - encrypted storage (used by EncryptedSecretManager only) + + /// Get the encrypted value for a secret. + async fn get_encrypted_secret(&self, name: &str) -> Result>>; + + /// Store or update an encrypted secret value. + async fn put_encrypted_secret_value(&self, name: &str, encrypted_value: &[u8]) -> Result<()>; + + /// Delete an encrypted secret value. Returns true if it existed. + async fn delete_encrypted_secret_value(&self, name: &str) -> Result; } diff --git a/src/catalog/mod.rs b/src/catalog/mod.rs index 8c8f193..b39b688 100644 --- a/src/catalog/mod.rs +++ b/src/catalog/mod.rs @@ -5,6 +5,6 @@ mod sqlite_manager; mod manager; -pub use manager::{CatalogManager, ConnectionInfo, TableInfo}; +pub use manager::{CatalogManager, ConnectionInfo, OptimisticLock, TableInfo}; pub use postgres_manager::PostgresCatalogManager; pub use sqlite_manager::SqliteCatalogManager; diff --git a/src/catalog/postgres_manager.rs b/src/catalog/postgres_manager.rs index 1177ad7..8a62c33 100644 --- a/src/catalog/postgres_manager.rs +++ b/src/catalog/postgres_manager.rs @@ -1,16 +1,43 @@ use crate::catalog::backend::CatalogBackend; -use crate::catalog::manager::{CatalogManager, ConnectionInfo, TableInfo}; +use crate::catalog::manager::{CatalogManager, ConnectionInfo, OptimisticLock, TableInfo}; use crate::catalog::migrations::{run_migrations, CatalogMigrations}; +use crate::secrets::{SecretMetadata, SecretStatus}; use anyhow::Result; use async_trait::async_trait; +use chrono::{DateTime, Utc}; use sqlx::postgres::PgPoolOptions; use sqlx::{PgPool, Postgres}; use std::fmt::{self, Debug, Formatter}; +use std::str::FromStr; pub struct PostgresCatalogManager { backend: CatalogBackend, } +/// Row type for secret metadata queries (Postgres handles timestamps natively) +#[derive(sqlx::FromRow)] +struct SecretMetadataRow { + name: String, + provider: String, + provider_ref: Option, + status: String, + created_at: DateTime, + updated_at: DateTime, +} + +impl SecretMetadataRow { + fn into_metadata(self) -> SecretMetadata { + SecretMetadata { + name: self.name, + provider: self.provider, + provider_ref: self.provider_ref, + status: SecretStatus::from_str(&self.status).unwrap_or(SecretStatus::Active), + created_at: self.created_at, + updated_at: self.updated_at, + } + } +} + impl PostgresCatalogManager { pub async fn new(connection_string: &str) -> Result { let pool = PgPoolOptions::new() @@ -52,6 +79,28 @@ impl PostgresCatalogManager { .execute(pool) .await?; + sqlx::query( + "CREATE TABLE IF NOT EXISTS secrets ( + name TEXT PRIMARY KEY, + provider TEXT NOT NULL, + provider_ref TEXT, + status TEXT NOT NULL DEFAULT 'active', + created_at TIMESTAMPTZ NOT NULL, + updated_at TIMESTAMPTZ NOT NULL + )", + ) + .execute(pool) + .await?; + + sqlx::query( + "CREATE TABLE IF NOT EXISTS encrypted_secret_values ( + name TEXT PRIMARY KEY, + encrypted_value BYTEA NOT NULL + )", + ) + .execute(pool) + .await?; + Ok(()) } } @@ -167,6 +216,139 @@ impl CatalogManager for PostgresCatalogManager { async fn delete_connection(&self, name: &str) -> Result<()> { self.backend.delete_connection(name).await } + + async fn get_secret_metadata(&self, name: &str) -> Result> { + let row: Option = sqlx::query_as( + "SELECT name, provider, provider_ref, status, created_at, updated_at \ + FROM secrets WHERE name = $1 AND status = 'active'", + ) + .bind(name) + .fetch_optional(self.backend.pool()) + .await?; + + Ok(row.map(SecretMetadataRow::into_metadata)) + } + + async fn get_secret_metadata_any_status(&self, name: &str) -> Result> { + let row: Option = sqlx::query_as( + "SELECT name, provider, provider_ref, status, created_at, updated_at \ + FROM secrets WHERE name = $1", + ) + .bind(name) + .fetch_optional(self.backend.pool()) + .await?; + + Ok(row.map(SecretMetadataRow::into_metadata)) + } + + async fn create_secret_metadata(&self, metadata: &SecretMetadata) -> Result<()> { + sqlx::query( + "INSERT INTO secrets (name, provider, provider_ref, status, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $5, $6)", + ) + .bind(&metadata.name) + .bind(&metadata.provider) + .bind(&metadata.provider_ref) + .bind(metadata.status.as_str()) + .bind(metadata.created_at) + .bind(metadata.updated_at) + .execute(self.backend.pool()) + .await?; + + Ok(()) + } + + async fn update_secret_metadata( + &self, + metadata: &SecretMetadata, + lock: Option, + ) -> Result { + use sqlx::QueryBuilder; + + let mut qb = QueryBuilder::new("UPDATE secrets SET "); + qb.push("provider = ") + .push_bind(&metadata.provider) + .push(", provider_ref = ") + .push_bind(&metadata.provider_ref) + .push(", status = ") + .push_bind(metadata.status.as_str()) + .push(", updated_at = ") + .push_bind(metadata.updated_at) + .push(" WHERE name = ") + .push_bind(&metadata.name); + + if let Some(lock) = lock { + qb.push(" AND created_at = ").push_bind(lock.created_at); + } + + let result = qb.build().execute(self.backend.pool()).await?; + Ok(result.rows_affected() > 0) + } + + async fn set_secret_status(&self, name: &str, status: SecretStatus) -> Result { + let result = sqlx::query("UPDATE secrets SET status = $1 WHERE name = $2") + .bind(status.as_str()) + .bind(name) + .execute(self.backend.pool()) + .await?; + + Ok(result.rows_affected() > 0) + } + + async fn delete_secret_metadata(&self, name: &str) -> Result { + let result = sqlx::query("DELETE FROM secrets WHERE name = $1") + .bind(name) + .execute(self.backend.pool()) + .await?; + + Ok(result.rows_affected() > 0) + } + + async fn get_encrypted_secret(&self, name: &str) -> Result>> { + sqlx::query_scalar("SELECT encrypted_value FROM encrypted_secret_values WHERE name = $1") + .bind(name) + .fetch_optional(self.backend.pool()) + .await + .map_err(Into::into) + } + + async fn put_encrypted_secret_value(&self, name: &str, encrypted_value: &[u8]) -> Result<()> { + sqlx::query( + "INSERT INTO encrypted_secret_values (name, encrypted_value) \ + VALUES ($1, $2) \ + ON CONFLICT (name) DO UPDATE SET \ + encrypted_value = excluded.encrypted_value", + ) + .bind(name) + .bind(encrypted_value) + .execute(self.backend.pool()) + .await?; + + Ok(()) + } + + async fn delete_encrypted_secret_value(&self, name: &str) -> Result { + let result = sqlx::query("DELETE FROM encrypted_secret_values WHERE name = $1") + .bind(name) + .execute(self.backend.pool()) + .await?; + + Ok(result.rows_affected() > 0) + } + + async fn list_secrets(&self) -> Result> { + let rows: Vec = sqlx::query_as( + "SELECT name, provider, provider_ref, status, created_at, updated_at \ + FROM secrets WHERE status = 'active' ORDER BY name", + ) + .fetch_all(self.backend.pool()) + .await?; + + Ok(rows + .into_iter() + .map(SecretMetadataRow::into_metadata) + .collect()) + } } impl Debug for PostgresCatalogManager { diff --git a/src/catalog/sqlite_manager.rs b/src/catalog/sqlite_manager.rs index 43ceef7..7301b0e 100644 --- a/src/catalog/sqlite_manager.rs +++ b/src/catalog/sqlite_manager.rs @@ -1,16 +1,43 @@ use crate::catalog::backend::CatalogBackend; -use crate::catalog::manager::{CatalogManager, ConnectionInfo, TableInfo}; +use crate::catalog::manager::{CatalogManager, ConnectionInfo, OptimisticLock, TableInfo}; use crate::catalog::migrations::{run_migrations, CatalogMigrations}; +use crate::secrets::{SecretMetadata, SecretStatus}; use anyhow::Result; use async_trait::async_trait; +use chrono::Utc; use sqlx::{Sqlite, SqlitePool}; use std::fmt::{self, Debug, Formatter}; +use std::str::FromStr; pub struct SqliteCatalogManager { backend: CatalogBackend, catalog_path: String, } +/// Row type for secret metadata queries (SQLite stores timestamps as strings) +#[derive(sqlx::FromRow)] +struct SecretMetadataRow { + name: String, + provider: String, + provider_ref: Option, + status: String, + created_at: String, + updated_at: String, +} + +impl SecretMetadataRow { + fn into_metadata(self) -> SecretMetadata { + SecretMetadata { + name: self.name, + provider: self.provider, + provider_ref: self.provider_ref, + status: SecretStatus::from_str(&self.status).unwrap_or(SecretStatus::Active), + created_at: self.created_at.parse().unwrap_or_else(|_| Utc::now()), + updated_at: self.updated_at.parse().unwrap_or_else(|_| Utc::now()), + } + } +} + impl Debug for SqliteCatalogManager { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("SqliteCatalogManager") @@ -67,6 +94,32 @@ impl SqliteCatalogManager { .execute(pool) .await?; + sqlx::query( + r#" + CREATE TABLE IF NOT EXISTS secrets ( + name TEXT PRIMARY KEY, + provider TEXT NOT NULL, + provider_ref TEXT, + status TEXT NOT NULL DEFAULT 'active', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ) + "#, + ) + .execute(pool) + .await?; + + sqlx::query( + r#" + CREATE TABLE IF NOT EXISTS encrypted_secret_values ( + name TEXT PRIMARY KEY, + encrypted_value BLOB NOT NULL + ) + "#, + ) + .execute(pool) + .await?; + Ok(()) } } @@ -145,6 +198,145 @@ impl CatalogManager for SqliteCatalogManager { async fn delete_connection(&self, name: &str) -> Result<()> { self.backend.delete_connection(name).await } + + async fn get_secret_metadata(&self, name: &str) -> Result> { + let row: Option = sqlx::query_as( + "SELECT name, provider, provider_ref, status, created_at, updated_at \ + FROM secrets WHERE name = ? AND status = 'active'", + ) + .bind(name) + .fetch_optional(self.backend.pool()) + .await?; + + Ok(row.map(SecretMetadataRow::into_metadata)) + } + + async fn get_secret_metadata_any_status(&self, name: &str) -> Result> { + let row: Option = sqlx::query_as( + "SELECT name, provider, provider_ref, status, created_at, updated_at \ + FROM secrets WHERE name = ?", + ) + .bind(name) + .fetch_optional(self.backend.pool()) + .await?; + + Ok(row.map(SecretMetadataRow::into_metadata)) + } + + async fn create_secret_metadata(&self, metadata: &SecretMetadata) -> Result<()> { + let created_at = metadata.created_at.to_rfc3339(); + let updated_at = metadata.updated_at.to_rfc3339(); + + sqlx::query( + "INSERT INTO secrets (name, provider, provider_ref, status, created_at, updated_at) \ + VALUES (?, ?, ?, ?, ?, ?)", + ) + .bind(&metadata.name) + .bind(&metadata.provider) + .bind(&metadata.provider_ref) + .bind(metadata.status.as_str()) + .bind(&created_at) + .bind(&updated_at) + .execute(self.backend.pool()) + .await?; + + Ok(()) + } + + async fn update_secret_metadata( + &self, + metadata: &SecretMetadata, + lock: Option, + ) -> Result { + use sqlx::QueryBuilder; + + let updated_at = metadata.updated_at.to_rfc3339(); + + let mut qb = QueryBuilder::new("UPDATE secrets SET "); + qb.push("provider = ") + .push_bind(&metadata.provider) + .push(", provider_ref = ") + .push_bind(&metadata.provider_ref) + .push(", status = ") + .push_bind(metadata.status.as_str()) + .push(", updated_at = ") + .push_bind(&updated_at) + .push(" WHERE name = ") + .push_bind(&metadata.name); + + if let Some(lock) = lock { + let expected_created_at = lock.created_at.to_rfc3339(); + qb.push(" AND created_at = ").push_bind(expected_created_at); + } + + let result = qb.build().execute(self.backend.pool()).await?; + Ok(result.rows_affected() > 0) + } + + async fn set_secret_status(&self, name: &str, status: SecretStatus) -> Result { + let result = sqlx::query("UPDATE secrets SET status = ? WHERE name = ?") + .bind(status.as_str()) + .bind(name) + .execute(self.backend.pool()) + .await?; + + Ok(result.rows_affected() > 0) + } + + async fn delete_secret_metadata(&self, name: &str) -> Result { + let result = sqlx::query("DELETE FROM secrets WHERE name = ?") + .bind(name) + .execute(self.backend.pool()) + .await?; + + Ok(result.rows_affected() > 0) + } + + async fn get_encrypted_secret(&self, name: &str) -> Result>> { + sqlx::query_scalar("SELECT encrypted_value FROM encrypted_secret_values WHERE name = ?") + .bind(name) + .fetch_optional(self.backend.pool()) + .await + .map_err(Into::into) + } + + async fn put_encrypted_secret_value(&self, name: &str, encrypted_value: &[u8]) -> Result<()> { + sqlx::query( + "INSERT INTO encrypted_secret_values (name, encrypted_value) \ + VALUES (?, ?) \ + ON CONFLICT (name) DO UPDATE SET \ + encrypted_value = excluded.encrypted_value", + ) + .bind(name) + .bind(encrypted_value) + .execute(self.backend.pool()) + .await?; + + Ok(()) + } + + async fn delete_encrypted_secret_value(&self, name: &str) -> Result { + let result = sqlx::query("DELETE FROM encrypted_secret_values WHERE name = ?") + .bind(name) + .execute(self.backend.pool()) + .await?; + + Ok(result.rows_affected() > 0) + } + + async fn list_secrets(&self) -> Result> { + let rows: Vec = sqlx::query_as( + "SELECT name, provider, provider_ref, status, created_at, updated_at \ + FROM secrets WHERE status = 'active' ORDER BY name", + ) + .fetch_all(self.backend.pool()) + .await?; + + Ok(rows + .into_iter() + .map(SecretMetadataRow::into_metadata) + .collect()) + } } impl CatalogMigrations for SqliteMigrationBackend { diff --git a/src/config/mod.rs b/src/config/mod.rs index aa49c79..9c420c5 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -8,6 +8,8 @@ pub struct AppConfig { pub storage: StorageConfig, #[serde(default)] pub paths: PathsConfig, + #[serde(default)] + pub secrets: SecretsConfig, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -55,6 +57,13 @@ pub struct PathsConfig { pub cache_dir: Option, } +#[derive(Debug, Clone, Deserialize, Serialize, Default)] +pub struct SecretsConfig { + /// Encryption key for secrets (base64-encoded 32-byte key). + /// Can also be set via RIVETDB_SECRET_KEY environment variable. + pub encryption_key: Option, +} + impl AppConfig { /// Load configuration from file and environment variables pub fn load(config_path: &str) -> Result { diff --git a/src/datafetch/fetcher.rs b/src/datafetch/fetcher.rs index 576fa1a..d21b4ef 100644 --- a/src/datafetch/fetcher.rs +++ b/src/datafetch/fetcher.rs @@ -2,13 +2,18 @@ use async_trait::async_trait; use super::native::StreamingParquetWriter; use super::{DataFetchError, TableMetadata}; +use crate::secrets::SecretManager; use crate::source::Source; /// Trait for fetching data from remote sources #[async_trait] pub trait DataFetcher: Send + Sync + std::fmt::Debug { /// Discover all tables (with columns) from the remote source - async fn discover_tables(&self, source: &Source) -> Result, DataFetchError>; + async fn discover_tables( + &self, + source: &Source, + secrets: &SecretManager, + ) -> Result, DataFetchError>; /// Fetch table data and write to the provided Parquet writer. /// The writer is pre-initialized with the destination path. @@ -16,6 +21,7 @@ pub trait DataFetcher: Send + Sync + std::fmt::Debug { async fn fetch_table( &self, source: &Source, + secrets: &SecretManager, catalog: Option<&str>, schema: &str, table: &str, diff --git a/src/datafetch/native/duckdb.rs b/src/datafetch/native/duckdb.rs index 37695ac..1fab54d 100644 --- a/src/datafetch/native/duckdb.rs +++ b/src/datafetch/native/duckdb.rs @@ -2,15 +2,20 @@ use duckdb::Connection; use std::collections::HashMap; +use urlencoding::encode; use crate::datafetch::{ColumnMetadata, DataFetchError, TableMetadata}; +use crate::secrets::SecretManager; use crate::source::Source; use super::StreamingParquetWriter; /// Discover tables and columns from DuckDB/MotherDuck -pub async fn discover_tables(source: &Source) -> Result, DataFetchError> { - let connection_string = source.connection_string(); +pub async fn discover_tables( + source: &Source, + secrets: &SecretManager, +) -> Result, DataFetchError> { + let connection_string = resolve_connection_string(source, secrets).await?; let catalog = source.catalog().map(|s| s.to_string()); tokio::task::spawn_blocking(move || { @@ -20,6 +25,33 @@ pub async fn discover_tables(source: &Source) -> Result, Data .map_err(|e| DataFetchError::Connection(e.to_string()))? } +/// Resolve credentials and build connection string for DuckDB or Motherduck source. +pub async fn resolve_connection_string( + source: &Source, + secrets: &SecretManager, +) -> Result { + match source { + Source::Duckdb { path } => Ok(path.clone()), + Source::Motherduck { + database, + credential, + } => { + let token = credential + .resolve(secrets) + .await + .map_err(|e| DataFetchError::Connection(e.to_string()))?; + Ok(format!( + "md:{}?motherduck_token={}", + encode(database), + encode(&token) + )) + } + _ => Err(DataFetchError::Connection( + "Expected DuckDB or Motherduck source".to_string(), + )), + } +} + fn discover_tables_sync( connection_string: &str, catalog: Option<&str>, @@ -107,6 +139,7 @@ enum FetchMessage { /// Fetch table data and write to Parquet using streaming to avoid OOM on large tables pub async fn fetch_table( source: &Source, + secrets: &SecretManager, _catalog: Option<&str>, schema: &str, table: &str, @@ -115,7 +148,7 @@ pub async fn fetch_table( use datafusion::arrow::record_batch::RecordBatch; use std::sync::Arc; - let connection_string = source.connection_string(); + let connection_string = resolve_connection_string(source, secrets).await?; let schema = schema.to_string(); let table = table.to_string(); diff --git a/src/datafetch/native/mod.rs b/src/datafetch/native/mod.rs index eaccadd..ea0a9df 100644 --- a/src/datafetch/native/mod.rs +++ b/src/datafetch/native/mod.rs @@ -7,6 +7,7 @@ pub use parquet_writer::StreamingParquetWriter; use async_trait::async_trait; use crate::datafetch::{DataFetchError, DataFetcher, TableMetadata}; +use crate::secrets::SecretManager; use crate::source::Source; /// Native Rust driver-based data fetcher @@ -21,12 +22,16 @@ impl NativeFetcher { #[async_trait] impl DataFetcher for NativeFetcher { - async fn discover_tables(&self, source: &Source) -> Result, DataFetchError> { + async fn discover_tables( + &self, + source: &Source, + secrets: &SecretManager, + ) -> Result, DataFetchError> { match source { Source::Duckdb { .. } | Source::Motherduck { .. } => { - duckdb::discover_tables(source).await + duckdb::discover_tables(source, secrets).await } - Source::Postgres { .. } => postgres::discover_tables(source).await, + Source::Postgres { .. } => postgres::discover_tables(source, secrets).await, Source::Snowflake { .. } => Err(DataFetchError::UnsupportedDriver("Snowflake")), } } @@ -34,6 +39,7 @@ impl DataFetcher for NativeFetcher { async fn fetch_table( &self, source: &Source, + secrets: &SecretManager, catalog: Option<&str>, schema: &str, table: &str, @@ -41,10 +47,10 @@ impl DataFetcher for NativeFetcher { ) -> Result<(), DataFetchError> { match source { Source::Duckdb { .. } | Source::Motherduck { .. } => { - duckdb::fetch_table(source, catalog, schema, table, writer).await + duckdb::fetch_table(source, secrets, catalog, schema, table, writer).await } Source::Postgres { .. } => { - postgres::fetch_table(source, catalog, schema, table, writer).await + postgres::fetch_table(source, secrets, catalog, schema, table, writer).await } Source::Snowflake { .. } => Err(DataFetchError::UnsupportedDriver("Snowflake")), } diff --git a/src/datafetch/native/postgres.rs b/src/datafetch/native/postgres.rs index d6a0db6..7783d8e 100644 --- a/src/datafetch/native/postgres.rs +++ b/src/datafetch/native/postgres.rs @@ -10,12 +10,62 @@ use futures::StreamExt; use sqlx::postgres::{PgColumn, PgConnection, PgRow}; use sqlx::{Column, Connection, Row, TypeInfo}; use std::sync::Arc; +use urlencoding::encode; use crate::datafetch::{ColumnMetadata, DataFetchError, TableMetadata}; +use crate::secrets::SecretManager; use crate::source::Source; use super::StreamingParquetWriter; +/// Build a PostgreSQL connection string from source configuration and resolved password. +fn build_connection_string( + host: &str, + port: u16, + user: &str, + database: &str, + password: &str, +) -> String { + format!( + "postgresql://{}:{}@{}:{}/{}", + encode(user), + encode(password), + encode(host), + port, + encode(database) + ) +} + +/// Resolve credentials and build connection string for a Postgres source. +pub async fn resolve_connection_string( + source: &Source, + secrets: &SecretManager, +) -> Result { + let (host, port, user, database, credential) = match source { + Source::Postgres { + host, + port, + user, + database, + credential, + } => (host, *port, user, database, credential), + _ => { + return Err(DataFetchError::Connection( + "Expected Postgres source".to_string(), + )) + } + }; + + let password = credential + .resolve(secrets) + .await + .map_err(|e| DataFetchError::Connection(e.to_string()))?; + + Ok(build_connection_string( + host, port, user, database, &password, + )) +} + /// Connect to PostgreSQL with automatic SSL retry. /// If the initial connection fails with an "insecure connection" error, /// automatically retries with `sslmode=require` appended to the connection string. @@ -42,8 +92,12 @@ async fn connect_with_ssl_retry(connection_string: &str) -> Result Result, DataFetchError> { - let mut conn = connect_with_ssl_retry(&source.connection_string()).await?; +pub async fn discover_tables( + source: &Source, + secrets: &SecretManager, +) -> Result, DataFetchError> { + let connection_string = resolve_connection_string(source, secrets).await?; + let mut conn = connect_with_ssl_retry(&connection_string).await?; let rows = sqlx::query( r#" @@ -110,12 +164,14 @@ pub async fn discover_tables(source: &Source) -> Result, Data /// Fetch table data and write to Parquet using streaming to avoid OOM on large tables pub async fn fetch_table( source: &Source, + secrets: &SecretManager, _catalog: Option<&str>, schema: &str, table: &str, writer: &mut StreamingParquetWriter, ) -> Result<(), DataFetchError> { - let mut conn = connect_with_ssl_retry(&source.connection_string()).await?; + let connection_string = resolve_connection_string(source, secrets).await?; + let mut conn = connect_with_ssl_retry(&connection_string).await?; // Build query - properly escape identifiers let query = format!( @@ -138,7 +194,7 @@ pub async fn fetch_table( None => { // Empty table: query information_schema for schema // Need a new connection since stream borrows conn - let mut schema_conn = connect_with_ssl_retry(&source.connection_string()).await?; + let mut schema_conn = connect_with_ssl_retry(&connection_string).await?; let schema_rows = sqlx::query( r#" SELECT column_name, data_type, is_nullable diff --git a/src/datafetch/orchestrator.rs b/src/datafetch/orchestrator.rs index b7123ea..93a5ecd 100644 --- a/src/datafetch/orchestrator.rs +++ b/src/datafetch/orchestrator.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use super::native::StreamingParquetWriter; use super::DataFetcher; use crate::catalog::CatalogManager; +use crate::secrets::SecretManager; use crate::source::Source; use crate::storage::StorageManager; @@ -13,6 +14,7 @@ pub struct FetchOrchestrator { fetcher: Arc, storage: Arc, catalog: Arc, + secret_manager: Arc, } impl FetchOrchestrator { @@ -20,11 +22,13 @@ impl FetchOrchestrator { fetcher: Arc, storage: Arc, catalog: Arc, + secret_manager: Arc, ) -> Self { Self { fetcher, storage, catalog, + secret_manager, } } @@ -50,6 +54,7 @@ impl FetchOrchestrator { self.fetcher .fetch_table( source, + &self.secret_manager, None, // catalog schema_name, table_name, diff --git a/src/engine.rs b/src/engine.rs index 68aafc1..e71192d 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -1,6 +1,7 @@ use crate::catalog::{CatalogManager, ConnectionInfo, SqliteCatalogManager, TableInfo}; use crate::datafetch::{DataFetcher, FetchOrchestrator, NativeFetcher}; use crate::datafusion::{block_on, RivetCatalogProvider}; +use crate::secrets::{EncryptedCatalogBackend, SecretManager, ENCRYPTED_PROVIDER_TYPE}; use crate::source::Source; use crate::storage::{FilesystemStorage, StorageManager}; use anyhow::Result; @@ -13,6 +14,11 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use tracing::error; +/// Default insecure encryption key for development use only. +/// This key is publicly known and provides NO security. +/// It is a base64-encoded 32-byte key: "INSECURE_DEFAULT_KEY_RIVETDB!!!!" +const DEFAULT_INSECURE_KEY: &str = "SU5TRUNVUkVfREVGQVVMVF9LRVlfUklWRVREQiEhISE="; + pub struct QueryResponse { pub results: Vec, pub execution_time: Duration, @@ -24,6 +30,7 @@ pub struct RivetEngine { df_ctx: SessionContext, storage: Arc, orchestrator: Arc, + secret_manager: Arc, } impl RivetEngine { @@ -60,6 +67,11 @@ impl RivetEngine { builder = builder.cache_dir(PathBuf::from(cache)); } + // Set secret key if explicitly configured + if let Some(key) = &config.secrets.encryption_key { + builder = builder.secret_key(key); + } + // Only create explicit catalog for non-sqlite backends if config.catalog.catalog_type != "sqlite" { let catalog = Self::create_catalog_from_config(config).await?; @@ -189,6 +201,11 @@ impl RivetEngine { &self.storage } + /// Get a reference to the secret manager. + pub fn secret_manager(&self) -> &Arc { + &self.secret_manager + } + /// Register all connections from the catalog store as DataFusion catalogs. async fn register_existing_connections(&mut self) -> Result<()> { let connections = self.catalog.list_connections().await?; @@ -260,7 +277,7 @@ impl RivetEngine { info!("Discovering tables for {} source...", source_type); let fetcher = crate::datafetch::NativeFetcher::new(); let tables = fetcher - .discover_tables(&source) + .discover_tables(&source, &self.secret_manager) .await .map_err(|e| anyhow::anyhow!("Discovery failed: {}", e))?; @@ -504,6 +521,7 @@ pub struct RivetEngineBuilder { cache_dir: Option, catalog: Option>, storage: Option>, + secret_key: Option, } impl Default for RivetEngineBuilder { @@ -519,6 +537,7 @@ impl RivetEngineBuilder { cache_dir: None, catalog: None, storage: None, + secret_key: std::env::var("RIVETDB_SECRET_KEY").ok(), } } @@ -550,6 +569,14 @@ impl RivetEngineBuilder { self } + /// Set the encryption key for the secret manager (base64-encoded 32-byte key). + /// If not set, falls back to RIVETDB_SECRET_KEY environment variable. + /// If neither is set, uses a default insecure key (with loud warnings). + pub fn secret_key(mut self, key: impl Into) -> Self { + self.secret_key = Some(key.into()); + self + } + /// Resolve the base directory, using default if not set. fn resolve_base_dir(&self) -> PathBuf { self.base_dir.clone().unwrap_or_else(|| { @@ -609,12 +636,47 @@ impl RivetEngineBuilder { let df_ctx = SessionContext::new(); storage.register_with_datafusion(&df_ctx)?; - // Step 6: Create fetch orchestrator + // Step 6: Initialize secret manager + let (secret_key, using_default_key) = match self.secret_key { + Some(key) => (key, false), + None => { + // Use insecure default key for development convenience + (DEFAULT_INSECURE_KEY.to_string(), true) + } + }; + + if using_default_key { + warn!("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"); + warn!("!!! SECURITY WARNING !!!"); + warn!("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"); + warn!("!!! Using DEFAULT INSECURE encryption key for secrets. !!!"); + warn!("!!! This key is PUBLICLY KNOWN and provides NO SECURITY. !!!"); + warn!("!!! !!!"); + warn!("!!! DO NOT USE IN PRODUCTION! !!!"); + warn!("!!! !!!"); + warn!("!!! To fix: Set RIVETDB_SECRET_KEY environment variable !!!"); + warn!("!!! Generate a key with: openssl rand -base64 32 !!!"); + warn!("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"); + } + + let backend = Arc::new( + EncryptedCatalogBackend::from_base64_key(&secret_key, catalog.clone()) + .map_err(|e| anyhow::anyhow!("Invalid secret key: {}", e))?, + ); + let secret_manager = Arc::new(SecretManager::new( + backend, + catalog.clone(), + ENCRYPTED_PROVIDER_TYPE, + )); + info!("Secret manager initialized"); + + // Step 7: Create fetch orchestrator (needs secret_manager) let fetcher = Arc::new(NativeFetcher::new()); let orchestrator = Arc::new(FetchOrchestrator::new( fetcher, storage.clone(), catalog.clone(), + secret_manager.clone(), )); let mut engine = RivetEngine { @@ -622,6 +684,7 @@ impl RivetEngineBuilder { df_ctx, storage, orchestrator, + secret_manager, }; // Register all existing connections as DataFusion catalogs @@ -636,6 +699,15 @@ mod tests { use super::*; use tempfile::TempDir; + /// Generate a test secret key (base64-encoded 32 bytes) + fn test_secret_key() -> String { + use base64::{engine::general_purpose::STANDARD, Engine}; + use rand::RngCore; + let mut key = [0u8; 32]; + rand::thread_rng().fill_bytes(&mut key); + STANDARD.encode(key) + } + #[tokio::test(flavor = "multi_thread")] async fn test_builder_pattern() { let temp_dir = TempDir::new().unwrap(); @@ -658,6 +730,7 @@ mod tests { .base_dir(base_dir.clone()) .catalog(catalog) .storage(storage) + .secret_key(test_secret_key()) .build() .await; @@ -681,6 +754,7 @@ mod tests { let temp_dir = TempDir::new().unwrap(); let result = RivetEngine::builder() .base_dir(temp_dir.path().to_path_buf()) + .secret_key(test_secret_key()) .build() .await; assert!( @@ -695,24 +769,41 @@ mod tests { } #[tokio::test(flavor = "multi_thread")] - async fn test_defaults_constructor() { - // Test the defaults() convenience constructor + async fn test_builder_uses_default_key_without_secret_key() { + // Test that builder uses default insecure key when no secret key is provided let temp_dir = TempDir::new().unwrap(); - let engine = RivetEngine::defaults(temp_dir.path().to_path_buf()).await; + + // Temporarily clear the env var if set + let old_key = std::env::var("RIVETDB_SECRET_KEY").ok(); + std::env::remove_var("RIVETDB_SECRET_KEY"); + + let result = RivetEngine::builder() + .base_dir(temp_dir.path().to_path_buf()) + .build() + .await; + + // Restore env var if it was set + if let Some(key) = old_key { + std::env::set_var("RIVETDB_SECRET_KEY", key); + } + + // Should succeed with default insecure key (with warnings logged) assert!( - engine.is_ok(), - "defaults() should create engine: {:?}", - engine.err() + result.is_ok(), + "Builder should succeed with default key: {:?}", + result.err() ); - let engine = engine.unwrap(); + let engine = result.unwrap(); let connections = engine.list_connections().await; assert!(connections.is_ok(), "Should be able to list connections"); } #[tokio::test(flavor = "multi_thread")] async fn test_from_config_sqlite_filesystem() { - use crate::config::{AppConfig, CatalogConfig, PathsConfig, ServerConfig, StorageConfig}; + use crate::config::{ + AppConfig, CatalogConfig, PathsConfig, SecretsConfig, ServerConfig, StorageConfig, + }; let temp_dir = TempDir::new().unwrap(); let base_dir = temp_dir.path().to_path_buf(); @@ -740,12 +831,16 @@ mod tests { base_dir: Some(base_dir.to_str().unwrap().to_string()), cache_dir: None, }, + secrets: SecretsConfig { + encryption_key: Some(test_secret_key()), + }, }; let engine = RivetEngine::from_config(&config).await; assert!( engine.is_ok(), - "from_config should create engine successfully" + "from_config should create engine successfully: {:?}", + engine.err() ); let engine = engine.unwrap(); diff --git a/src/http/app_server.rs b/src/http/app_server.rs index 471953b..3f68cc8 100644 --- a/src/http/app_server.rs +++ b/src/http/app_server.rs @@ -1,7 +1,8 @@ use crate::http::handlers::{ - create_connection_handler, delete_connection_handler, discover_connection_handler, - get_connection_handler, health_handler, list_connections_handler, - purge_connection_cache_handler, purge_table_cache_handler, query_handler, tables_handler, + create_connection_handler, create_secret_handler, delete_connection_handler, + delete_secret_handler, discover_connection_handler, get_connection_handler, get_secret_handler, + health_handler, list_connections_handler, list_secrets_handler, purge_connection_cache_handler, + purge_table_cache_handler, query_handler, tables_handler, update_secret_handler, }; use crate::RivetEngine; use axum::routing::{delete, get, post}; @@ -21,6 +22,8 @@ pub const PATH_CONNECTION: &str = "/connections/{name}"; pub const PATH_CONNECTION_DISCOVER: &str = "/connections/{name}/discover"; pub const PATH_CONNECTION_CACHE: &str = "/connections/{name}/cache"; pub const PATH_TABLE_CACHE: &str = "/connections/{name}/tables/{schema}/{table}/cache"; +pub const PATH_SECRETS: &str = "/secrets"; +pub const PATH_SECRET: &str = "/secrets/{name}"; impl AppServer { pub fn new(engine: RivetEngine) -> Self { @@ -44,6 +47,16 @@ impl AppServer { delete(purge_connection_cache_handler), ) .route(PATH_TABLE_CACHE, delete(purge_table_cache_handler)) + .route( + PATH_SECRETS, + post(create_secret_handler).get(list_secrets_handler), + ) + .route( + PATH_SECRET, + get(get_secret_handler) + .put(update_secret_handler) + .delete(delete_secret_handler), + ) .with_state(engine.clone()), engine, } diff --git a/src/http/error.rs b/src/http/error.rs index 2bbf67c..7d9ca1b 100644 --- a/src/http/error.rs +++ b/src/http/error.rs @@ -53,6 +53,14 @@ impl ApiError { code: "BAD_GATEWAY".to_string(), } } + + pub fn service_unavailable(message: impl Into) -> Self { + Self { + status: StatusCode::SERVICE_UNAVAILABLE, + message: message.into(), + code: "SERVICE_UNAVAILABLE".to_string(), + } + } } impl IntoResponse for ApiError { @@ -76,3 +84,22 @@ impl From for ApiError { ApiError::internal_error(err.to_string()) } } + +/// Convert SecretError to ApiError +impl From for ApiError { + fn from(e: crate::secrets::SecretError) -> Self { + use crate::secrets::SecretError; + let constructor = match &e { + SecretError::NotFound(_) => ApiError::not_found, + SecretError::AlreadyExists(_) | SecretError::CreationInProgress(_) => { + ApiError::conflict + } + SecretError::NotConfigured => ApiError::service_unavailable, + SecretError::InvalidName(_) => ApiError::bad_request, + SecretError::Backend(_) | SecretError::InvalidUtf8 | SecretError::Database(_) => { + ApiError::internal_error + } + }; + constructor(e.to_string()) + } +} diff --git a/src/http/handlers.rs b/src/http/handlers.rs index 103f626..6a60b0f 100644 --- a/src/http/handlers.rs +++ b/src/http/handlers.rs @@ -1,8 +1,9 @@ use crate::http::error::ApiError; use crate::http::models::{ - ConnectionInfo, CreateConnectionRequest, CreateConnectionResponse, DiscoverConnectionResponse, - DiscoveryStatus, GetConnectionResponse, ListConnectionsResponse, QueryRequest, QueryResponse, - TableInfo, TablesResponse, + ConnectionInfo, CreateConnectionRequest, CreateConnectionResponse, CreateSecretRequest, + CreateSecretResponse, DiscoverConnectionResponse, DiscoveryStatus, GetConnectionResponse, + GetSecretResponse, ListConnectionsResponse, ListSecretsResponse, QueryRequest, QueryResponse, + SecretMetadataResponse, TableInfo, TablesResponse, UpdateSecretRequest, UpdateSecretResponse, }; use crate::http::serialization::{encode_value_at, make_array_encoder}; use crate::source::Source; @@ -369,3 +370,91 @@ pub async fn purge_table_cache_handler( Ok(StatusCode::NO_CONTENT) } + +// Secret management handlers + +/// Handler for POST /secrets +pub async fn create_secret_handler( + State(engine): State>, + Json(request): Json, +) -> Result<(StatusCode, Json), ApiError> { + let secret_manager = engine.secret_manager(); + + secret_manager + .create(&request.name, request.value.as_bytes()) + .await?; + + let metadata = secret_manager.get_metadata(&request.name).await?; + + Ok(( + StatusCode::CREATED, + Json(CreateSecretResponse { + name: metadata.name, + created_at: metadata.created_at, + }), + )) +} + +/// Handler for PUT /secrets/{name} +pub async fn update_secret_handler( + State(engine): State>, + Path(name): Path, + Json(request): Json, +) -> Result, ApiError> { + let secret_manager = engine.secret_manager(); + + secret_manager + .update(&name, request.value.as_bytes()) + .await?; + + let metadata = secret_manager.get_metadata(&name).await?; + + Ok(Json(UpdateSecretResponse { + name: metadata.name, + updated_at: metadata.updated_at, + })) +} + +/// Handler for GET /secrets +pub async fn list_secrets_handler( + State(engine): State>, +) -> Result, ApiError> { + let secret_manager = engine.secret_manager(); + + let secrets = secret_manager.list().await?; + + Ok(Json(ListSecretsResponse { + secrets: secrets + .into_iter() + .map(SecretMetadataResponse::from) + .collect(), + })) +} + +/// Handler for GET /secrets/{name} +pub async fn get_secret_handler( + State(engine): State>, + Path(name): Path, +) -> Result, ApiError> { + let secret_manager = engine.secret_manager(); + + let metadata = secret_manager.get_metadata(&name).await?; + + Ok(Json(GetSecretResponse { + name: metadata.name, + created_at: metadata.created_at, + updated_at: metadata.updated_at, + })) +} + +/// Handler for DELETE /secrets/{name} +pub async fn delete_secret_handler( + State(engine): State>, + Path(name): Path, +) -> Result { + let secret_manager = engine.secret_manager(); + + secret_manager.delete(&name).await?; + + Ok(StatusCode::NO_CONTENT) +} diff --git a/src/http/models.rs b/src/http/models.rs index 2d5a6f7..2c5eebf 100644 --- a/src/http/models.rs +++ b/src/http/models.rs @@ -1,3 +1,5 @@ +use crate::secrets::SecretMetadata; +use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; /// Request body for POST /query @@ -95,3 +97,64 @@ pub struct GetConnectionResponse { pub table_count: usize, pub synced_table_count: usize, } + +// Secret management models + +/// Request body for POST /secrets +#[derive(Debug, Deserialize)] +pub struct CreateSecretRequest { + pub name: String, + pub value: String, +} + +/// Request body for PUT /secrets/{name} +#[derive(Debug, Deserialize)] +pub struct UpdateSecretRequest { + pub value: String, +} + +/// Response body for POST /secrets +#[derive(Debug, Serialize)] +pub struct CreateSecretResponse { + pub name: String, + pub created_at: DateTime, +} + +/// Response body for PUT /secrets/{name} +#[derive(Debug, Serialize)] +pub struct UpdateSecretResponse { + pub name: String, + pub updated_at: DateTime, +} + +/// Response body for GET /secrets +#[derive(Debug, Serialize)] +pub struct ListSecretsResponse { + pub secrets: Vec, +} + +/// Single secret metadata for API responses +#[derive(Debug, Serialize)] +pub struct SecretMetadataResponse { + pub name: String, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +impl From for SecretMetadataResponse { + fn from(m: SecretMetadata) -> Self { + Self { + name: m.name, + created_at: m.created_at, + updated_at: m.updated_at, + } + } +} + +/// Response body for GET /secrets/{name} +#[derive(Debug, Serialize)] +pub struct GetSecretResponse { + pub name: String, + pub created_at: DateTime, + pub updated_at: DateTime, +} diff --git a/src/lib.rs b/src/lib.rs index 55e93b5..e7e5674 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,7 @@ pub mod datafetch; pub mod datafusion; mod engine; pub mod http; +pub mod secrets; pub mod source; pub mod storage; diff --git a/src/secrets/backend.rs b/src/secrets/backend.rs new file mode 100644 index 0000000..c4ac5ea --- /dev/null +++ b/src/secrets/backend.rs @@ -0,0 +1,72 @@ +//! Low-level storage abstraction for secrets. +//! +//! The `SecretBackend` trait defines raw storage operations, decoupling +//! the manager layer (validation, encryption, metadata) from persistence. +//! This allows different storage backends (catalog tables, Vault, KMS) +//! without duplicating manager logic. + +use async_trait::async_trait; +use std::fmt::Debug; + +/// Error type for backend storage operations. +#[derive(Debug, thiserror::Error)] +pub enum BackendError { + #[error("Secret '{0}' not found")] + NotFound(String), + + #[error("Storage error: {0}")] + Storage(String), +} + +/// Catalog metadata the manager already knows about a secret. +/// Backends can use this information (e.g. provider_ref) to locate the raw value. +#[derive(Debug, Clone)] +pub struct SecretRecord { + pub name: String, + pub provider_ref: Option, + // Extend with additional fields (version, tags, etc.) as needed. +} + +/// Raw read result returned by a backend. +#[derive(Debug, Clone)] +pub struct BackendRead { + pub value: Vec, + pub provider_ref: Option, +} + +/// Result of writing a secret to a backend. +#[derive(Debug, Clone)] +pub struct BackendWrite { + pub status: WriteStatus, + pub provider_ref: Option, +} + +/// Whether a write created a new secret or updated an existing one. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum WriteStatus { + Created, + Updated, +} + +/// Low-level storage trait for secret values. +/// +/// Managers own name validation, encryption/decryption, and catalog metadata. +/// Backends are pure key-value stores that may rely on manager-provided +/// metadata (e.g. `provider_ref`) to locate entries and can return updated +/// metadata for the manager to persist. +#[async_trait] +pub trait SecretBackend: Debug + Send + Sync { + /// Retrieve the raw stored value for `record`, returning `None` if not found. + async fn get(&self, record: &SecretRecord) -> Result, BackendError>; + + /// Store `value` for `record` (create or update). + /// + /// Returns a `BackendWrite` so the manager knows whether it was a create + /// vs update and can persist any backend-provided metadata. + async fn put(&self, record: &SecretRecord, value: &[u8]) -> Result; + + /// Delete the entry identified by `record`. + /// + /// Returns `true` if something was deleted, `false` if it didn't exist. + async fn delete(&self, record: &SecretRecord) -> Result; +} diff --git a/src/secrets/encrypted_catalog_backend.rs b/src/secrets/encrypted_catalog_backend.rs new file mode 100644 index 0000000..79c692d --- /dev/null +++ b/src/secrets/encrypted_catalog_backend.rs @@ -0,0 +1,310 @@ +//! SecretBackend implementation that encrypts values and stores in catalog tables. + +use crate::catalog::CatalogManager; +use crate::secrets::backend::{ + BackendError, BackendRead, BackendWrite, SecretBackend, SecretRecord, WriteStatus, +}; +use crate::secrets::{decrypt, encrypt}; +use async_trait::async_trait; +use std::fmt::{self, Debug, Formatter}; +use std::sync::Arc; + +/// Provider type identifier for this backend. +pub const PROVIDER_TYPE: &str = "encrypted"; + +/// SecretBackend that encrypts values with AES-256-GCM-SIV and stores in catalog. +/// +/// This backend handles: +/// - AES-256-GCM-SIV encryption/decryption +/// - Storage in the catalog's `encrypted_secret_values` table +/// +/// The secret name is used as AAD (additional authenticated data) to prevent +/// ciphertext from being swapped between secrets. +pub struct EncryptedCatalogBackend { + key: [u8; 32], + catalog: Arc, +} + +impl EncryptedCatalogBackend { + /// Creates a new EncryptedCatalogBackend. + /// + /// # Arguments + /// * `key` - 32-byte AES-256 key + /// * `catalog` - Catalog manager for database access + pub fn new(key: [u8; 32], catalog: Arc) -> Self { + Self { key, catalog } + } + + /// Creates from base64-encoded key string. + pub fn from_base64_key( + key_base64: &str, + catalog: Arc, + ) -> Result { + use base64::{engine::general_purpose::STANDARD, Engine}; + + let key_bytes = STANDARD + .decode(key_base64) + .map_err(|e| BackendError::Storage(format!("Invalid base64 key: {}", e)))?; + + if key_bytes.len() != 32 { + return Err(BackendError::Storage(format!( + "Key must be exactly 32 bytes, got {}", + key_bytes.len() + ))); + } + + let key: [u8; 32] = key_bytes + .try_into() + .map_err(|_| BackendError::Storage("Key conversion failed".into()))?; + + Ok(Self::new(key, catalog)) + } + + /// Returns the provider type for this backend. + pub fn provider_type(&self) -> &'static str { + PROVIDER_TYPE + } +} + +impl Debug for EncryptedCatalogBackend { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("EncryptedCatalogBackend") + .field("catalog", &self.catalog) + .finish_non_exhaustive() + } +} + +#[async_trait] +impl SecretBackend for EncryptedCatalogBackend { + async fn get(&self, record: &SecretRecord) -> Result, BackendError> { + let encrypted = self + .catalog + .get_encrypted_secret(&record.name) + .await + .map_err(|e| BackendError::Storage(e.to_string()))?; + + match encrypted { + Some(ciphertext) => { + let plaintext = decrypt(&self.key, &ciphertext, &record.name) + .map_err(|e| BackendError::Storage(format!("Decryption failed: {}", e)))?; + + Ok(Some(BackendRead { + value: plaintext, + provider_ref: None, + })) + } + None => Ok(None), + } + } + + async fn put(&self, record: &SecretRecord, value: &[u8]) -> Result { + // Encrypt the value using the secret name as AAD + let ciphertext = encrypt(&self.key, value, &record.name) + .map_err(|e| BackendError::Storage(format!("Encryption failed: {}", e)))?; + + // Check if secret already exists to determine Created vs Updated + let exists = self + .catalog + .get_encrypted_secret(&record.name) + .await + .map_err(|e| BackendError::Storage(e.to_string()))? + .is_some(); + + // Store encrypted value + self.catalog + .put_encrypted_secret_value(&record.name, &ciphertext) + .await + .map_err(|e| BackendError::Storage(e.to_string()))?; + + Ok(BackendWrite { + status: if exists { + WriteStatus::Updated + } else { + WriteStatus::Created + }, + provider_ref: None, + }) + } + + async fn delete(&self, record: &SecretRecord) -> Result { + self.catalog + .delete_encrypted_secret_value(&record.name) + .await + .map_err(|e| BackendError::Storage(e.to_string())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::catalog::SqliteCatalogManager; + use tempfile::TempDir; + + fn test_key() -> [u8; 32] { + [0x42; 32] + } + + async fn test_backend() -> (EncryptedCatalogBackend, TempDir) { + let dir = TempDir::new().unwrap(); + let db_path = dir.path().join("test.db"); + let catalog = Arc::new( + SqliteCatalogManager::new(db_path.to_str().unwrap()) + .await + .unwrap(), + ); + catalog.run_migrations().await.unwrap(); + (EncryptedCatalogBackend::new(test_key(), catalog), dir) + } + + fn record(name: &str) -> SecretRecord { + SecretRecord { + name: name.to_string(), + provider_ref: None, + } + } + + #[tokio::test] + async fn test_put_and_get() { + let (backend, _dir) = test_backend().await; + let rec = record("test-secret"); + let value = b"my-secret-password"; + + let write = backend.put(&rec, value).await.unwrap(); + assert_eq!(write.status, WriteStatus::Created); + + let read = backend.get(&rec).await.unwrap().unwrap(); + assert_eq!(read.value, value); + } + + #[tokio::test] + async fn test_encryption_is_applied() { + let (backend, _dir) = test_backend().await; + let rec = record("encrypted-test"); + let plaintext = b"sensitive-data"; + + backend.put(&rec, plaintext).await.unwrap(); + + // Read raw from catalog - should be encrypted, not plaintext + let raw = backend + .catalog + .get_encrypted_secret(&rec.name) + .await + .unwrap() + .unwrap(); + + assert_ne!(raw, plaintext, "Value should be encrypted in storage"); + assert!(raw.len() > plaintext.len(), "Ciphertext should be larger"); + } + + #[tokio::test] + async fn test_get_not_found() { + let (backend, _dir) = test_backend().await; + let rec = record("nonexistent"); + + let result = backend.get(&rec).await.unwrap(); + assert!(result.is_none()); + } + + #[tokio::test] + async fn test_put_update() { + let (backend, _dir) = test_backend().await; + let rec = record("updatable"); + + let write1 = backend.put(&rec, b"old-value").await.unwrap(); + assert_eq!(write1.status, WriteStatus::Created); + + let write2 = backend.put(&rec, b"new-value").await.unwrap(); + assert_eq!(write2.status, WriteStatus::Updated); + + let read = backend.get(&rec).await.unwrap().unwrap(); + assert_eq!(read.value, b"new-value"); + } + + #[tokio::test] + async fn test_delete() { + let (backend, _dir) = test_backend().await; + let rec = record("to-delete"); + + backend.put(&rec, b"value").await.unwrap(); + + let deleted = backend.delete(&rec).await.unwrap(); + assert!(deleted); + + let result = backend.get(&rec).await.unwrap(); + assert!(result.is_none()); + } + + #[tokio::test] + async fn test_delete_not_found() { + let (backend, _dir) = test_backend().await; + let rec = record("nonexistent"); + + let deleted = backend.delete(&rec).await.unwrap(); + assert!(!deleted); + } + + #[tokio::test] + async fn test_wrong_key_fails_decrypt() { + let dir = TempDir::new().unwrap(); + let db_path = dir.path().join("test.db"); + let catalog = Arc::new( + SqliteCatalogManager::new(db_path.to_str().unwrap()) + .await + .unwrap(), + ); + catalog.run_migrations().await.unwrap(); + + // Store with one key + let backend1 = EncryptedCatalogBackend::new([0x42; 32], catalog.clone()); + let rec = record("key-test"); + backend1.put(&rec, b"secret").await.unwrap(); + + // Try to read with different key + let backend2 = EncryptedCatalogBackend::new([0x43; 32], catalog); + let result = backend2.get(&rec).await; + + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Decryption failed")); + } + + #[tokio::test] + async fn test_from_base64_key() { + use base64::{engine::general_purpose::STANDARD, Engine}; + + let key = [0x42u8; 32]; + let key_base64 = STANDARD.encode(key); + + let dir = TempDir::new().unwrap(); + let db_path = dir.path().join("test.db"); + let catalog = Arc::new( + SqliteCatalogManager::new(db_path.to_str().unwrap()) + .await + .unwrap(), + ); + + let result = EncryptedCatalogBackend::from_base64_key(&key_base64, catalog); + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_from_base64_key_invalid() { + let dir = TempDir::new().unwrap(); + let db_path = dir.path().join("test.db"); + let catalog = Arc::new( + SqliteCatalogManager::new(db_path.to_str().unwrap()) + .await + .unwrap(), + ); + + // Too short + let result = EncryptedCatalogBackend::from_base64_key("dG9vLXNob3J0", catalog.clone()); + assert!(result.is_err()); + + // Invalid base64 + let result = EncryptedCatalogBackend::from_base64_key("not-valid-base64!!!", catalog); + assert!(result.is_err()); + } +} diff --git a/src/secrets/encryption.rs b/src/secrets/encryption.rs new file mode 100644 index 0000000..725f4d5 --- /dev/null +++ b/src/secrets/encryption.rs @@ -0,0 +1,236 @@ +//! Low-level encryption/decryption for secret values. +//! +//! Format: ['R','V','S','1'][scheme][key_version][12-byte nonce][ciphertext...] +//! - Magic: 4 bytes "RVS1" +//! - Scheme: 1 byte (0x01 = AES-256-GCM-SIV) +//! - Key version: 1 byte (0x01 = first key) +//! - Nonce: 12 bytes +//! - Ciphertext: remaining bytes + +use aes_gcm_siv::{ + aead::{Aead, KeyInit, OsRng}, + Aes256GcmSiv, Nonce, +}; + +const MAGIC: &[u8; 4] = b"RVS1"; +const SCHEME_AES_256_GCM_SIV: u8 = 0x01; +const KEY_VERSION_1: u8 = 0x01; +const NONCE_SIZE: usize = 12; +const HEADER_SIZE: usize = 4 + 1 + 1 + NONCE_SIZE; // 18 bytes + +/// Encrypts plaintext using AES-256-GCM-SIV. +/// +/// # Arguments +/// * `key` - 32-byte encryption key +/// * `plaintext` - Data to encrypt +/// * `aad` - Associated authenticated data (secret name, prevents blob swapping) +pub fn encrypt(key: &[u8; 32], plaintext: &[u8], aad: &str) -> Result, EncryptError> { + let cipher = + Aes256GcmSiv::new_from_slice(key).map_err(|e| EncryptError::CipherInit(e.to_string()))?; + + // Generate random nonce + let mut nonce_bytes = [0u8; NONCE_SIZE]; + aes_gcm_siv::aead::rand_core::RngCore::fill_bytes(&mut OsRng, &mut nonce_bytes); + + let nonce = Nonce::from(nonce_bytes); + + let ciphertext = cipher + .encrypt( + &nonce, + aes_gcm_siv::aead::Payload { + msg: plaintext, + aad: aad.as_bytes(), + }, + ) + .map_err(|e| EncryptError::Encryption(e.to_string()))?; + + // Build output: magic + scheme + key_version + nonce + ciphertext + let mut output = Vec::with_capacity(HEADER_SIZE + ciphertext.len()); + output.extend_from_slice(MAGIC); + output.push(SCHEME_AES_256_GCM_SIV); + output.push(KEY_VERSION_1); + output.extend_from_slice(&nonce_bytes); + output.extend_from_slice(&ciphertext); + + Ok(output) +} + +/// Decrypts ciphertext encrypted with `encrypt()`. +/// +/// # Arguments +/// * `key` - 32-byte encryption key (must match key version in blob) +/// * `encrypted` - Encrypted blob from `encrypt()` +/// * `aad` - Associated authenticated data (must match value used during encryption) +pub fn decrypt(key: &[u8; 32], encrypted: &[u8], aad: &str) -> Result, DecryptError> { + if encrypted.len() < HEADER_SIZE { + return Err(DecryptError::TooShort); + } + + // Verify magic + if &encrypted[0..4] != MAGIC { + return Err(DecryptError::InvalidMagic); + } + + // Check scheme + let scheme = encrypted[4]; + if scheme != SCHEME_AES_256_GCM_SIV { + return Err(DecryptError::UnsupportedScheme(scheme)); + } + + // Check key version (for future key rotation support) + let key_version = encrypted[5]; + if key_version != KEY_VERSION_1 { + return Err(DecryptError::UnsupportedKeyVersion(key_version)); + } + + // Extract nonce and ciphertext + let nonce_bytes: [u8; NONCE_SIZE] = encrypted[6..6 + NONCE_SIZE] + .try_into() + .expect("slice length matches NONCE_SIZE"); + let ciphertext = &encrypted[HEADER_SIZE..]; + + let cipher = + Aes256GcmSiv::new_from_slice(key).map_err(|e| DecryptError::CipherInit(e.to_string()))?; + + let nonce = Nonce::from(nonce_bytes); + + let plaintext = cipher + .decrypt( + &nonce, + aes_gcm_siv::aead::Payload { + msg: ciphertext, + aad: aad.as_bytes(), + }, + ) + .map_err(|_| DecryptError::AuthenticationFailed)?; + + Ok(plaintext) +} + +#[derive(Debug, thiserror::Error)] +pub enum EncryptError { + #[error("Failed to initialize cipher: {0}")] + CipherInit(String), + #[error("Encryption failed: {0}")] + Encryption(String), +} + +#[derive(Debug, thiserror::Error)] +pub enum DecryptError { + #[error("Encrypted data too short")] + TooShort, + #[error("Invalid magic bytes (not a RivetDB secret)")] + InvalidMagic, + #[error("Unsupported encryption scheme: {0}")] + UnsupportedScheme(u8), + #[error("Unsupported key version: {0}")] + UnsupportedKeyVersion(u8), + #[error("Failed to initialize cipher: {0}")] + CipherInit(String), + #[error("Authentication failed (wrong key, corrupted data, or AAD mismatch)")] + AuthenticationFailed, +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_key() -> [u8; 32] { + [0x42; 32] // Deterministic key for tests + } + + #[test] + fn test_encrypt_decrypt_roundtrip() { + let key = test_key(); + let plaintext = b"my-secret-password"; + let aad = "my-secret-name"; + + let encrypted = encrypt(&key, plaintext, aad).unwrap(); + let decrypted = decrypt(&key, &encrypted, aad).unwrap(); + + assert_eq!(decrypted, plaintext); + } + + #[test] + fn test_encrypted_format() { + let key = test_key(); + let encrypted = encrypt(&key, b"test", "name").unwrap(); + + assert!(encrypted.len() >= HEADER_SIZE); + assert_eq!(&encrypted[0..4], MAGIC); + assert_eq!(encrypted[4], SCHEME_AES_256_GCM_SIV); + assert_eq!(encrypted[5], KEY_VERSION_1); + } + + #[test] + fn test_wrong_key_fails() { + let key1 = [0x42; 32]; + let key2 = [0x43; 32]; + let plaintext = b"secret"; + let aad = "name"; + + let encrypted = encrypt(&key1, plaintext, aad).unwrap(); + let result = decrypt(&key2, &encrypted, aad); + + assert!(matches!(result, Err(DecryptError::AuthenticationFailed))); + } + + #[test] + fn test_wrong_aad_fails() { + let key = test_key(); + let plaintext = b"secret"; + + let encrypted = encrypt(&key, plaintext, "correct-name").unwrap(); + let result = decrypt(&key, &encrypted, "wrong-name"); + + assert!(matches!(result, Err(DecryptError::AuthenticationFailed))); + } + + #[test] + fn test_corrupted_ciphertext_fails() { + let key = test_key(); + let mut encrypted = encrypt(&key, b"secret", "name").unwrap(); + + // Corrupt the ciphertext + if let Some(byte) = encrypted.last_mut() { + *byte ^= 0xFF; + } + + let result = decrypt(&key, &encrypted, "name"); + assert!(matches!(result, Err(DecryptError::AuthenticationFailed))); + } + + #[test] + fn test_invalid_magic_fails() { + let key = test_key(); + let mut encrypted = encrypt(&key, b"secret", "name").unwrap(); + encrypted[0] = 0x00; + + let result = decrypt(&key, &encrypted, "name"); + assert!(matches!(result, Err(DecryptError::InvalidMagic))); + } + + #[test] + fn test_too_short_fails() { + let key = test_key(); + let result = decrypt(&key, &[0; 10], "name"); + assert!(matches!(result, Err(DecryptError::TooShort))); + } + + #[test] + fn test_empty_plaintext() { + let key = test_key(); + let encrypted = encrypt(&key, b"", "name").unwrap(); + let decrypted = decrypt(&key, &encrypted, "name").unwrap(); + assert!(decrypted.is_empty()); + } + + #[test] + fn test_large_plaintext() { + let key = test_key(); + let plaintext = vec![0xAB; 10000]; + let encrypted = encrypt(&key, &plaintext, "name").unwrap(); + let decrypted = decrypt(&key, &encrypted, "name").unwrap(); + assert_eq!(decrypted, plaintext); + } +} diff --git a/src/secrets/mod.rs b/src/secrets/mod.rs new file mode 100644 index 0000000..464eb7a --- /dev/null +++ b/src/secrets/mod.rs @@ -0,0 +1,644 @@ +mod backend; +mod encrypted_catalog_backend; +mod encryption; +mod validation; + +pub use backend::{BackendError, BackendRead, BackendWrite, SecretBackend, SecretRecord}; +pub use encrypted_catalog_backend::{ + EncryptedCatalogBackend, PROVIDER_TYPE as ENCRYPTED_PROVIDER_TYPE, +}; +pub use encryption::{decrypt, encrypt, DecryptError, EncryptError}; +pub use validation::validate_and_normalize_name; + +use crate::catalog::CatalogManager; +use chrono::Utc; +use std::fmt::Debug; +use std::sync::Arc; + +use chrono::{DateTime, Utc as ChronoUtc}; + +/// Status of a secret in its lifecycle. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SecretStatus { + /// Secret is being created; backend write in progress. + Creating, + /// Secret is active and available. + Active, + /// Secret is marked for deletion; cleanup in progress. + PendingDelete, +} + +impl SecretStatus { + pub fn as_str(&self) -> &'static str { + match self { + SecretStatus::Creating => "creating", + SecretStatus::Active => "active", + SecretStatus::PendingDelete => "pending_delete", + } + } +} + +impl std::str::FromStr for SecretStatus { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "creating" => Ok(SecretStatus::Creating), + "active" => Ok(SecretStatus::Active), + "pending_delete" => Ok(SecretStatus::PendingDelete), + _ => Err(()), + } + } +} + +/// Metadata about a stored secret (no sensitive data). +#[derive(Debug, Clone)] +pub struct SecretMetadata { + pub name: String, + pub provider: String, + pub provider_ref: Option, + pub status: SecretStatus, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +/// Errors from secret manager operations. +#[derive(Debug, thiserror::Error)] +pub enum SecretError { + #[error("Secret '{0}' not found")] + NotFound(String), + + #[error("Secret '{0}' already exists")] + AlreadyExists(String), + + #[error( + "Secret '{0}' is being created by another process; delete it first if you want to retry" + )] + CreationInProgress(String), + + #[error("Secret manager not configured")] + NotConfigured, + + #[error( + "Invalid secret name '{0:.128}': must be 1-128 characters, alphanumeric with _ and - only" + )] + InvalidName(String), + + #[error("Backend error: {0}")] + Backend(String), + + #[error("Invalid secret value: not valid UTF-8")] + InvalidUtf8, + + #[error("Database error: {0}")] + Database(String), +} + +impl From for SecretError { + fn from(e: BackendError) -> Self { + match e { + BackendError::NotFound(name) => SecretError::NotFound(name), + BackendError::Storage(msg) => SecretError::Backend(msg), + } + } +} + +/// Coordinates secret storage between a backend and catalog metadata. +/// +/// The `SecretManager` is the main entry point for secret operations. It: +/// - Validates and normalizes secret names +/// - Manages metadata (timestamps, provider info) in the catalog +/// - Delegates value storage to a `SecretBackend` +/// +/// Different backends handle storage differently (local encrypted, Vault, KMS, etc). +#[derive(Debug)] +pub struct SecretManager { + backend: Arc, + catalog: Arc, + provider_type: String, +} + +impl SecretManager { + /// Creates a new SecretManager. + /// + /// # Arguments + /// * `backend` - Backend for storing secret values + /// * `catalog` - Catalog manager for metadata persistence + /// * `provider_type` - Provider identifier stored in metadata (e.g., "encrypted", "vault") + pub fn new( + backend: Arc, + catalog: Arc, + provider_type: impl Into, + ) -> Self { + Self { + backend, + catalog, + provider_type: provider_type.into(), + } + } + + /// Get a secret's raw bytes by name. + pub async fn get(&self, name: &str) -> Result, SecretError> { + let normalized = validate_and_normalize_name(name)?; + + // Fetch metadata to get provider_ref for the backend + let metadata = self + .catalog + .get_secret_metadata(&normalized) + .await + .map_err(|e| SecretError::Database(e.to_string()))? + .ok_or_else(|| SecretError::NotFound(normalized.clone()))?; + + let record = SecretRecord { + name: normalized.clone(), + provider_ref: metadata.provider_ref, + }; + + let read = self + .backend + .get(&record) + .await? + .ok_or(SecretError::NotFound(normalized))?; + + Ok(read.value) + } + + /// Get a secret as a UTF-8 string. + pub async fn get_string(&self, name: &str) -> Result { + let bytes = self.get(name).await?; + String::from_utf8(bytes).map_err(|_| SecretError::InvalidUtf8) + } + + /// Get a secret's metadata by name (no value). + pub async fn get_metadata(&self, name: &str) -> Result { + let normalized = validate_and_normalize_name(name)?; + + self.catalog + .get_secret_metadata(&normalized) + .await + .map_err(|e| SecretError::Database(e.to_string()))? + .ok_or(SecretError::NotFound(normalized)) + } + + /// Create a new secret. + /// + /// Uses optimistic locking to handle concurrent creation attempts safely: + /// 1. Insert metadata with status=Creating (claims the name) + /// 2. Store value in backend + /// 3. Update metadata to Active with optimistic lock on created_at + /// + /// If step 2 fails, the Creating record remains (user must delete to retry). + /// If another request races us at step 1, we detect it and return appropriate error. + pub async fn create(&self, name: &str, value: &[u8]) -> Result<(), SecretError> { + use crate::catalog::OptimisticLock; + + let normalized = validate_and_normalize_name(name)?; + + // Check existing state + let existing = self + .catalog + .get_secret_metadata_any_status(&normalized) + .await + .map_err(|e| SecretError::Database(e.to_string()))?; + + if let Some(metadata) = existing { + match metadata.status { + SecretStatus::Active => return Err(SecretError::AlreadyExists(normalized)), + SecretStatus::Creating => return Err(SecretError::CreationInProgress(normalized)), + SecretStatus::PendingDelete => { + // Stale record from failed delete - clean it up + tracing::info!( + secret = %normalized, + "Cleaning up stale pending_delete metadata before create" + ); + let _ = self.catalog.delete_secret_metadata(&normalized).await; + } + } + } + + // Step 1: Insert metadata with status=Creating to claim the name + let now = Utc::now(); + let creating_metadata = SecretMetadata { + name: normalized.clone(), + provider: self.provider_type.clone(), + provider_ref: None, + status: SecretStatus::Creating, + created_at: now, + updated_at: now, + }; + + if let Err(e) = self + .catalog + .create_secret_metadata(&creating_metadata) + .await + { + // Another request likely beat us - check what's there now + let current = self + .catalog + .get_secret_metadata_any_status(&normalized) + .await + .map_err(|e2| SecretError::Database(e2.to_string()))?; + + return match current.map(|m| m.status) { + Some(SecretStatus::Active) => Err(SecretError::AlreadyExists(normalized)), + Some(SecretStatus::Creating) => Err(SecretError::CreationInProgress(normalized)), + _ => Err(SecretError::Database(e.to_string())), + }; + } + + // Step 2: Store value in backend + let record = SecretRecord { + name: normalized.clone(), + provider_ref: None, + }; + + let write = match self.backend.put(&record, value).await { + Ok(w) => w, + Err(e) => { + // Backend failed - leave Creating record in place. + // User must delete to retry. + tracing::error!( + secret = %normalized, + error = %e, + "Backend write failed during create; secret left in Creating state" + ); + return Err(e.into()); + } + }; + + // Step 3: Update metadata to Active with optimistic lock + let active_metadata = SecretMetadata { + name: normalized.clone(), + provider: self.provider_type.clone(), + provider_ref: write.provider_ref, + status: SecretStatus::Active, + created_at: now, + updated_at: Utc::now(), + }; + + let lock = OptimisticLock::from(now); + let updated = self + .catalog + .update_secret_metadata(&active_metadata, Some(lock)) + .await + .map_err(|e| SecretError::Database(e.to_string()))?; + + if !updated { + // Our Creating record was deleted while we were writing to backend. + // The backend value is now orphaned (harmless, will be overwritten on next create). + return Err(SecretError::Database(format!( + "Secret '{}' was deleted by another process while being created; please retry", + normalized + ))); + } + + Ok(()) + } + + /// Update an existing secret's value. + /// + /// Fails with `NotFound` if the secret doesn't exist. + /// For creating a new secret, use `create` instead. + pub async fn update(&self, name: &str, value: &[u8]) -> Result<(), SecretError> { + let normalized = validate_and_normalize_name(name)?; + + // Load existing metadata to get provider_ref + let metadata = self + .catalog + .get_secret_metadata(&normalized) + .await + .map_err(|e| SecretError::Database(e.to_string()))? + .ok_or_else(|| SecretError::NotFound(normalized.clone()))?; + + let existing_ref = metadata.provider_ref; + + let record = SecretRecord { + name: normalized.clone(), + provider_ref: existing_ref.clone(), + }; + + // Store via backend (update path) + let write = self.backend.put(&record, value).await?; + + // Use new provider_ref if backend returned one, otherwise preserve existing + let updated_ref = write.provider_ref.or(existing_ref); + + // Update metadata timestamp (no optimistic lock for updates - last write wins) + let now = Utc::now(); + let updated_metadata = SecretMetadata { + name: normalized.clone(), + provider: self.provider_type.clone(), + provider_ref: updated_ref, + status: SecretStatus::Active, + created_at: metadata.created_at, + updated_at: now, + }; + + self.catalog + .update_secret_metadata(&updated_metadata, None) + .await + .map_err(|e| SecretError::Database(e.to_string()))?; + + Ok(()) + } + + /// Delete a secret using three-phase commit. + /// + /// 1. Mark metadata as 'pending_delete' (secret becomes invisible) + /// 2. Delete from backend + /// 3. Delete metadata row + /// + /// If step 2 fails, the secret remains in 'pending_delete' state. + /// Subsequent delete calls can retry (we use any_status lookup). + /// If step 3 fails, the secret is effectively deleted (value gone). + pub async fn delete(&self, name: &str) -> Result<(), SecretError> { + let normalized = validate_and_normalize_name(name)?; + + // Fetch metadata to get provider_ref for the backend. + // Use any_status so we can retry deletion of pending_delete secrets. + let metadata = self + .catalog + .get_secret_metadata_any_status(&normalized) + .await + .map_err(|e| SecretError::Database(e.to_string()))? + .ok_or_else(|| SecretError::NotFound(normalized.clone()))?; + + // Phase 1: Mark as pending_delete (secret becomes invisible to reads) + // This is idempotent if already pending_delete. + self.catalog + .set_secret_status(&normalized, SecretStatus::PendingDelete) + .await + .map_err(|e| SecretError::Database(e.to_string()))?; + + // Phase 2: Delete from backend + let record = SecretRecord { + name: normalized.clone(), + provider_ref: metadata.provider_ref, + }; + + if let Err(e) = self.backend.delete(&record).await { + tracing::error!( + secret = %normalized, + error = %e, + "Failed to delete secret from backend; secret left in pending_delete state" + ); + return Err(e.into()); + } + + // Phase 3: Delete metadata row + if let Err(e) = self.catalog.delete_secret_metadata(&normalized).await { + // Backend value is gone but metadata remains in pending_delete + // This is acceptable - create will clean it up + tracing::warn!( + secret = %normalized, + error = %e, + "Secret value deleted but failed to remove metadata; will be cleaned up on next create" + ); + } + + Ok(()) + } + + /// List all secrets (metadata only, no values). + pub async fn list(&self) -> Result, SecretError> { + self.catalog + .list_secrets() + .await + .map_err(|e| SecretError::Database(e.to_string())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::catalog::SqliteCatalogManager; + use tempfile::TempDir; + + fn test_key() -> [u8; 32] { + [0x42; 32] + } + + async fn test_manager() -> (SecretManager, TempDir) { + let dir = TempDir::new().unwrap(); + let db_path = dir.path().join("test.db"); + let catalog = Arc::new( + SqliteCatalogManager::new(db_path.to_str().unwrap()) + .await + .unwrap(), + ); + catalog.run_migrations().await.unwrap(); + + let backend = Arc::new(EncryptedCatalogBackend::new(test_key(), catalog.clone())); + + ( + SecretManager::new(backend, catalog, ENCRYPTED_PROVIDER_TYPE), + dir, + ) + } + + #[tokio::test] + async fn test_create_and_get() { + let (manager, _dir) = test_manager().await; + let value = b"my-secret-password"; + + manager.create("my-secret", value).await.unwrap(); + let retrieved = manager.get("my-secret").await.unwrap(); + + assert_eq!(retrieved, value); + } + + #[tokio::test] + async fn test_create_and_get_string() { + let (manager, _dir) = test_manager().await; + let value = "my-string-secret"; + + manager + .create("string-secret", value.as_bytes()) + .await + .unwrap(); + let retrieved = manager.get_string("string-secret").await.unwrap(); + + assert_eq!(retrieved, value); + } + + #[tokio::test] + async fn test_create_already_exists() { + let (manager, _dir) = test_manager().await; + + manager.create("duplicate", b"first").await.unwrap(); + let result = manager.create("duplicate", b"second").await; + + assert!(matches!(result, Err(SecretError::AlreadyExists(_)))); + } + + #[tokio::test] + async fn test_get_not_found() { + let (manager, _dir) = test_manager().await; + let result = manager.get("nonexistent").await; + + assert!(matches!(result, Err(SecretError::NotFound(_)))); + } + + #[tokio::test] + async fn test_update_existing() { + let (manager, _dir) = test_manager().await; + + manager.create("updatable", b"old-value").await.unwrap(); + manager.update("updatable", b"new-value").await.unwrap(); + + let retrieved = manager.get("updatable").await.unwrap(); + assert_eq!(retrieved, b"new-value"); + } + + #[tokio::test] + async fn test_update_not_found() { + let (manager, _dir) = test_manager().await; + let result = manager.update("nonexistent", b"value").await; + + assert!(matches!(result, Err(SecretError::NotFound(_)))); + } + + #[tokio::test] + async fn test_delete() { + let (manager, _dir) = test_manager().await; + + manager.create("to-delete", b"value").await.unwrap(); + manager.delete("to-delete").await.unwrap(); + + let result = manager.get("to-delete").await; + assert!(matches!(result, Err(SecretError::NotFound(_)))); + } + + #[tokio::test] + async fn test_delete_not_found() { + let (manager, _dir) = test_manager().await; + let result = manager.delete("nonexistent").await; + + assert!(matches!(result, Err(SecretError::NotFound(_)))); + } + + #[tokio::test] + async fn test_list_and_metadata() { + let (manager, _dir) = test_manager().await; + + manager.create("secret-a", b"value-a").await.unwrap(); + manager.create("secret-b", b"value-b").await.unwrap(); + + let list = manager.list().await.unwrap(); + assert_eq!(list.len(), 2); + + let metadata = manager.get_metadata("secret-a").await.unwrap(); + assert_eq!(metadata.name, "secret-a"); + } + + #[tokio::test] + async fn test_name_normalization() { + let (manager, _dir) = test_manager().await; + + manager.create("My-Secret", b"value").await.unwrap(); + + // Should be able to retrieve with different case + let retrieved = manager.get("my-secret").await.unwrap(); + assert_eq!(retrieved, b"value"); + } + + #[tokio::test] + async fn test_invalid_name_on_create() { + let (manager, _dir) = test_manager().await; + + let result = manager.create("invalid name!", b"value").await; + assert!(matches!(result, Err(SecretError::InvalidName(_)))); + } + + #[tokio::test] + async fn test_invalid_name_on_update() { + let (manager, _dir) = test_manager().await; + + let result = manager.update("invalid name!", b"value").await; + assert!(matches!(result, Err(SecretError::InvalidName(_)))); + } + + #[tokio::test] + async fn test_delete_retry_after_pending_delete() { + // Simulates the scenario where backend delete fails, leaving secret in pending_delete. + // Subsequent delete calls should still work (not return NotFound). + let dir = TempDir::new().unwrap(); + let db_path = dir.path().join("test.db"); + let catalog = Arc::new( + SqliteCatalogManager::new(db_path.to_str().unwrap()) + .await + .unwrap(), + ); + catalog.run_migrations().await.unwrap(); + + let backend = Arc::new(EncryptedCatalogBackend::new(test_key(), catalog.clone())); + let manager = SecretManager::new(backend, catalog.clone(), ENCRYPTED_PROVIDER_TYPE); + + // Create a secret + manager.create("retry-delete", b"value").await.unwrap(); + + // Simulate a failed delete by setting status to pending_delete directly + catalog + .set_secret_status("retry-delete", SecretStatus::PendingDelete) + .await + .unwrap(); + + // Secret should no longer be visible to normal reads + let result = manager.get("retry-delete").await; + assert!(matches!(result, Err(SecretError::NotFound(_)))); + + // But delete should still work (this was the bug - it would return NotFound) + manager.delete("retry-delete").await.unwrap(); + + // Verify it's fully deleted + let any_status = catalog + .get_secret_metadata_any_status("retry-delete") + .await + .unwrap(); + assert!(any_status.is_none()); + } + + #[tokio::test] + async fn test_create_blocked_by_creating_status() { + // If a secret is stuck in "creating" status (e.g., backend write failed), + // subsequent creates should return CreationInProgress error. + let dir = TempDir::new().unwrap(); + let db_path = dir.path().join("test.db"); + let catalog = Arc::new( + SqliteCatalogManager::new(db_path.to_str().unwrap()) + .await + .unwrap(), + ); + catalog.run_migrations().await.unwrap(); + + let backend = Arc::new(EncryptedCatalogBackend::new(test_key(), catalog.clone())); + let manager = SecretManager::new(backend, catalog.clone(), ENCRYPTED_PROVIDER_TYPE); + + // Manually insert a secret in "creating" status (simulating failed backend write) + let now = chrono::Utc::now(); + let creating_metadata = SecretMetadata { + name: "stuck-secret".to_string(), + provider: ENCRYPTED_PROVIDER_TYPE.to_string(), + provider_ref: None, + status: SecretStatus::Creating, + created_at: now, + updated_at: now, + }; + catalog + .create_secret_metadata(&creating_metadata) + .await + .unwrap(); + + // Attempt to create the same secret should return CreationInProgress + let result = manager.create("stuck-secret", b"value").await; + assert!(matches!(result, Err(SecretError::CreationInProgress(_)))); + + // User can delete the stuck secret to retry + manager.delete("stuck-secret").await.unwrap(); + + // Now create should succeed + manager.create("stuck-secret", b"value").await.unwrap(); + let retrieved = manager.get("stuck-secret").await.unwrap(); + assert_eq!(retrieved, b"value"); + } +} diff --git a/src/secrets/validation.rs b/src/secrets/validation.rs new file mode 100644 index 0000000..c6fade4 --- /dev/null +++ b/src/secrets/validation.rs @@ -0,0 +1,56 @@ +use super::SecretError; + +/// Validates a secret name and returns the normalized (lowercase) form. +/// Returns an error if the name doesn't match the allowed pattern. +/// +/// Valid names: 1-128 characters, alphanumeric with _ and - only. +pub fn validate_and_normalize_name(name: &str) -> Result { + if name.is_empty() || name.len() > 128 { + return Err(SecretError::InvalidName(name.to_string())); + } + if !name + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-') + { + return Err(SecretError::InvalidName(name.to_string())); + } + Ok(name.to_lowercase()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_valid_names() { + assert_eq!( + validate_and_normalize_name("my-secret").unwrap(), + "my-secret" + ); + assert_eq!( + validate_and_normalize_name("MY_SECRET").unwrap(), + "my_secret" + ); + assert_eq!( + validate_and_normalize_name("secret123").unwrap(), + "secret123" + ); + assert_eq!(validate_and_normalize_name("a").unwrap(), "a"); + assert_eq!(validate_and_normalize_name("A-B_c-1").unwrap(), "a-b_c-1"); + } + + #[test] + fn test_invalid_names() { + assert!(validate_and_normalize_name("").is_err()); + assert!(validate_and_normalize_name("has space").is_err()); + assert!(validate_and_normalize_name("has.dot").is_err()); + assert!(validate_and_normalize_name("has/slash").is_err()); + assert!(validate_and_normalize_name(&"a".repeat(129)).is_err()); + } + + #[test] + fn test_max_length_name() { + let max_name = "a".repeat(128); + assert!(validate_and_normalize_name(&max_name).is_ok()); + } +} diff --git a/src/source.rs b/src/source.rs index fa4152a..e477740 100644 --- a/src/source.rs +++ b/src/source.rs @@ -1,8 +1,40 @@ +use crate::secrets::SecretManager; use serde::{Deserialize, Serialize}; -use urlencoding::encode; + +/// Credential storage - either no credential or a reference to a stored secret. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Credential { + #[default] + None, + SecretRef { + name: String, + }, +} + +impl Credential { + /// Resolve the credential to a plaintext string. + /// Returns an error if the credential is None or the secret cannot be found/decoded. + pub async fn resolve(&self, secrets: &SecretManager) -> anyhow::Result { + match self { + Credential::None => Err(anyhow::anyhow!("no credential configured")), + Credential::SecretRef { name } => { + let bytes = secrets + .get(name) + .await + .map_err(|e| anyhow::anyhow!("failed to resolve secret '{}': {}", name, e))?; + String::from_utf8(bytes) + .map_err(|_| anyhow::anyhow!("secret '{}' is not valid UTF-8", name)) + } + } + } +} /// Represents a data source connection with typed configuration. /// The `type` field is used as the JSON discriminator via serde's tag attribute. +/// +/// Credentials are stored as secrets and referenced via the `credential` field. +/// Use `credential().resolve(secrets)` to obtain the plaintext value. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(tag = "type", rename_all = "snake_case")] pub enum Source { @@ -10,21 +42,24 @@ pub enum Source { host: String, port: u16, user: String, - password: String, database: String, + #[serde(default)] + credential: Credential, }, Snowflake { account: String, user: String, - password: String, warehouse: String, database: String, #[serde(skip_serializing_if = "Option::is_none")] role: Option, + #[serde(default)] + credential: Credential, }, Motherduck { - token: String, database: String, + #[serde(default)] + credential: Credential, }, Duckdb { path: String, @@ -51,47 +86,13 @@ impl Source { } } - /// Builds the connection string for this source. - /// User-provided values are URL-encoded to prevent connection string injection. - pub fn connection_string(&self) -> String { + /// Access the credential field. + pub fn credential(&self) -> &Credential { match self { - Source::Postgres { - host, - port, - user, - password, - database, - } => { - format!( - "postgresql://{}:{}@{}:{}/{}", - encode(user), - encode(password), - encode(host), - port, - encode(database) - ) - } - Source::Snowflake { - account, - user, - password, - warehouse, - database, - .. - } => { - format!( - "{}:{}@{}/{}/{}", - encode(user), - encode(password), - encode(account), - encode(database), - encode(warehouse) - ) - } - Source::Motherduck { token, database } => { - format!("md:{}?motherduck_token={}", encode(database), encode(token)) - } - Source::Duckdb { path } => path.clone(), + Source::Postgres { credential, .. } => credential, + Source::Snowflake { credential, .. } => credential, + Source::Motherduck { credential, .. } => credential, + Source::Duckdb { .. } => &Credential::None, } } } @@ -106,13 +107,16 @@ mod tests { host: "localhost".to_string(), port: 5432, user: "postgres".to_string(), - password: "secret".to_string(), database: "mydb".to_string(), + credential: Credential::SecretRef { + name: "my-pg-secret".to_string(), + }, }; let json = serde_json::to_string(&source).unwrap(); assert!(json.contains(r#""type":"postgres""#)); assert!(json.contains(r#""host":"localhost""#)); + assert!(json.contains(r#""my-pg-secret""#)); let parsed: Source = serde_json::from_str(&json).unwrap(); assert_eq!(source, parsed); @@ -123,10 +127,12 @@ mod tests { let source = Source::Snowflake { account: "xyz123".to_string(), user: "bob".to_string(), - password: "secret".to_string(), warehouse: "COMPUTE_WH".to_string(), database: "PROD".to_string(), role: Some("ANALYST".to_string()), + credential: Credential::SecretRef { + name: "snowflake-secret".to_string(), + }, }; let json = serde_json::to_string(&source).unwrap(); @@ -142,26 +148,30 @@ mod tests { let source = Source::Snowflake { account: "xyz123".to_string(), user: "bob".to_string(), - password: "secret".to_string(), warehouse: "COMPUTE_WH".to_string(), database: "PROD".to_string(), role: None, + credential: Credential::SecretRef { + name: "secret".to_string(), + }, }; let json = serde_json::to_string(&source).unwrap(); - assert!(!json.contains("role")); + assert!(!json.contains(r#""role""#)); } #[test] fn test_motherduck_serialization() { let source = Source::Motherduck { - token: "md_abc123".to_string(), database: "my_db".to_string(), + credential: Credential::SecretRef { + name: "md-token".to_string(), + }, }; let json = serde_json::to_string(&source).unwrap(); assert!(json.contains(r#""type":"motherduck""#)); - assert!(json.contains(r#""token":"md_abc123""#)); + assert!(json.contains(r#""md-token""#)); let parsed: Source = serde_json::from_str(&json).unwrap(); assert_eq!(source, parsed); @@ -170,8 +180,8 @@ mod tests { #[test] fn test_catalog_method() { let motherduck = Source::Motherduck { - token: "t".to_string(), database: "my_database".to_string(), + credential: Credential::None, }; assert_eq!(motherduck.catalog(), Some("my_database")); @@ -184,18 +194,18 @@ mod tests { host: "localhost".to_string(), port: 5432, user: "u".to_string(), - password: "p".to_string(), database: "d".to_string(), + credential: Credential::None, }; assert_eq!(postgres.catalog(), None); let snowflake = Source::Snowflake { account: "a".to_string(), user: "u".to_string(), - password: "p".to_string(), warehouse: "w".to_string(), database: "d".to_string(), role: None, + credential: Credential::None, }; assert_eq!(snowflake.catalog(), None); } @@ -206,25 +216,47 @@ mod tests { host: "localhost".to_string(), port: 5432, user: "u".to_string(), - password: "p".to_string(), database: "d".to_string(), + credential: Credential::None, }; assert_eq!(postgres.source_type(), "postgres"); let snowflake = Source::Snowflake { account: "a".to_string(), user: "u".to_string(), - password: "p".to_string(), warehouse: "w".to_string(), database: "d".to_string(), role: None, + credential: Credential::None, }; assert_eq!(snowflake.source_type(), "snowflake"); let motherduck = Source::Motherduck { - token: "t".to_string(), database: "d".to_string(), + credential: Credential::None, }; assert_eq!(motherduck.source_type(), "motherduck"); } + + #[test] + fn test_credential_accessor() { + let with_secret = Source::Postgres { + host: "h".to_string(), + port: 5432, + user: "u".to_string(), + database: "d".to_string(), + credential: Credential::SecretRef { + name: "my-secret".to_string(), + }, + }; + assert!(matches!( + with_secret.credential(), + Credential::SecretRef { name } if name == "my-secret" + )); + + let duckdb = Source::Duckdb { + path: "/p".to_string(), + }; + assert!(matches!(duckdb.credential(), Credential::None)); + } } diff --git a/tests/catalog_manager_suite.rs b/tests/catalog_manager_suite.rs index 3f0aa3e..029ecb2 100644 --- a/tests/catalog_manager_suite.rs +++ b/tests/catalog_manager_suite.rs @@ -460,6 +460,32 @@ macro_rules! catalog_manager_tests { catalog.close().await.unwrap(); catalog.close().await.unwrap(); } + + #[tokio::test] + async fn create_secret_metadata_duplicate_fails() { + use rivetdb::secrets::{SecretMetadata, SecretStatus}; + + let ctx = super::$setup_fn().await; + let catalog = ctx.manager(); + let now = chrono::Utc::now(); + + let metadata = SecretMetadata { + name: "my-secret".to_string(), + provider: "encrypted".to_string(), + provider_ref: None, + status: SecretStatus::Active, + created_at: now, + updated_at: now, + }; + + // First create should succeed + catalog.create_secret_metadata(&metadata).await.unwrap(); + + // Second create with same name should fail (unique constraint) + let result = catalog.create_secret_metadata(&metadata).await; + + assert!(result.is_err()); + } } }; } diff --git a/tests/datafetch_tests.rs b/tests/datafetch_tests.rs index 1ed9858..fb0967a 100644 --- a/tests/datafetch_tests.rs +++ b/tests/datafetch_tests.rs @@ -1,16 +1,40 @@ //! Integration tests for datafetch module +use rivetdb::catalog::{CatalogManager, SqliteCatalogManager}; use rivetdb::datafetch::{DataFetcher, NativeFetcher}; +use rivetdb::secrets::{EncryptedCatalogBackend, SecretManager, ENCRYPTED_PROVIDER_TYPE}; use rivetdb::source::Source; +use std::sync::Arc; +use tempfile::TempDir; + +/// Create a test SecretManager with temporary storage. +/// The secret manager is required by the API even for sources that don't use credentials. +async fn test_secret_manager(dir: &TempDir) -> SecretManager { + let db_path = dir.path().join("test_catalog.db"); + let catalog = Arc::new( + SqliteCatalogManager::new(db_path.to_str().unwrap()) + .await + .unwrap(), + ); + catalog.run_migrations().await.unwrap(); + + let key = [0x42u8; 32]; + let backend = Arc::new(EncryptedCatalogBackend::new(key, catalog.clone())); + + SecretManager::new(backend, catalog, ENCRYPTED_PROVIDER_TYPE) +} #[tokio::test] async fn test_duckdb_discovery_empty() { + let temp_dir = TempDir::new().unwrap(); + let secrets = test_secret_manager(&temp_dir).await; + let fetcher = NativeFetcher::new(); let source = Source::Duckdb { path: ":memory:".to_string(), }; - let result = fetcher.discover_tables(&source).await; + let result = fetcher.discover_tables(&source, &secrets).await; assert!( result.is_ok(), "Discovery should succeed: {:?}", @@ -23,8 +47,10 @@ async fn test_duckdb_discovery_empty() { #[tokio::test] async fn test_duckdb_discovery_with_table() { + let temp_dir = TempDir::new().unwrap(); + let secrets = test_secret_manager(&temp_dir).await; + // Create a temp file for DuckDB - let temp_dir = tempfile::tempdir().unwrap(); let db_path = temp_dir.path().join("test.duckdb"); // Create table using duckdb crate directly @@ -43,7 +69,7 @@ async fn test_duckdb_discovery_with_table() { path: db_path.to_str().unwrap().to_string(), }; - let result = fetcher.discover_tables(&source).await; + let result = fetcher.discover_tables(&source, &secrets).await; assert!( result.is_ok(), "Discovery should succeed: {:?}", @@ -64,18 +90,21 @@ async fn test_duckdb_discovery_with_table() { #[tokio::test] async fn test_unsupported_driver() { + let temp_dir = TempDir::new().unwrap(); + let secrets = test_secret_manager(&temp_dir).await; + let fetcher = NativeFetcher::new(); // Use a Snowflake source which is not implemented let source = Source::Snowflake { account: "fake".to_string(), user: "fake".to_string(), - password: "fake".to_string(), warehouse: "fake".to_string(), database: "fake".to_string(), role: None, + credential: rivetdb::source::Credential::None, }; - let result = fetcher.discover_tables(&source).await; + let result = fetcher.discover_tables(&source, &secrets).await; assert!(result.is_err(), "Should fail for unsupported driver"); } @@ -85,8 +114,10 @@ async fn test_duckdb_fetch_table() { use rivetdb::datafetch::StreamingParquetWriter; use std::fs::File; + let temp_dir = TempDir::new().unwrap(); + let secrets = test_secret_manager(&temp_dir).await; + // Create a temp DuckDB database with test data - let temp_dir = tempfile::tempdir().unwrap(); let db_path = temp_dir.path().join("test.duckdb"); // Populate database with test data (multiple types) @@ -124,7 +155,14 @@ async fn test_duckdb_fetch_table() { let mut writer = StreamingParquetWriter::new(output_path.clone()); let result = fetcher - .fetch_table(&source, None, "test_schema", "products", &mut writer) + .fetch_table( + &source, + &secrets, + None, + "test_schema", + "products", + &mut writer, + ) .await; assert!(result.is_ok(), "Fetch should succeed: {:?}", result.err()); diff --git a/tests/engine_sync_tests.rs b/tests/engine_sync_tests.rs index d4434fe..a1a0dee 100644 --- a/tests/engine_sync_tests.rs +++ b/tests/engine_sync_tests.rs @@ -1,14 +1,27 @@ use anyhow::Result; +use base64::{engine::general_purpose::STANDARD, Engine}; +use rand::RngCore; use rivetdb::RivetEngine; use tempfile::tempdir; +/// Generate a test secret key (base64-encoded 32 bytes) +fn generate_test_secret_key() -> String { + let mut key = [0u8; 32]; + rand::thread_rng().fill_bytes(&mut key); + STANDARD.encode(key) +} + /// Test that sync_connection handles non-existent connections correctly #[tokio::test] #[ignore] async fn test_sync_connection_not_found() -> Result<()> { let dir = tempdir()?; - let engine = RivetEngine::defaults(dir.path()).await?; + let engine = RivetEngine::builder() + .base_dir(dir.path()) + .secret_key(generate_test_secret_key()) + .build() + .await?; // Try to sync a connection that doesn't exist let result = engine.sync_connection("nonexistent").await; @@ -26,7 +39,11 @@ async fn test_sync_connection_not_found() -> Result<()> { async fn test_sync_connection_no_tables() -> Result<()> { let dir = tempdir()?; - let engine = RivetEngine::defaults(dir.path()).await?; + let engine = RivetEngine::builder() + .base_dir(dir.path()) + .secret_key(generate_test_secret_key()) + .build() + .await?; // Add a connection with no tables let config = serde_json::json!({ diff --git a/tests/http_server_tests.rs b/tests/http_server_tests.rs index 7d1bde6..e7ce99c 100644 --- a/tests/http_server_tests.rs +++ b/tests/http_server_tests.rs @@ -5,19 +5,33 @@ use axum::{ http::{Request, StatusCode}, Router, }; +use base64::{engine::general_purpose::STANDARD, Engine}; +use rand::RngCore; use rivetdb::http::app_server::{ - AppServer, PATH_CONNECTIONS, PATH_CONNECTION_DISCOVER, PATH_QUERY, PATH_TABLES, + AppServer, PATH_CONNECTIONS, PATH_CONNECTION_DISCOVER, PATH_QUERY, PATH_SECRET, PATH_SECRETS, + PATH_TABLES, }; use rivetdb::RivetEngine; use serde_json::json; use tempfile::TempDir; use tower::util::ServiceExt; +/// Generate a test secret key (base64-encoded 32 bytes) +fn generate_test_secret_key() -> String { + let mut key = [0u8; 32]; + rand::thread_rng().fill_bytes(&mut key); + STANDARD.encode(key) +} + /// Create test router with in-memory engine async fn setup_test() -> Result<(Router, TempDir)> { let temp_dir = tempfile::tempdir()?; - let engine = RivetEngine::defaults(temp_dir.path()).await?; + let engine = RivetEngine::builder() + .base_dir(temp_dir.path()) + .secret_key(generate_test_secret_key()) + .build() + .await?; let app = AppServer::new(engine); @@ -451,6 +465,405 @@ async fn test_create_connection_missing_fields() -> Result<()> { Ok(()) } +// ==================== Secret Endpoint Tests ==================== + +#[tokio::test(flavor = "multi_thread")] +async fn test_list_secrets_empty() -> Result<()> { + let (app, _tempdir) = setup_test().await?; + + let response = app + .oneshot( + Request::builder() + .method("GET") + .uri(PATH_SECRETS) + .body(Body::empty())?, + ) + .await?; + + assert_eq!(response.status(), StatusCode::OK); + + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await?; + let json: serde_json::Value = serde_json::from_slice(&body)?; + + assert!(json["secrets"].is_array()); + assert_eq!(json["secrets"].as_array().unwrap().len(), 0); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_create_and_get_secret() -> Result<()> { + let temp_dir = tempfile::tempdir()?; + let engine = RivetEngine::builder() + .base_dir(temp_dir.path()) + .secret_key(generate_test_secret_key()) + .build() + .await?; + let app = AppServer::new(engine); + + // Create a secret + let response = app + .router + .clone() + .oneshot( + Request::builder() + .method("POST") + .uri(PATH_SECRETS) + .header("content-type", "application/json") + .body(Body::from(serde_json::to_string(&json!({ + "name": "test_secret", + "value": "super_secret_value" + }))?))?, + ) + .await?; + + assert_eq!(response.status(), StatusCode::CREATED); + + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await?; + let json: serde_json::Value = serde_json::from_slice(&body)?; + + assert_eq!(json["name"], "test_secret"); + assert!(json["created_at"].is_string()); + + // Fetch the secret metadata + let secret_path = PATH_SECRET.replace("{name}", "test_secret"); + let response = app + .router + .clone() + .oneshot( + Request::builder() + .method("GET") + .uri(&secret_path) + .body(Body::empty())?, + ) + .await?; + + assert_eq!(response.status(), StatusCode::OK); + + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await?; + let json: serde_json::Value = serde_json::from_slice(&body)?; + + assert_eq!(json["name"], "test_secret"); + assert!(json["created_at"].is_string()); + assert!(json["updated_at"].is_string()); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_list_secrets_after_create() -> Result<()> { + let temp_dir = tempfile::tempdir()?; + let engine = RivetEngine::builder() + .base_dir(temp_dir.path()) + .secret_key(generate_test_secret_key()) + .build() + .await?; + let app = AppServer::new(engine); + + // Create two secrets + for name in ["secret_one", "secret_two"] { + let response = app + .router + .clone() + .oneshot( + Request::builder() + .method("POST") + .uri(PATH_SECRETS) + .header("content-type", "application/json") + .body(Body::from(serde_json::to_string(&json!({ + "name": name, + "value": format!("value_for_{}", name) + }))?))?, + ) + .await?; + assert_eq!(response.status(), StatusCode::CREATED); + } + + // List secrets + let response = app + .router + .clone() + .oneshot( + Request::builder() + .method("GET") + .uri(PATH_SECRETS) + .body(Body::empty())?, + ) + .await?; + + assert_eq!(response.status(), StatusCode::OK); + + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await?; + let json: serde_json::Value = serde_json::from_slice(&body)?; + + let secrets = json["secrets"].as_array().unwrap(); + assert_eq!(secrets.len(), 2); + + let names: Vec<&str> = secrets + .iter() + .map(|s| s["name"].as_str().unwrap()) + .collect(); + assert!(names.contains(&"secret_one")); + assert!(names.contains(&"secret_two")); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_delete_secret() -> Result<()> { + let temp_dir = tempfile::tempdir()?; + let engine = RivetEngine::builder() + .base_dir(temp_dir.path()) + .secret_key(generate_test_secret_key()) + .build() + .await?; + let app = AppServer::new(engine); + + // Create a secret + let response = app + .router + .clone() + .oneshot( + Request::builder() + .method("POST") + .uri(PATH_SECRETS) + .header("content-type", "application/json") + .body(Body::from(serde_json::to_string(&json!({ + "name": "to_delete", + "value": "will_be_deleted" + }))?))?, + ) + .await?; + assert_eq!(response.status(), StatusCode::CREATED); + + // Delete the secret + let secret_path = PATH_SECRET.replace("{name}", "to_delete"); + let response = app + .router + .clone() + .oneshot( + Request::builder() + .method("DELETE") + .uri(&secret_path) + .body(Body::empty())?, + ) + .await?; + + assert_eq!(response.status(), StatusCode::NO_CONTENT); + + // Verify it's gone by listing + let response = app + .router + .clone() + .oneshot( + Request::builder() + .method("GET") + .uri(PATH_SECRETS) + .body(Body::empty())?, + ) + .await?; + + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await?; + let json: serde_json::Value = serde_json::from_slice(&body)?; + + assert_eq!(json["secrets"].as_array().unwrap().len(), 0); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_get_nonexistent_secret() -> Result<()> { + let (app, _tempdir) = setup_test().await?; + + let secret_path = PATH_SECRET.replace("{name}", "does_not_exist"); + let response = app + .oneshot( + Request::builder() + .method("GET") + .uri(&secret_path) + .body(Body::empty())?, + ) + .await?; + + assert_eq!(response.status(), StatusCode::NOT_FOUND); + + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await?; + let json: serde_json::Value = serde_json::from_slice(&body)?; + + assert!(json["error"]["message"] + .as_str() + .unwrap() + .contains("not found")); + assert_eq!(json["error"]["code"], "NOT_FOUND"); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_delete_nonexistent_secret() -> Result<()> { + let (app, _tempdir) = setup_test().await?; + + let secret_path = PATH_SECRET.replace("{name}", "does_not_exist"); + let response = app + .oneshot( + Request::builder() + .method("DELETE") + .uri(&secret_path) + .body(Body::empty())?, + ) + .await?; + + assert_eq!(response.status(), StatusCode::NOT_FOUND); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_create_secret_missing_fields() -> Result<()> { + let (app, _tempdir) = setup_test().await?; + + let response = app + .oneshot( + Request::builder() + .method("POST") + .uri(PATH_SECRETS) + .header("content-type", "application/json") + .body(Body::from(serde_json::to_string(&json!({ + "name": "test_secret" + // missing "value" field + }))?))?, + ) + .await?; + + // Axum returns UNPROCESSABLE_ENTITY (422) when required fields are missing + assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_update_secret() -> Result<()> { + let temp_dir = tempfile::tempdir()?; + let engine = RivetEngine::builder() + .base_dir(temp_dir.path()) + .secret_key(generate_test_secret_key()) + .build() + .await?; + let app = AppServer::new(engine); + + // Create a secret + let response = app + .router + .clone() + .oneshot( + Request::builder() + .method("POST") + .uri(PATH_SECRETS) + .header("content-type", "application/json") + .body(Body::from(serde_json::to_string(&json!({ + "name": "test_secret", + "value": "original_value" + }))?))?, + ) + .await?; + + assert_eq!(response.status(), StatusCode::CREATED); + + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await?; + let create_json: serde_json::Value = serde_json::from_slice(&body)?; + let created_at = create_json["created_at"].as_str().unwrap(); + + // Small delay to ensure updated_at will be different + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + + // Update the secret + let secret_path = PATH_SECRET.replace("{name}", "test_secret"); + let response = app + .router + .clone() + .oneshot( + Request::builder() + .method("PUT") + .uri(&secret_path) + .header("content-type", "application/json") + .body(Body::from(serde_json::to_string(&json!({ + "value": "updated_value" + }))?))?, + ) + .await?; + + assert_eq!(response.status(), StatusCode::OK); + + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await?; + let update_json: serde_json::Value = serde_json::from_slice(&body)?; + + assert_eq!(update_json["name"], "test_secret"); + assert!(update_json["updated_at"].is_string()); + + // Verify updated_at is different from created_at + let updated_at = update_json["updated_at"].as_str().unwrap(); + assert_ne!( + created_at, updated_at, + "updated_at should change after update" + ); + + // Verify the new value can be retrieved via the manager + let secret_value = app.engine.secret_manager().get("test_secret").await?; + assert_eq!(secret_value, b"updated_value"); + + // Verify metadata via GET shows updated timestamps + let response = app + .router + .clone() + .oneshot( + Request::builder() + .method("GET") + .uri(&secret_path) + .body(Body::empty())?, + ) + .await?; + + assert_eq!(response.status(), StatusCode::OK); + + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await?; + let get_json: serde_json::Value = serde_json::from_slice(&body)?; + + assert_eq!(get_json["updated_at"].as_str().unwrap(), updated_at); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_update_nonexistent_secret() -> Result<()> { + let (app, _tempdir) = setup_test().await?; + + let secret_path = PATH_SECRET.replace("{name}", "does_not_exist"); + let response = app + .oneshot( + Request::builder() + .method("PUT") + .uri(&secret_path) + .header("content-type", "application/json") + .body(Body::from(serde_json::to_string(&json!({ + "value": "some_value" + }))?))?, + ) + .await?; + + assert_eq!(response.status(), StatusCode::NOT_FOUND); + + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await?; + let json: serde_json::Value = serde_json::from_slice(&body)?; + + assert!(json["error"]["message"] + .as_str() + .unwrap() + .contains("not found")); + assert_eq!(json["error"]["code"], "NOT_FOUND"); + + Ok(()) +} + // ==================== Decoupled Registration/Discovery Tests ==================== #[tokio::test(flavor = "multi_thread")] diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 8c26fa1..3bd236e 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -299,12 +299,20 @@ impl TestExecutor for ApiExecutor { host, port, user, - password, database, - } => ( - "postgres", - json!({ "host": host, "port": port, "user": user, "password": password, "database": database }), - ), + credential, + } => { + let cred_json = match credential { + rivetdb::source::Credential::None => json!({"type": "none"}), + rivetdb::source::Credential::SecretRef { name } => { + json!({"type": "secret_ref", "name": name}) + } + }; + ( + "postgres", + json!({ "host": host, "port": port, "user": user, "database": database, "credential": cred_json }), + ) + } _ => panic!("Unsupported source type"), }; @@ -473,6 +481,13 @@ impl TestExecutor for ApiExecutor { } } +/// Generate a random base64-encoded 32-byte key for test secret manager. +fn generate_test_secret_key() -> String { + use base64::Engine; + let key_bytes: [u8; 32] = rand::random(); + base64::engine::general_purpose::STANDARD.encode(key_bytes) +} + /// Test context providing both executors. struct TestHarness { engine_executor: EngineExecutor, @@ -485,7 +500,15 @@ impl TestHarness { async fn new() -> Self { let temp_dir = TempDir::new().unwrap(); - let engine = RivetEngine::defaults(temp_dir.path()).await.unwrap(); + // Generate a test secret key to enable the secret manager + let secret_key = generate_test_secret_key(); + + let engine = RivetEngine::builder() + .base_dir(temp_dir.path()) + .secret_key(secret_key) + .build() + .await + .unwrap(); let app = AppServer::new(engine); @@ -503,6 +526,15 @@ impl TestHarness { fn api(&self) -> &dyn TestExecutor { &self.api_executor } + + /// Store a secret for use in connection credentials. + async fn store_secret(&self, name: &str, value: &str) { + let secret_manager = self.engine_executor.engine.secret_manager(); + secret_manager + .create(name, value.as_bytes()) + .await + .expect("Failed to store test secret"); + } } // ============================================================================ @@ -726,6 +758,9 @@ mod postgres_fixtures { use testcontainers::{runners::AsyncRunner, ContainerAsync, ImageExt}; use testcontainers_modules::postgres::Postgres; + /// The password used for test Postgres containers. + pub const TEST_PASSWORD: &str = "postgres"; + pub struct PostgresFixture { #[allow(dead_code)] pub container: ContainerAsync, @@ -739,11 +774,14 @@ mod postgres_fixtures { .await .expect("Failed to start postgres"); let port = container.get_host_port_ipv4(5432).await.unwrap(); - let conn_str = format!("postgres://postgres:postgres@localhost:{}/postgres", port); + let conn_str = format!( + "postgres://postgres:{}@localhost:{}/postgres", + TEST_PASSWORD, port + ); (container, conn_str) } - pub async fn standard() -> PostgresFixture { + pub async fn standard(secret_name: &str) -> PostgresFixture { let (container, conn_str) = start_container().await; let pool = sqlx::PgPool::connect(&conn_str).await.unwrap(); @@ -766,13 +804,15 @@ mod postgres_fixtures { host: "localhost".into(), port, user: "postgres".into(), - password: "postgres".into(), database: "postgres".into(), + credential: rivetdb::source::Credential::SecretRef { + name: secret_name.to_string(), + }, }, } } - pub async fn multi_schema() -> PostgresFixture { + pub async fn multi_schema(secret_name: &str) -> PostgresFixture { let (container, conn_str) = start_container().await; let pool = sqlx::PgPool::connect(&conn_str).await.unwrap(); @@ -809,8 +849,10 @@ mod postgres_fixtures { host: "localhost".into(), port, user: "postgres".into(), - password: "postgres".into(), database: "postgres".into(), + credential: rivetdb::source::Credential::SecretRef { + name: secret_name.to_string(), + }, }, } } @@ -901,31 +943,46 @@ mod duckdb_tests { mod postgres_tests { use super::*; + const PG_SECRET_NAME: &str = "pg-test-password"; + #[tokio::test(flavor = "multi_thread")] async fn test_engine_golden_path() { - let fixture = postgres_fixtures::standard().await; let harness = TestHarness::new().await; + // Store the password as a secret before creating the fixture + harness + .store_secret(PG_SECRET_NAME, postgres_fixtures::TEST_PASSWORD) + .await; + let fixture = postgres_fixtures::standard(PG_SECRET_NAME).await; run_golden_path_test(harness.engine(), &fixture.source, "pg_conn").await; } #[tokio::test(flavor = "multi_thread")] async fn test_api_golden_path() { - let fixture = postgres_fixtures::standard().await; let harness = TestHarness::new().await; + harness + .store_secret(PG_SECRET_NAME, postgres_fixtures::TEST_PASSWORD) + .await; + let fixture = postgres_fixtures::standard(PG_SECRET_NAME).await; run_golden_path_test(harness.api(), &fixture.source, "pg_conn").await; } #[tokio::test(flavor = "multi_thread")] async fn test_engine_multi_schema() { - let fixture = postgres_fixtures::multi_schema().await; let harness = TestHarness::new().await; + harness + .store_secret(PG_SECRET_NAME, postgres_fixtures::TEST_PASSWORD) + .await; + let fixture = postgres_fixtures::multi_schema(PG_SECRET_NAME).await; run_multi_schema_test(harness.engine(), &fixture.source, "pg_conn").await; } #[tokio::test(flavor = "multi_thread")] async fn test_api_multi_schema() { - let fixture = postgres_fixtures::multi_schema().await; let harness = TestHarness::new().await; + harness + .store_secret(PG_SECRET_NAME, postgres_fixtures::TEST_PASSWORD) + .await; + let fixture = postgres_fixtures::multi_schema(PG_SECRET_NAME).await; run_multi_schema_test(harness.api(), &fixture.source, "pg_conn").await; } }