From 0c757b4751e050bc5a38d27894b775a2f4c8d6ee Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Fri, 14 Feb 2025 10:58:02 -0800 Subject: [PATCH 01/30] deps: install utoipa --- rust/frontend/Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/rust/frontend/Cargo.toml b/rust/frontend/Cargo.toml index 9d0321c8b65..9848121bf7b 100644 --- a/rust/frontend/Cargo.toml +++ b/rust/frontend/Cargo.toml @@ -41,3 +41,4 @@ chroma-system = { workspace = true } chroma-tracing = { workspace = true } chroma-types = { workspace = true } chroma-sqlite = { workspace = true } +utoipa-axum = "0.2.0" From 87ce3b88b897bfefd4a675a13afa0db2ea86a510 Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Fri, 14 Feb 2025 11:12:18 -0800 Subject: [PATCH 02/30] api docs: minimal integration --- Cargo.lock | 201 ++++++++++++++++++++++++++++++++++-- rust/frontend/Cargo.toml | 2 + rust/frontend/src/server.rs | 19 ++++ 3 files changed, 215 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bb591a81c71..e2e246f6752 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -128,6 +128,15 @@ version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" +[[package]] +name = "arbitrary" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dde20b3d026af13f561bdd0f15edf01fc734f0dafcedbaf42bba506a9517f223" +dependencies = [ + "derive_arbitrary", +] + [[package]] name = "arc-swap" version = "1.7.1" @@ -867,7 +876,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" dependencies = [ "async-trait", - "axum-core", + "axum-core 0.4.5", "axum-macros", "bytes", "futures-util", @@ -877,7 +886,7 @@ dependencies = [ "hyper 1.5.1", "hyper-util", "itoa", - "matchit", + "matchit 0.7.3", "memchr", "mime", "percent-encoding", @@ -895,6 +904,32 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d6fd624c75e18b3b4c6b9caf42b1afe24437daaee904069137d8bab077be8b8" +dependencies = [ + "axum-core 0.5.0", + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "itoa", + "matchit 0.8.4", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "sync_wrapper 1.0.2", + "tower 0.5.2", + "tower-layer", + "tower-service", +] + [[package]] name = "axum-core" version = "0.4.5" @@ -916,6 +951,25 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-core" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1362f362fd16024ae199c1970ce98f9661bf5ef94b9808fee734bc3698b733" +dependencies = [ + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper 1.0.2", + "tower-layer", + "tower-service", +] + [[package]] name = "axum-macros" version = "0.4.2" @@ -1103,9 +1157,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.15.4" +version = "3.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ff69b9dd49fd426c69a0db9fc04dd934cdb6645ff000864d98f7e2af8830eaa" +checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" [[package]] name = "bytemuck" @@ -1326,7 +1380,7 @@ name = "chroma-frontend" version = "0.7.0" dependencies = [ "async-trait", - "axum", + "axum 0.7.9", "backon", "chroma-cache", "chroma-config", @@ -1355,6 +1409,9 @@ dependencies = [ "tower 0.4.13", "tower-http 0.6.2", "tracing", + "utoipa", + "utoipa-axum", + "utoipa-swagger-ui", "uuid", "validator", ] @@ -1399,7 +1456,7 @@ name = "chroma-load" version = "0.1.0" dependencies = [ "async-trait", - "axum", + "axum 0.7.9", "chromadb", "chrono", "clap", @@ -2156,6 +2213,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "derive_arbitrary" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30542c1ad912e0e3d22a1935c290e12e8a29d704a420177a31faad4a601a0800" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + [[package]] name = "derive_utils" version = "0.14.2" @@ -3445,6 +3513,7 @@ checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", "hashbrown 0.15.2", + "serde", ] [[package]] @@ -3885,6 +3954,12 @@ dependencies = [ "serde", ] +[[package]] +name = "lockfree-object-pool" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9374ef4228402d4b7e403e5838cb880d9ee663314b0a900d5a6aabf0c213552e" + [[package]] name = "log" version = "0.4.22" @@ -4009,6 +4084,12 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "matrixmultiply" version = "0.3.9" @@ -4084,6 +4165,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -6203,6 +6294,12 @@ dependencies = [ "rand_core", ] +[[package]] +name = "simd-adler32" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" + [[package]] name = "siphasher" version = "1.0.1" @@ -7063,7 +7160,7 @@ checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52" dependencies = [ "async-stream", "async-trait", - "axum", + "axum 0.7.9", "base64 0.22.1", "bytes", "h2 0.4.7", @@ -7334,6 +7431,12 @@ dependencies = [ "version_check", ] +[[package]] +name = "unicase" +version = "2.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" + [[package]] name = "unicode-bidi" version = "0.3.18" @@ -7432,6 +7535,59 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "utoipa" +version = "5.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "435c6f69ef38c9017b4b4eea965dfb91e71e53d869e896db40d1cf2441dd75c0" +dependencies = [ + "indexmap 2.6.0", + "serde", + "serde_json", + "utoipa-gen", +] + +[[package]] +name = "utoipa-axum" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c25bae5bccc842449ec0c5ddc5cbb6a3a1eaeac4503895dc105a1138f8234a0" +dependencies = [ + "axum 0.8.1", + "paste", + "tower-layer", + "tower-service", + "utoipa", +] + +[[package]] +name = "utoipa-gen" +version = "5.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a77d306bc75294fd52f3e99b13ece67c02c1a2789190a6f31d32f736624326f7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + +[[package]] +name = "utoipa-swagger-ui" +version = "9.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "161166ec520c50144922a625d8bc4925cc801b2dda958ab69878527c0e5c5d61" +dependencies = [ + "base64 0.22.1", + "mime_guess", + "regex", + "rust-embed", + "serde", + "serde_json", + "url", + "utoipa", + "zip", +] + [[package]] name = "uuid" version = "1.11.0" @@ -8223,6 +8379,37 @@ dependencies = [ "syn 2.0.89", ] +[[package]] +name = "zip" +version = "2.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae9c1ea7b3a5e1f4b922ff856a129881167511563dc219869afe3787fc0c1a45" +dependencies = [ + "arbitrary", + "crc32fast", + "crossbeam-utils", + "displaydoc", + "flate2", + "indexmap 2.6.0", + "memchr", + "thiserror 2.0.4", + "zopfli", +] + +[[package]] +name = "zopfli" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5019f391bac5cf252e93bbcc53d039ffd62c7bfb7c150414d61369afe57e946" +dependencies = [ + "bumpalo", + "crc32fast", + "lockfree-object-pool", + "log", + "once_cell", + "simd-adler32", +] + [[package]] name = "zstd" version = "0.13.0" diff --git a/rust/frontend/Cargo.toml b/rust/frontend/Cargo.toml index 9848121bf7b..fc017223e11 100644 --- a/rust/frontend/Cargo.toml +++ b/rust/frontend/Cargo.toml @@ -41,4 +41,6 @@ chroma-system = { workspace = true } chroma-tracing = { workspace = true } chroma-types = { workspace = true } chroma-sqlite = { workspace = true } +utoipa = "5.0.0" utoipa-axum = "0.2.0" +utoipa-swagger-ui = "9.0.0" \ No newline at end of file diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index 2106e631707..6addd8734ce 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -30,6 +30,7 @@ use std::sync::{ Arc, }; use uuid::Uuid; +use utoipa::OpenApi; use crate::{ ac::AdmissionControlledService, @@ -204,6 +205,7 @@ impl FrontendServer { "/api/v2/tenants/:tenant_id/databases/:database_name/collections/:collection_id/query", post(collection_query), ) + .route("/openapi.json", get(openapi)) .with_state(server) .layer(DefaultBodyLimit::max(6000000)); // TODO: add to server configuration let app = add_tracing_middleware(app); @@ -1254,3 +1256,20 @@ async fn v1_deprecation_notice() -> Response { ); (StatusCode::GONE, Json(err_response)).into_response() } + +#[derive(OpenApi)] +#[openapi( + paths(openapi) +)] +struct ApiDoc; + +#[utoipa::path( + get, + path = "/openapi.json", + responses( + (status = 200, description = "JSON file", body = ()) + ) +)] +async fn openapi() -> axum::Json { + axum::Json(ApiDoc::openapi()) +} From 0d6827b6b86494f65b29cf79b86b078138f02d5a Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Fri, 14 Feb 2025 15:13:08 -0800 Subject: [PATCH 03/30] openapi: basic compiling example --- Cargo.lock | 30 +++++++----- Cargo.toml | 4 +- rust/frontend/Cargo.toml | 6 +-- rust/frontend/src/server.rs | 81 ++++++++++++++++++++++++------- rust/frontend/src/types/errors.rs | 10 +++- rust/types/Cargo.toml | 2 +- rust/types/src/api_types.rs | 27 +++++++++-- 7 files changed, 120 insertions(+), 40 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e2e246f6752..07d53b7291e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -877,14 +877,11 @@ checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" dependencies = [ "async-trait", "axum-core 0.4.5", - "axum-macros", "bytes", "futures-util", "http 1.1.0", "http-body 1.0.1", "http-body-util", - "hyper 1.5.1", - "hyper-util", "itoa", "matchit 0.7.3", "memchr", @@ -893,15 +890,10 @@ dependencies = [ "pin-project-lite", "rustversion", "serde", - "serde_json", - "serde_path_to_error", - "serde_urlencoded", "sync_wrapper 1.0.2", - "tokio", "tower 0.5.2", "tower-layer", "tower-service", - "tracing", ] [[package]] @@ -911,11 +903,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d6fd624c75e18b3b4c6b9caf42b1afe24437daaee904069137d8bab077be8b8" dependencies = [ "axum-core 0.5.0", + "axum-macros", "bytes", + "form_urlencoded", "futures-util", "http 1.1.0", "http-body 1.0.1", "http-body-util", + "hyper 1.5.1", + "hyper-util", "itoa", "matchit 0.8.4", "memchr", @@ -924,10 +920,15 @@ dependencies = [ "pin-project-lite", "rustversion", "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", "sync_wrapper 1.0.2", + "tokio", "tower 0.5.2", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -948,7 +949,6 @@ dependencies = [ "sync_wrapper 1.0.2", "tower-layer", "tower-service", - "tracing", ] [[package]] @@ -968,13 +968,14 @@ dependencies = [ "sync_wrapper 1.0.2", "tower-layer", "tower-service", + "tracing", ] [[package]] name = "axum-macros" -version = "0.4.2" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57d123550fa8d071b7255cb0cc04dc302baa6c8c4a79f55701552684d8399bce" +checksum = "604fde5e028fea851ce1d8570bbdc034bec850d157f7569d10f347d06808c05c" dependencies = [ "proc-macro2", "quote", @@ -1380,7 +1381,7 @@ name = "chroma-frontend" version = "0.7.0" dependencies = [ "async-trait", - "axum 0.7.9", + "axum 0.8.1", "backon", "chroma-cache", "chroma-config", @@ -1456,7 +1457,7 @@ name = "chroma-load" version = "0.1.0" dependencies = [ "async-trait", - "axum 0.7.9", + "axum 0.8.1", "chromadb", "chrono", "clap", @@ -1699,6 +1700,7 @@ dependencies = [ "tokio", "tonic", "tonic-build", + "utoipa", "uuid", "validator", ] @@ -7568,6 +7570,7 @@ checksum = "a77d306bc75294fd52f3e99b13ece67c02c1a2789190a6f31d32f736624326f7" dependencies = [ "proc-macro2", "quote", + "regex", "syn 2.0.89", ] @@ -7577,6 +7580,7 @@ version = "9.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "161166ec520c50144922a625d8bc4925cc801b2dda958ab69878527c0e5c5d61" dependencies = [ + "axum 0.8.1", "base64 0.22.1", "mime_guess", "regex", diff --git a/Cargo.toml b/Cargo.toml index ee213e20955..f7d3c0cd210 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,8 @@ members = ["rust/benchmark", "rust/blockstore", "rust/cache", "rust/chroma", "ru [workspace.dependencies] arrow = "52.2.0" async-trait = "0.1" -axum = { version = "0.7", features = ["macros"] } + +axum = { version = "0.8", features = ["macros"] } chrono = { version = "0.4", features = ["serde"] } clap = { version = "4", features = ["derive"] } criterion = { version = "0.5", features = ["async_tokio"] } @@ -38,6 +39,7 @@ tracing-bunyan-formatter = "0.3" tracing-opentelemetry = "0.28.0" tracing-subscriber = { version = "0.3", features = ["env-filter"] } uuid = { version = "1.11.0", features = ["v4", "fast-rng", "macro-diagnostics", "serde"] } +utoipa = { version = "5.0.0", features = ["axum_extras", "debug"] } sqlx = { version = "0.8.3", features = ["runtime-tokio", "sqlite"] } sha2 = "0.10.8" md5 = "0.7.0" diff --git a/rust/frontend/Cargo.toml b/rust/frontend/Cargo.toml index fc017223e11..83d138dd76c 100644 --- a/rust/frontend/Cargo.toml +++ b/rust/frontend/Cargo.toml @@ -41,6 +41,6 @@ chroma-system = { workspace = true } chroma-tracing = { workspace = true } chroma-types = { workspace = true } chroma-sqlite = { workspace = true } -utoipa = "5.0.0" -utoipa-axum = "0.2.0" -utoipa-swagger-ui = "9.0.0" \ No newline at end of file +utoipa = { workspace = true } +utoipa-axum = { version = "0.2.0", features = ["debug"] } +utoipa-swagger-ui = { version = "9", features = ["axum"] } diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index 6addd8734ce..5c83f6c2af6 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -29,8 +29,11 @@ use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; -use uuid::Uuid; use utoipa::OpenApi; +use utoipa_axum::router::OpenApiRouter; +use utoipa_axum::routes; +use utoipa_swagger_ui::SwaggerUi; +use uuid::Uuid; use crate::{ ac::AdmissionControlledService, @@ -152,6 +155,14 @@ impl FrontendServer { #[allow(dead_code)] pub async fn run(server: FrontendServer) { let circuit_breaker_config = server.config.circuit_breaker.clone(); + + // Build an OpenApiRouter with only the healthcheck endpoint + let (docs_router, docs_api) = OpenApiRouter::with_openapi(ApiDoc::openapi()) + .routes(routes!(healthcheck)) + .split_for_parts(); + + let docs_router = docs_router.merge(SwaggerUi::new("/docs").url("/openapi.json", docs_api)); + let app = Router::new() // `GET /` goes to `root` .route("/api/v1/*any", get(v1_deprecation_notice).put(v1_deprecation_notice).patch(v1_deprecation_notice).delete(v1_deprecation_notice).head(v1_deprecation_notice).options(v1_deprecation_notice)) @@ -206,6 +217,7 @@ impl FrontendServer { post(collection_query), ) .route("/openapi.json", get(openapi)) + .merge(docs_router) .with_state(server) .layer(DefaultBodyLimit::max(6000000)); // TODO: add to server configuration let app = add_tracing_middleware(app); @@ -256,6 +268,13 @@ impl FrontendServer { // These handlers simply proxy the call and the relevant inputs into // the appropriate method on the `FrontendServer` struct. +#[utoipa::path( + method(get, head), + path = "/api/v2/healthcheck", + responses( + (status = 200, description = "Success", body = str, content_type = "text/plain") + ) +)] async fn healthcheck(State(server): State) -> impl IntoResponse { server.metrics.healthcheck.add(1, &[]); let res = server.frontend.healthcheck().await; @@ -263,15 +282,26 @@ async fn healthcheck(State(server): State) -> impl IntoResponse tonic::Code::Ok => StatusCode::OK, _ => StatusCode::SERVICE_UNAVAILABLE, }; - (code, Json(res)) } -async fn heartbeat( - State(server): State, -) -> Result, ServerError> { +#[utoipa::path( + get, + path = "/api/v2/heartbeat", + responses( + (status = 200, description = "Success", body = HeartbeatResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ) +)] +async fn heartbeat(State(server): State) -> impl IntoResponse { server.metrics.heartbeat.add(1, &[]); - Ok(Json(server.frontend.heartbeat().await?)) + match server.frontend.heartbeat().await { + Ok(response) => (StatusCode::OK, Json(response)).into_response(), + Err(err) => { + let error = ErrorResponse::new("HeartbeatError".to_string(), err.to_string()); + (StatusCode::INTERNAL_SERVER_ERROR, Json(error)).into_response() + } + } } // Dummy implementation for now @@ -284,12 +314,18 @@ async fn pre_flight_checks( })) } -async fn reset( - headers: HeaderMap, - State(mut server): State, -) -> Result, ServerError> { +#[utoipa::path( + post, + path = "/api/v2/reset", + responses( + (status = 200, description = "Reset successful", body = bool), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ) +)] +async fn reset(headers: HeaderMap, State(mut server): State) -> impl IntoResponse { server.metrics.reset.add(1, &[]); - server + match server .authenticate_and_authorize( &headers, AuthzAction::Reset, @@ -299,9 +335,20 @@ async fn reset( collection: None, }, ) - .await?; - server.frontend.reset().await?; - Ok(Json(true)) + .await + { + Err(auth_err) => { + let error = ErrorResponse::new("AuthError".to_string(), auth_err.to_string()); + (StatusCode::UNAUTHORIZED, Json(error)).into_response() + } + Ok(_) => match server.frontend.reset().await { + Ok(_) => (StatusCode::OK, Json(true)).into_response(), + Err(reset_err) => { + let error = ErrorResponse::new("ResetError".to_string(), reset_err.to_string()); + (StatusCode::INTERNAL_SERVER_ERROR, Json(error)).into_response() + } + }, + } } async fn version(State(server): State) -> &'static str { @@ -497,7 +544,7 @@ async fn delete_database( Ok(Json(server.frontend.delete_database(request).await?)) } -#[derive(Deserialize)] +#[derive(Deserialize, Debug)] struct ListCollectionsParams { limit: Option, #[serde(default)] @@ -1258,9 +1305,7 @@ async fn v1_deprecation_notice() -> Response { } #[derive(OpenApi)] -#[openapi( - paths(openapi) -)] +#[openapi(paths(healthcheck))] struct ApiDoc; #[utoipa::path( diff --git a/rust/frontend/src/types/errors.rs b/rust/frontend/src/types/errors.rs index fa513530477..214fdc859c0 100644 --- a/rust/frontend/src/types/errors.rs +++ b/rust/frontend/src/types/errors.rs @@ -6,7 +6,9 @@ use axum::{ use chroma_error::{ChromaError, ErrorCodes}; use chroma_types::{GetCollectionError, UpdateCollectionError}; use serde::Serialize; +use std::fmt; use thiserror::Error; +use utoipa::ToSchema; #[derive(Error, Debug)] pub enum ValidationError { @@ -52,7 +54,13 @@ impl From for ServerError { } } -#[derive(Serialize)] +impl fmt::Display for ServerError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +#[derive(Serialize, ToSchema)] pub struct ErrorResponse { error: String, message: String, diff --git a/rust/types/Cargo.toml b/rust/types/Cargo.toml index 26f18dc9ee7..94147889295 100644 --- a/rust/types/Cargo.toml +++ b/rust/types/Cargo.toml @@ -19,7 +19,7 @@ tokio = { workspace = true } pyo3 = { workspace = true } validator = { workspace = true } regex = { workspace = true } - +utoipa = { workspace = true } # (Cross-crate testing dependencies) proptest = { workspace = true, optional = true } proptest-derive = { workspace = true, optional = true } diff --git a/rust/types/src/api_types.rs b/rust/types/src/api_types.rs index 383c8e5106b..eab279fd59f 100644 --- a/rust/types/src/api_types.rs +++ b/rust/types/src/api_types.rs @@ -26,6 +26,7 @@ use serde_json::Value; use std::time::SystemTimeError; use thiserror::Error; use tonic::Status; +use utoipa::ToSchema; use uuid::Uuid; use validator::Validate; use validator::ValidationError; @@ -118,16 +119,36 @@ pub struct ChecklistResponse { pub max_batch_size: u32, } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] pub struct HeartbeatResponse { #[serde(rename(serialize = "nanosecond heartbeat"))] pub nanosecond_heartbeat: u128, } -#[derive(Debug, Error)] +#[derive(Debug, Error, Serialize, ToSchema)] +#[error("system time error: {message}")] +pub struct ChromaSystemTimeError { + message: String, +} + +impl From for ChromaSystemTimeError { + fn from(err: SystemTimeError) -> Self { + Self { + message: err.to_string(), + } + } +} + +#[derive(Debug, Error, ToSchema)] pub enum HeartbeatError { #[error(transparent)] - CouldNotGetTime(#[from] SystemTimeError), + CouldNotGetTime(#[from] ChromaSystemTimeError), +} + +impl From for HeartbeatError { + fn from(err: SystemTimeError) -> Self { + HeartbeatError::CouldNotGetTime(ChromaSystemTimeError::from(err)) + } } impl ChromaError for HeartbeatError { From bb530e07f33ceb507c80c537e1b056d591400ecc Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Fri, 14 Feb 2025 15:29:25 -0800 Subject: [PATCH 04/30] frontend: refactor routes for axum 0.8 --- rust/frontend/src/server.rs | 39 ++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index 5c83f6c2af6..51e37612075 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -157,15 +157,19 @@ impl FrontendServer { let circuit_breaker_config = server.config.circuit_breaker.clone(); // Build an OpenApiRouter with only the healthcheck endpoint - let (docs_router, docs_api) = OpenApiRouter::with_openapi(ApiDoc::openapi()) - .routes(routes!(healthcheck)) - .split_for_parts(); + let (docs_router, docs_api) = + OpenApiRouter::with_openapi(ApiDoc::openapi()).split_for_parts(); let docs_router = docs_router.merge(SwaggerUi::new("/docs").url("/openapi.json", docs_api)); let app = Router::new() // `GET /` goes to `root` - .route("/api/v1/*any", get(v1_deprecation_notice).put(v1_deprecation_notice).patch(v1_deprecation_notice).delete(v1_deprecation_notice).head(v1_deprecation_notice).options(v1_deprecation_notice)) + .route("/api/v1/{*any}", get(v1_deprecation_notice) + .put(v1_deprecation_notice) + .patch(v1_deprecation_notice) + .delete(v1_deprecation_notice) + .head(v1_deprecation_notice) + .options(v1_deprecation_notice)) .route("/api/v2/healthcheck", get(healthcheck)) .route("/api/v2/heartbeat", get(heartbeat)) .route("/api/v2/pre-flight-checks", get(pre_flight_checks)) @@ -173,50 +177,49 @@ impl FrontendServer { .route("/api/v2/version", get(version)) .route("/api/v2/auth/identity", get(get_user_identity)) .route("/api/v2/tenants", post(create_tenant)) - .route("/api/v2/tenants/:tenant_name", get(get_tenant)) - .route("/api/v2/tenants/:tenant_id/databases", get(list_databases).post(create_database)) - .route("/api/v2/tenants/:tenant_id/databases/:name", get(get_database).delete(delete_database)) + .route("/api/v2/tenants/{tenant_name}", get(get_tenant)) + .route("/api/v2/tenants/{tenant_id}/databases", get(list_databases).post(create_database)) + .route("/api/v2/tenants/{tenant_id}/databases/{name}", get(get_database).delete(delete_database)) .route( - "/api/v2/tenants/:tenant_id/databases/:database_name/collections", + "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections", post(create_collection).get(list_collections), ) .route( - "/api/v2/tenants/:tenant_id/databases/:database_name/collections_count", + "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections_count", get(count_collections), ) .route( - "/api/v2/tenants/:tenant_id/databases/:database_name/collections/:collection_id", + "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections/{collection_id}", get(get_collection).put(update_collection).delete(delete_collection), ) .route( - "/api/v2/tenants/:tenant/databases/:database_name/collections/:collection_id/add", + "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/add", post(collection_add), ) .route( - "/api/v2/tenants/:tenant/databases/:database_name/collections/:collection_id/update", + "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/update", post(collection_update), ) .route( - "/api/v2/tenants/:tenant/databases/:database_name/collections/:collection_id/upsert", + "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/upsert", post(collection_upsert), ) .route( - "/api/v2/tenants/:tenant/databases/:database_name/collections/:collection_id/delete", + "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/delete", post(collection_delete), ) .route( - "/api/v2/tenants/:tenant_id/databases/:database_name/collections/:collection_id/count", + "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections/{collection_id}/count", get(collection_count), ) .route( - "/api/v2/tenants/:tenant_id/databases/:database_name/collections/:collection_id/get", + "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections/{collection_id}/get", post(collection_get), ) .route( - "/api/v2/tenants/:tenant_id/databases/:database_name/collections/:collection_id/query", + "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections/{collection_id}/query", post(collection_query), ) - .route("/openapi.json", get(openapi)) .merge(docs_router) .with_state(server) .layer(DefaultBodyLimit::max(6000000)); // TODO: add to server configuration From 036189ba3d5fb062adc111cce76a01da6af3b921 Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Tue, 18 Feb 2025 14:56:13 -0800 Subject: [PATCH 05/30] frontend: clean up tests + linting --- rust/frontend/src/server.rs | 12 ------------ rust/python_bindings/src/bindings.rs | 2 +- 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index 51e37612075..3cb9957744c 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -31,7 +31,6 @@ use std::sync::{ }; use utoipa::OpenApi; use utoipa_axum::router::OpenApiRouter; -use utoipa_axum::routes; use utoipa_swagger_ui::SwaggerUi; use uuid::Uuid; @@ -1310,14 +1309,3 @@ async fn v1_deprecation_notice() -> Response { #[derive(OpenApi)] #[openapi(paths(healthcheck))] struct ApiDoc; - -#[utoipa::path( - get, - path = "/openapi.json", - responses( - (status = 200, description = "JSON file", body = ()) - ) -)] -async fn openapi() -> axum::Json { - axum::Json(ApiDoc::openapi()) -} diff --git a/rust/python_bindings/src/bindings.rs b/rust/python_bindings/src/bindings.rs index 251dc7342d6..032449c4bc0 100644 --- a/rust/python_bindings/src/bindings.rs +++ b/rust/python_bindings/src/bindings.rs @@ -131,7 +131,7 @@ impl Bindings { fn heartbeat(&self) -> ChromaPyResult { let duration_since_epoch = std::time::SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) - .map_err(HeartbeatError::CouldNotGetTime)?; + .map_err(|err| HeartbeatError::CouldNotGetTime(err.into()))?; Ok(duration_since_epoch.as_nanos()) } From cf31ba6f876d9816d7e5cad807c6a011ff096ebb Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Tue, 18 Feb 2025 15:53:01 -0800 Subject: [PATCH 06/30] openapi: add pre-flight checks and reset to spec --- rust/frontend/src/server.rs | 13 ++++++++++--- rust/types/src/api_types.rs | 8 +++++++- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index 3cb9957744c..e2f90d16d3a 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -271,7 +271,7 @@ impl FrontendServer { // the appropriate method on the `FrontendServer` struct. #[utoipa::path( - method(get, head), + get, path = "/api/v2/healthcheck", responses( (status = 200, description = "Success", body = str, content_type = "text/plain") @@ -306,7 +306,14 @@ async fn heartbeat(State(server): State) -> impl IntoResponse { } } -// Dummy implementation for now +#[utoipa::path( + get, + path = "/api/v2/pre-flight-checks", + responses( + (status = 200, description = "Pre flight checks", body = ChecklistResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ) +)] async fn pre_flight_checks( State(server): State, ) -> Result, ServerError> { @@ -1307,5 +1314,5 @@ async fn v1_deprecation_notice() -> Response { } #[derive(OpenApi)] -#[openapi(paths(healthcheck))] +#[openapi(paths(healthcheck, pre_flight_checks, reset))] struct ApiDoc; diff --git a/rust/types/src/api_types.rs b/rust/types/src/api_types.rs index eab279fd59f..609da7ee434 100644 --- a/rust/types/src/api_types.rs +++ b/rust/types/src/api_types.rs @@ -114,7 +114,7 @@ impl ChromaError for ResetError { } } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] pub struct ChecklistResponse { pub max_batch_size: u32, } @@ -1424,3 +1424,9 @@ impl ChromaError for ExecutorError { } } } + +#[derive(Debug, Serialize, ToSchema)] +pub struct ErrorResponse { + pub error: String, + pub message: String, +} From ec014a454762e195facac2235271af0686b43f33 Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Tue, 18 Feb 2025 16:04:12 -0800 Subject: [PATCH 07/30] openapi: add version and identity endpoints --- rust/frontend/src/server.rs | 24 +++++++++++++++++++++++- rust/types/src/api_types.rs | 2 +- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index e2f90d16d3a..95dc9f611ad 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -360,11 +360,26 @@ async fn reset(headers: HeaderMap, State(mut server): State) -> } } +#[utoipa::path( + get, + path = "/api/v2/version", + responses( + (status = 200, description = "Get server version", body = String) + ) +)] async fn version(State(server): State) -> &'static str { server.metrics.version.add(1, &[]); env!("CARGO_PKG_VERSION") } +#[utoipa::path( + get, + path = "/api/v2/auth/identity", + responses( + (status = 200, description = "Get user identity", body = GetUserIdentityResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ) +)] async fn get_user_identity( headers: HeaderMap, State(server): State, @@ -1314,5 +1329,12 @@ async fn v1_deprecation_notice() -> Response { } #[derive(OpenApi)] -#[openapi(paths(healthcheck, pre_flight_checks, reset))] +#[openapi(paths( + healthcheck, + heartbeat, + pre_flight_checks, + reset, + version, + get_user_identity +))] struct ApiDoc; diff --git a/rust/types/src/api_types.rs b/rust/types/src/api_types.rs index 609da7ee434..4a187621bff 100644 --- a/rust/types/src/api_types.rs +++ b/rust/types/src/api_types.rs @@ -157,7 +157,7 @@ impl ChromaError for HeartbeatError { } } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] pub struct GetUserIdentityResponse { pub user_id: String, pub tenant: String, From 7307bf8d97ed319a6f4c24ed48ed68cd7c039138 Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Tue, 18 Feb 2025 16:21:23 -0800 Subject: [PATCH 08/30] openapi: add create_tenant to spec --- rust/frontend/src/server.rs | 13 ++++++++++++- rust/types/src/api_types.rs | 4 ++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index 95dc9f611ad..15803c7bf33 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -388,6 +388,16 @@ async fn get_user_identity( Ok(Json(server.auth.get_user_identity(&headers).await?)) } +#[utoipa::path( + post, + path = "/api/v2/tenants", + request_body = CreateTenantRequest, + responses( + (status = 200, description = "Tenant created successfully", body = CreateTenantResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ) +)] async fn create_tenant( headers: HeaderMap, State(mut server): State, @@ -1335,6 +1345,7 @@ async fn v1_deprecation_notice() -> Response { pre_flight_checks, reset, version, - get_user_identity + get_user_identity, + create_tenant ))] struct ApiDoc; diff --git a/rust/types/src/api_types.rs b/rust/types/src/api_types.rs index 4a187621bff..24e061fc65f 100644 --- a/rust/types/src/api_types.rs +++ b/rust/types/src/api_types.rs @@ -165,7 +165,7 @@ pub struct GetUserIdentityResponse { } #[non_exhaustive] -#[derive(Deserialize, Validate)] +#[derive(Serialize, Validate, Deserialize, ToSchema)] pub struct CreateTenantRequest { #[validate(length(min = 3))] pub name: String, @@ -179,7 +179,7 @@ impl CreateTenantRequest { } } -#[derive(Serialize)] +#[derive(Serialize, Deserialize, ToSchema)] pub struct CreateTenantResponse {} #[derive(Debug, Error)] From eb3eb22bdb50ba0ea1c11a825864bd8a56e48c9f Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Tue, 18 Feb 2025 16:29:34 -0800 Subject: [PATCH 09/30] openapi: add tenant endpoints --- rust/frontend/src/server.rs | 16 +++++++++++++++- rust/types/src/api_types.rs | 2 +- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index 15803c7bf33..662016eb1e5 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -419,6 +419,19 @@ async fn create_tenant( Ok(Json(server.frontend.create_tenant(request).await?)) } +#[utoipa::path( + get, + path = "/api/v2/tenants/{tenant_name}", + params( + ("tenant_name" = String, Path, description = "Tenant name or ID to retrieve") + ), + responses( + (status = 200, description = "Tenant found", body = GetTenantResponse), + (status = 404, description = "Tenant not found", body = ErrorResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ) +)] async fn get_tenant( headers: HeaderMap, Path(name): Path, @@ -1346,6 +1359,7 @@ async fn v1_deprecation_notice() -> Response { reset, version, get_user_identity, - create_tenant + create_tenant, + get_tenant ))] struct ApiDoc; diff --git a/rust/types/src/api_types.rs b/rust/types/src/api_types.rs index 24e061fc65f..641b64ee251 100644 --- a/rust/types/src/api_types.rs +++ b/rust/types/src/api_types.rs @@ -213,7 +213,7 @@ impl GetTenantRequest { } } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] #[pyclass] pub struct GetTenantResponse { pub name: String, From 70ae50a212a1bf5e3c463abc57dc83a4ce539ebb Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Tue, 18 Feb 2025 17:09:53 -0800 Subject: [PATCH 10/30] openapi: add database endpoints --- Cargo.lock | 1 + Cargo.toml | 2 +- rust/frontend/src/server.rs | 65 +++++++++++++++++++++++++++++++++++-- rust/types/src/api_types.rs | 6 ++-- 4 files changed, 67 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 07d53b7291e..f679f44648c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7572,6 +7572,7 @@ dependencies = [ "quote", "regex", "syn 2.0.89", + "uuid", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index f7d3c0cd210..4013b4d64ce 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,7 @@ tracing-bunyan-formatter = "0.3" tracing-opentelemetry = "0.28.0" tracing-subscriber = { version = "0.3", features = ["env-filter"] } uuid = { version = "1.11.0", features = ["v4", "fast-rng", "macro-diagnostics", "serde"] } -utoipa = { version = "5.0.0", features = ["axum_extras", "debug"] } +utoipa = { version = "5.0.0", features = ["axum_extras", "debug", "uuid"] } sqlx = { version = "0.8.3", features = ["runtime-tokio", "sqlite"] } sha2 = "0.10.8" md5 = "0.7.0" diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index 662016eb1e5..bf65347b3ba 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -30,6 +30,7 @@ use std::sync::{ Arc, }; use utoipa::OpenApi; +use utoipa::ToSchema; use utoipa_axum::router::OpenApiRouter; use utoipa_swagger_ui::SwaggerUi; use uuid::Uuid; @@ -454,11 +455,24 @@ async fn get_tenant( Ok(Json(server.frontend.get_tenant(request).await?)) } -#[derive(Deserialize, Debug)] +#[derive(Deserialize, Serialize, ToSchema, Debug)] struct CreateDatabasePayload { name: String, } +#[utoipa::path( + post, + path = "/api/v2/tenants/{tenant_id}/databases", + request_body = CreateDatabasePayload, + responses( + (status = 200, description = "Database created successfully", body = CreateDatabaseResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant_id" = String, Path, description = "Tenant ID to associate with the new database") + ) +)] async fn create_database( headers: HeaderMap, Path(tenant_id): Path, @@ -498,12 +512,25 @@ async fn create_database( Ok(Json(res)) } -#[derive(Deserialize)] +#[derive(Deserialize, Serialize, ToSchema, Debug)] struct ListDatabasesPayload { limit: Option, offset: u32, } +#[utoipa::path( + get, + path = "/api/v2/tenants/{tenant_id}/databases", + request_body = ListDatabasesPayload, + responses( + (status = 200, description = "List of databases", body = [ListDatabasesResponse]), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant_id" = String, Path, description = "Tenant ID to list databases for") + ) +)] async fn list_databases( headers: HeaderMap, Path(tenant_id): Path, @@ -532,6 +559,20 @@ async fn list_databases( Ok(Json(server.frontend.list_databases(request).await?)) } +#[utoipa::path( + get, + path = "/api/v2/tenants/{tenant_id}/databases/{database_name}", + responses( + (status = 200, description = "Database retrieved successfully", body = GetDatabaseResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 404, description = "Database not found", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant_id" = String, Path, description = "Tenant ID"), + ("database_name" = String, Path, description = "Name of the database to retrieve") + ) +)] async fn get_database( headers: HeaderMap, Path((tenant_id, database_name)): Path<(String, String)>, @@ -561,6 +602,20 @@ async fn get_database( Ok(Json(res)) } +#[utoipa::path( + delete, + path = "/api/v2/tenants/{tenant_id}/databases/{database_name}", + responses( + (status = 200, description = "Database deleted successfully", body = DeleteDatabaseResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 404, description = "Database not found", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant_id" = String, Path, description = "Tenant ID"), + ("database_name" = String, Path, description = "Name of the database to delete") + ) +)] async fn delete_database( headers: HeaderMap, Path((tenant_id, database_name)): Path<(String, String)>, @@ -1360,6 +1415,10 @@ async fn v1_deprecation_notice() -> Response { version, get_user_identity, create_tenant, - get_tenant + get_tenant, + list_databases, + create_database, + get_database, + delete_database ))] struct ApiDoc; diff --git a/rust/types/src/api_types.rs b/rust/types/src/api_types.rs index 641b64ee251..e2c898d8b5b 100644 --- a/rust/types/src/api_types.rs +++ b/rust/types/src/api_types.rs @@ -261,7 +261,7 @@ impl CreateDatabaseRequest { } } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] pub struct CreateDatabaseResponse {} #[derive(Error, Debug)] @@ -281,7 +281,7 @@ impl ChromaError for CreateDatabaseError { } } -#[derive(Serialize, Debug)] +#[derive(Serialize, Debug, ToSchema)] #[pyo3::pyclass] pub struct Database { pub id: Uuid, @@ -409,7 +409,7 @@ impl DeleteDatabaseRequest { } } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] pub struct DeleteDatabaseResponse {} #[derive(Debug, Error)] From 15fb7e649fea1d7f8623fc4c115edc14de70aa94 Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Tue, 18 Feb 2025 18:00:40 -0800 Subject: [PATCH 11/30] openapi: add collection endpoints --- Cargo.toml | 2 +- rust/frontend/src/server.rs | 113 ++++++++++++++++++++++++++++++++--- rust/types/src/api_types.rs | 2 +- rust/types/src/collection.rs | 16 ++++- rust/types/src/metadata.rs | 13 +++- 5 files changed, 132 insertions(+), 14 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4013b4d64ce..cae173f14c7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,7 @@ tracing-bunyan-formatter = "0.3" tracing-opentelemetry = "0.28.0" tracing-subscriber = { version = "0.3", features = ["env-filter"] } uuid = { version = "1.11.0", features = ["v4", "fast-rng", "macro-diagnostics", "serde"] } -utoipa = { version = "5.0.0", features = ["axum_extras", "debug", "uuid"] } +utoipa = { version = "5.0.0", features = ["macros", "axum_extras", "debug", "uuid"] } sqlx = { version = "0.8.3", features = ["runtime-tokio", "sqlite"] } sha2 = "0.10.8" md5 = "0.7.0" diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index bf65347b3ba..390989d0336 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -653,6 +653,19 @@ struct ListCollectionsParams { offset: u32, } +#[utoipa::path( + get, + path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections", + responses( + (status = 200, description = "List of collections", body = [Collection]), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant_id" = String, Path, description = "Tenant ID"), + ("database_name" = String, Path, description = "Database name to list collections from") + ) +)] async fn list_collections( headers: HeaderMap, Path((tenant_id, database_name)): Path<(String, String)>, @@ -696,16 +709,33 @@ async fn list_collections( Ok(Json(server.frontend.list_collections(request).await?)) } +#[utoipa::path( + get, + path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections_count", + responses( + (status = 200, description = "Count of collections", body = u32), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant_id" = String, Path, description = "Tenant ID"), + ("database_name" = String, Path, description = "Database name to count collections from") + ) +)] async fn count_collections( headers: HeaderMap, Path((tenant_id, database_name)): Path<(String, String)>, State(mut server): State, ) -> Result, ServerError> { - server.metrics.count_collections.add(1, &[]); + server.metrics.count_collections.add( + 1, + &[ + KeyValue::new("tenant_id", tenant_id.clone()), + KeyValue::new("database_name", database_name.clone()), + ], + ); tracing::info!( - "Counting collections in database [{}] for tenant [{}]", - database_name, - tenant_id + "Counting number of collections in database [{database_name}] for tenant [{tenant_id}]", ); server .authenticate_and_authorize( @@ -722,11 +752,12 @@ async fn count_collections( "op:count_collections", format!("tenant:{}", tenant_id).as_str(), ]); + let request = CountCollectionsRequest::try_new(tenant_id, database_name)?; Ok(Json(server.frontend.count_collections(request).await?)) } -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Serialize, ToSchema, Debug, Clone)] pub struct CreateCollectionPayload { pub name: String, pub configuration: Option, @@ -734,6 +765,20 @@ pub struct CreateCollectionPayload { pub get_or_create: bool, } +#[utoipa::path( + post, + path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections", + request_body = CreateCollectionPayload, + responses( + (status = 200, description = "Collection created successfully", body = Collection), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant_id" = String, Path, description = "Tenant ID"), + ("database_name" = String, Path, description = "Database name containing the new collection") + ) +)] async fn create_collection( headers: HeaderMap, Path((tenant_id, database_name)): Path<(String, String)>, @@ -781,6 +826,21 @@ async fn create_collection( Ok(Json(collection)) } +#[utoipa::path( + get, + path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections/{collection_id}", + responses( + (status = 200, description = "Collection found", body = Collection), + (status = 404, description = "Collection not found", body = ErrorResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant_id" = String, Path, description = "Tenant ID"), + ("database_name" = String, Path, description = "Database name"), + ("collection_id" = String, Path, description = "UUID of the collection") + ) +)] async fn get_collection( headers: HeaderMap, Path((tenant_id, database_name, collection_name)): Path<(String, String, String)>, @@ -808,12 +868,28 @@ async fn get_collection( Ok(Json(collection)) } -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Serialize, ToSchema, Debug, Clone)] pub struct UpdateCollectionPayload { pub new_name: Option, pub new_metadata: Option, } +#[utoipa::path( + put, + path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections/{collection_id}", + request_body = UpdateCollectionPayload, + responses( + (status = 200, description = "Collection updated successfully", body = UpdateCollectionResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 404, description = "Collection not found", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant_id" = String, Path, description = "Tenant ID"), + ("database_name" = String, Path, description = "Database name"), + ("collection_id" = String, Path, description = "UUID of the collection to update") + ) +)] async fn update_collection( headers: HeaderMap, Path((tenant_id, database_name, collection_id)): Path<(String, String, String)>, @@ -866,6 +942,21 @@ async fn update_collection( Ok(Json(UpdateCollectionResponse {})) } +#[utoipa::path( + delete, + path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections/{collection_id}", + responses( + (status = 200, description = "Collection deleted successfully", body = UpdateCollectionResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 404, description = "Collection not found", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant_id" = String, Path, description = "Tenant ID"), + ("database_name" = String, Path, description = "Database name"), + ("collection_id" = String, Path, description = "UUID of the collection to delete") + ) +)] async fn delete_collection( headers: HeaderMap, Path((tenant_id, database_name, collection_name)): Path<(String, String, String)>, @@ -895,7 +986,7 @@ async fn delete_collection( Ok(Json(UpdateCollectionResponse {})) } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, ToSchema, Debug, Clone)] pub struct AddCollectionRecordsPayload { ids: Vec, embeddings: Option>>, @@ -1419,6 +1510,12 @@ async fn v1_deprecation_notice() -> Response { list_databases, create_database, get_database, - delete_database + delete_database, + create_collection, + list_collections, + count_collections, + get_collection, + update_collection, + delete_collection ))] struct ApiDoc; diff --git a/rust/types/src/api_types.rs b/rust/types/src/api_types.rs index e2c898d8b5b..818a2819b13 100644 --- a/rust/types/src/api_types.rs +++ b/rust/types/src/api_types.rs @@ -659,7 +659,7 @@ impl UpdateCollectionRequest { } } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] pub struct UpdateCollectionResponse {} #[derive(Error, Debug)] diff --git a/rust/types/src/collection.rs b/rust/types/src/collection.rs index 1e91a821a67..1d0ec407b05 100644 --- a/rust/types/src/collection.rs +++ b/rust/types/src/collection.rs @@ -5,11 +5,23 @@ use pyo3::types::PyAnyMethods; use serde::{Deserialize, Serialize}; use serde_json::Value; use thiserror::Error; +use utoipa::ToSchema; use uuid::Uuid; /// CollectionUuid is a wrapper around Uuid to provide a type for the collection id. #[derive( - Copy, Clone, Debug, Default, Deserialize, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, + Copy, + Clone, + Debug, + Default, + Deserialize, + Eq, + PartialEq, + Ord, + PartialOrd, + Hash, + Serialize, + ToSchema, )] pub struct CollectionUuid(pub Uuid); @@ -36,7 +48,7 @@ impl std::fmt::Display for CollectionUuid { } } -#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] #[pyo3::pyclass] pub struct Collection { #[serde(rename(serialize = "id"))] diff --git a/rust/types/src/metadata.rs b/rust/types/src/metadata.rs index 7d98742304d..86360822fbb 100644 --- a/rust/types/src/metadata.rs +++ b/rust/types/src/metadata.rs @@ -7,10 +7,11 @@ use std::{ collections::{HashMap, HashSet}, }; use thiserror::Error; +use utoipa::ToSchema; use crate::chroma_proto; -#[derive(Clone, Debug, PartialEq, PartialOrd, Deserialize, Serialize)] +#[derive(Clone, Debug, PartialEq, PartialOrd, Deserialize, Serialize, ToSchema)] #[serde(untagged)] pub enum UpdateMetadataValue { Bool(bool), @@ -118,7 +119,15 @@ MetadataValue */ #[derive( - Clone, Debug, Deserialize, PartialEq, PartialOrd, Serialize, FromPyObject, IntoPyObject, + Clone, + Debug, + Deserialize, + PartialEq, + PartialOrd, + Serialize, + FromPyObject, + IntoPyObject, + ToSchema, )] #[serde(untagged)] pub enum MetadataValue { From 259d818edd11bbe3fff4512738070d5560ce1561 Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Tue, 18 Feb 2025 18:12:58 -0800 Subject: [PATCH 12/30] openapi: add collection_add and collection_update endpoints --- rust/frontend/src/server.rs | 24 ++++++++++++++++++++++-- rust/types/src/api_types.rs | 29 ++++++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index 390989d0336..99dff2151b3 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -995,6 +995,15 @@ pub struct AddCollectionRecordsPayload { metadatas: Option>>, } +#[utoipa::path( + post, + path = "/collection_add", + request_body = AddCollectionRecordsPayload, + responses( + (status = 201, description = "Collection added successfully", body = AddCollectionRecordsResponse), + (status = 400, description = "Invalid data for collection addition") + ) +)] async fn collection_add( headers: HeaderMap, Path((tenant_id, database_name, collection_id)): Path<(String, String, String)>, @@ -1067,7 +1076,7 @@ async fn collection_add( Ok(Json(res)) } -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone, ToSchema)] pub struct UpdateCollectionRecordsPayload { ids: Vec, embeddings: Option>>>, @@ -1076,6 +1085,15 @@ pub struct UpdateCollectionRecordsPayload { metadatas: Option>>, } +#[utoipa::path( + put, + path = "/collection_update", + request_body = UpdateCollectionRecordsPayload, + responses( + (status = 200, description = "Collection updated successfully"), + (status = 404, description = "Collection not found") + ) +)] async fn collection_update( headers: HeaderMap, Path((tenant_id, database_name, collection_id)): Path<(String, String, String)>, @@ -1516,6 +1534,8 @@ async fn v1_deprecation_notice() -> Response { count_collections, get_collection, update_collection, - delete_collection + delete_collection, + collection_add, + collection_update ))] struct ApiDoc; diff --git a/rust/types/src/api_types.rs b/rust/types/src/api_types.rs index 818a2819b13..cbc0d317c9b 100644 --- a/rust/types/src/api_types.rs +++ b/rust/types/src/api_types.rs @@ -797,7 +797,7 @@ impl AddCollectionRecordsRequest { } } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] pub struct AddCollectionRecordsResponse {} #[derive(Error, Debug)] @@ -1430,3 +1430,30 @@ pub struct ErrorResponse { pub error: String, pub message: String, } + +#[derive(Serialize, Deserialize, Debug, ToSchema)] +pub struct CollectionAddRequest { + /// Name of the new collection + pub name: String, + /// Optional description + pub description: Option, + // etc. +} + +#[derive(Serialize, Deserialize, Debug, ToSchema)] +pub struct CollectionUpdateRequest { + /// Name of the collection to update + pub name: String, + /// Updated description or other fields + pub new_description: Option, + // etc. +} + +/// Example success response for collection create/update +#[derive(Serialize, Deserialize, Debug, ToSchema)] +pub struct CollectionAddResponse { + /// The updated or newly created collection name + pub name: String, + /// Any relevant status or message + pub message: String, +} From 63dd4a6df6632c0ec67a2a7838feeda1e10183dd Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Tue, 18 Feb 2025 18:26:01 -0800 Subject: [PATCH 13/30] openapi: add collection_upsert and collection_delete methods --- rust/frontend/src/server.rs | 42 +++++++++++++++++++++++++++++---- rust/types/src/api_types.rs | 6 ++--- rust/types/src/metadata.rs | 14 +++++------ rust/types/src/where_parsing.rs | 4 +++- 4 files changed, 51 insertions(+), 15 deletions(-) diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index 99dff2151b3..cf798d29c56 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -1076,7 +1076,7 @@ async fn collection_add( Ok(Json(res)) } -#[derive(Deserialize, Debug, Clone, ToSchema)] +#[derive(Deserialize, Debug, Clone, ToSchema, Serialize)] pub struct UpdateCollectionRecordsPayload { ids: Vec, embeddings: Option>>>, @@ -1163,7 +1163,7 @@ async fn collection_update( Ok(Json(server.frontend.update(request).await?)) } -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone, ToSchema, Serialize)] pub struct UpsertCollectionRecordsPayload { ids: Vec, embeddings: Option>>, @@ -1172,6 +1172,22 @@ pub struct UpsertCollectionRecordsPayload { metadatas: Option>>, } +#[utoipa::path( + post, + path = "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/upsert", + request_body = UpsertCollectionRecordsPayload, + responses( + (status = 200, description = "Records upserted successfully", body = UpsertCollectionRecordsResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 404, description = "Collection not found", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse), + ), + params( + ("tenant" = String, Path, description = "Tenant ID"), + ("database_name" = String, Path, description = "Database name"), + ("collection_id" = String, Path, description = "Collection ID"), + ) +)] async fn collection_upsert( headers: HeaderMap, Path((tenant_id, database_name, collection_id)): Path<(String, String, String)>, @@ -1242,13 +1258,29 @@ async fn collection_upsert( Ok(Json(server.frontend.upsert(request).await?)) } -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone, ToSchema, Serialize)] pub struct DeleteCollectionRecordsPayload { ids: Option>, #[serde(flatten)] where_fields: RawWhereFields, } +#[utoipa::path( + post, + path = "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/delete", + request_body = DeleteCollectionRecordsPayload, + responses( + (status = 200, description = "Records deleted successfully", body = DeleteCollectionRecordsResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 404, description = "Collection not found", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse), + ), + params( + ("tenant" = String, Path, description = "Tenant ID"), + ("database_name" = String, Path, description = "Database name"), + ("collection_id" = String, Path, description = "Collection ID"), + ) +)] async fn collection_delete( headers: HeaderMap, Path((tenant_id, database_name, collection_id)): Path<(String, String, String)>, @@ -1536,6 +1568,8 @@ async fn v1_deprecation_notice() -> Response { update_collection, delete_collection, collection_add, - collection_update + collection_update, + collection_upsert, + collection_delete ))] struct ApiDoc; diff --git a/rust/types/src/api_types.rs b/rust/types/src/api_types.rs index cbc0d317c9b..ea19c0361c8 100644 --- a/rust/types/src/api_types.rs +++ b/rust/types/src/api_types.rs @@ -918,7 +918,7 @@ impl UpsertCollectionRecordsRequest { } } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] pub struct UpsertCollectionRecordsResponse {} #[derive(Error, Debug)] @@ -938,7 +938,7 @@ impl ChromaError for UpsertCollectionRecordsError { ////////////////////////// DeleteCollectionRecords ////////////////////////// #[non_exhaustive] -#[derive(Clone, Validate)] +#[derive(Clone, Validate, ToSchema)] pub struct DeleteCollectionRecordsRequest { pub tenant_id: String, pub database_name: String, @@ -975,7 +975,7 @@ impl DeleteCollectionRecordsRequest { } } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] pub struct DeleteCollectionRecordsResponse {} #[derive(Error, Debug)] diff --git a/rust/types/src/metadata.rs b/rust/types/src/metadata.rs index 86360822fbb..6aa7fba6892 100644 --- a/rust/types/src/metadata.rs +++ b/rust/types/src/metadata.rs @@ -447,7 +447,7 @@ impl WhereConversionError { /// present we simply create a conjunction of both clauses as the actual filter. This is consistent with /// the semantics we used to have when the `where` and `where_document` clauses are treated seperately. // TODO: Remove this note once the `where` clause and `where_document` clause is unified in the API level. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, ToSchema)] pub enum Where { Composite(CompositeExpression), Document(DocumentExpression), @@ -521,7 +521,7 @@ impl TryFrom for chroma_proto::Where { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, ToSchema)] pub struct CompositeExpression { pub operator: BooleanOperator, pub children: Vec, @@ -557,7 +557,7 @@ impl TryFrom for chroma_proto::WhereChildren { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, ToSchema)] pub enum BooleanOperator { And, Or, @@ -581,7 +581,7 @@ impl From for chroma_proto::BooleanOperator { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, ToSchema)] pub struct DocumentExpression { pub operator: DocumentOperator, pub text: String, @@ -605,7 +605,7 @@ impl From for chroma_proto::DirectWhereDocument { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, ToSchema)] pub enum DocumentOperator { Contains, NotContains, @@ -628,7 +628,7 @@ impl From for chroma_proto::WhereDocumentOperator { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, ToSchema)] pub struct MetadataExpression { pub key: String, pub comparison: MetadataComparison, @@ -758,7 +758,7 @@ impl TryFrom for chroma_proto::DirectComparison { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, ToSchema)] pub enum MetadataComparison { Primitive(PrimitiveOperator, MetadataValue), Set(SetOperator, MetadataSetValue), diff --git a/rust/types/src/where_parsing.rs b/rust/types/src/where_parsing.rs index 33cedc47477..757ec43cb6b 100644 --- a/rust/types/src/where_parsing.rs +++ b/rust/types/src/where_parsing.rs @@ -1,10 +1,12 @@ use crate::{CompositeExpression, DocumentOperator, MetadataExpression, PrimitiveOperator, Where}; use chroma_error::ChromaError; use serde::Deserialize; +use serde::Serialize; use serde_json::Value; use thiserror::Error; +use utoipa::ToSchema; -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone, Serialize, ToSchema)] pub struct RawWhereFields { #[serde(default)] r#where: Value, From 2ca496b9ef2a88376b753f3ad77c73348c98c7e6 Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Tue, 18 Feb 2025 18:36:47 -0800 Subject: [PATCH 14/30] openapi: add collection_count and collection_get endpoints --- rust/frontend/src/server.rs | 37 +++++++++++++++++++++++++++++++++++-- rust/types/src/api_types.rs | 6 +++--- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index cf798d29c56..92fabeed68a 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -1333,6 +1333,21 @@ async fn collection_delete( Ok(Json(DeleteCollectionRecordsResponse {})) } +#[utoipa::path( + get, + path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections/{collection_id}/count", + responses( + (status = 200, description = "Number of records in the collection", body = CountResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 404, description = "Collection not found", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant_id" = String, Path, description = "Tenant ID for the collection"), + ("database_name" = String, Path, description = "Database containing this collection"), + ("collection_id" = String, Path, description = "Collection ID whose records are counted") + ) +)] async fn collection_count( headers: HeaderMap, Path((tenant_id, database_name, collection_id)): Path<(String, String, String)>, @@ -1375,7 +1390,7 @@ async fn collection_count( Ok(Json(server.frontend.count(request).await?)) } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, ToSchema)] pub struct GetRequestPayload { ids: Option>, #[serde(flatten)] @@ -1386,6 +1401,22 @@ pub struct GetRequestPayload { include: IncludeList, } +#[utoipa::path( + post, + path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections/{collection_id}/get", + request_body = GetRequestPayload, + responses( + (status = 200, description = "Records retrieved from the collection", body = GetResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 404, description = "Collection not found", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant_id" = String, Path, description = "Tenant ID"), + ("database_name" = String, Path, description = "Database name for the collection"), + ("collection_id" = String, Path, description = "Collection ID to fetch records from") + ) +)] async fn collection_get( headers: HeaderMap, Path((tenant_id, database_name, collection_id)): Path<(String, String, String)>, @@ -1570,6 +1601,8 @@ async fn v1_deprecation_notice() -> Response { collection_add, collection_update, collection_upsert, - collection_delete + collection_delete, + collection_count, + collection_get ))] struct ApiDoc; diff --git a/rust/types/src/api_types.rs b/rust/types/src/api_types.rs index ea19c0361c8..e8d1fb5f6ae 100644 --- a/rust/types/src/api_types.rs +++ b/rust/types/src/api_types.rs @@ -1007,7 +1007,7 @@ impl ChromaError for IncludeParsingError { } } -#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] pub enum Include { #[serde(rename = "distances")] Distance, @@ -1036,7 +1036,7 @@ impl TryFrom<&str> for Include { } } -#[derive(Debug, Clone, Deserialize)] +#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, PartialEq)] #[pyclass] pub struct IncludeList(pub Vec); @@ -1141,7 +1141,7 @@ impl GetRequest { } } -#[derive(Clone, Deserialize, Serialize, Debug)] +#[derive(Clone, Deserialize, Serialize, Debug, ToSchema)] #[pyclass] pub struct GetResponse { #[pyo3(get)] From da9e96c415291eb554bc35c53b07f178a526a0b4 Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Tue, 18 Feb 2025 18:45:31 -0800 Subject: [PATCH 15/30] openapi: add collection_query endpoint --- rust/frontend/src/server.rs | 22 ++++++++++++++++++++-- rust/types/src/api_types.rs | 2 +- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index 92fabeed68a..b28e4e1da0c 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -1481,7 +1481,7 @@ async fn collection_get( Ok(Json(res)) } -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone, Serialize, utoipa::ToSchema)] pub struct QueryRequestPayload { ids: Option>, #[serde(flatten)] @@ -1492,6 +1492,23 @@ pub struct QueryRequestPayload { include: IncludeList, } +/// Query a collection for nearest matches using vector search +#[utoipa::path( + post, + path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections/{collection_id}/query", + request_body = QueryRequestPayload, + responses( + (status = 200, description = "Records matching the query", body = QueryResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 404, description = "Collection not found", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse), + ), + params( + ("tenant_id" = String, Path, description = "Tenant ID"), + ("database_name" = String, Path, description = "Database name containing the collection"), + ("collection_id" = String, Path, description = "Collection ID to query") + ) +)] async fn collection_query( headers: HeaderMap, Path((tenant_id, database_name, collection_id)): Path<(String, String, String)>, @@ -1603,6 +1620,7 @@ async fn v1_deprecation_notice() -> Response { collection_upsert, collection_delete, collection_count, - collection_get + collection_get, + collection_query ))] struct ApiDoc; diff --git a/rust/types/src/api_types.rs b/rust/types/src/api_types.rs index e8d1fb5f6ae..da0810c85b7 100644 --- a/rust/types/src/api_types.rs +++ b/rust/types/src/api_types.rs @@ -1257,7 +1257,7 @@ impl QueryRequest { } } -#[derive(Clone, Deserialize, Serialize)] +#[derive(Clone, Deserialize, Serialize, utoipa::ToSchema)] #[pyclass] pub struct QueryResponse { #[pyo3(get)] From 3e089bfa07bb7659bfa3b56428b80298869cf06c Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Tue, 18 Feb 2025 18:52:19 -0800 Subject: [PATCH 16/30] openapi: add api key --- rust/frontend/src/server.rs | 74 +++++++++++++++++++++++-------------- 1 file changed, 46 insertions(+), 28 deletions(-) diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index b28e4e1da0c..0a9c410de10 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -29,8 +29,9 @@ use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; -use utoipa::OpenApi; +use utoipa::openapi::security::{ApiKey, ApiKeyValue, SecurityScheme}; use utoipa::ToSchema; +use utoipa::{Modify, OpenApi}; use utoipa_axum::router::OpenApiRouter; use utoipa_swagger_ui::SwaggerUi; use uuid::Uuid; @@ -1595,32 +1596,49 @@ async fn v1_deprecation_notice() -> Response { (StatusCode::GONE, Json(err_response)).into_response() } +// Add a struct implementing Modify to inject the new security scheme +struct ChromaTokenSecurityAddon; + +impl Modify for ChromaTokenSecurityAddon { + fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) { + let components = openapi.components.as_mut().unwrap(); + components.add_security_scheme( + "x-chroma-token", + SecurityScheme::ApiKey(ApiKey::Header(ApiKeyValue::new("x-chroma-token"))), + ); + } +} + #[derive(OpenApi)] -#[openapi(paths( - healthcheck, - heartbeat, - pre_flight_checks, - reset, - version, - get_user_identity, - create_tenant, - get_tenant, - list_databases, - create_database, - get_database, - delete_database, - create_collection, - list_collections, - count_collections, - get_collection, - update_collection, - delete_collection, - collection_add, - collection_update, - collection_upsert, - collection_delete, - collection_count, - collection_get, - collection_query -))] +#[openapi( + paths( + healthcheck, + heartbeat, + pre_flight_checks, + reset, + version, + get_user_identity, + create_tenant, + get_tenant, + list_databases, + create_database, + get_database, + delete_database, + create_collection, + list_collections, + count_collections, + get_collection, + update_collection, + delete_collection, + collection_add, + collection_update, + collection_upsert, + collection_delete, + collection_count, + collection_get, + collection_query + ), + // Apply our new security scheme here + modifiers(&ChromaTokenSecurityAddon) +)] struct ApiDoc; From d0cb299eee5398de82ce9d0b593258253890de69 Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Wed, 19 Feb 2025 11:04:49 -0800 Subject: [PATCH 17/30] openapi: fix list_databases and collection endpoints --- rust/frontend/src/server.rs | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index 0a9c410de10..24d99c28190 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -513,30 +513,33 @@ async fn create_database( Ok(Json(res)) } -#[derive(Deserialize, Serialize, ToSchema, Debug)] -struct ListDatabasesPayload { +#[derive(Deserialize, ToSchema, Debug)] +struct ListDatabasesQueryParams { limit: Option, + #[serde(default)] offset: u32, } #[utoipa::path( get, path = "/api/v2/tenants/{tenant_id}/databases", - request_body = ListDatabasesPayload, + // Remove the request_body line entirely responses( (status = 200, description = "List of databases", body = [ListDatabasesResponse]), (status = 401, description = "Unauthorized", body = ErrorResponse), (status = 500, description = "Server error", body = ErrorResponse) ), params( - ("tenant_id" = String, Path, description = "Tenant ID to list databases for") + ("tenant_id" = String, Path, description = "Tenant ID to list databases for"), + ("limit" = Option, Query, description = "Limit for pagination"), + ("offset" = Option, Query, description = "Offset for pagination") ) )] async fn list_databases( headers: HeaderMap, Path(tenant_id): Path, + Query(ListDatabasesQueryParams { limit, offset }): Query, State(mut server): State, - Json(ListDatabasesPayload { limit, offset }): Json, ) -> Result, ServerError> { server.metrics.list_databases.add(1, &[]); tracing::info!("Listing database for tenant [{}]", tenant_id); @@ -998,7 +1001,7 @@ pub struct AddCollectionRecordsPayload { #[utoipa::path( post, - path = "/collection_add", + path = "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/add", request_body = AddCollectionRecordsPayload, responses( (status = 201, description = "Collection added successfully", body = AddCollectionRecordsResponse), @@ -1087,8 +1090,8 @@ pub struct UpdateCollectionRecordsPayload { } #[utoipa::path( - put, - path = "/collection_update", + post, + path = "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/update", request_body = UpdateCollectionRecordsPayload, responses( (status = 200, description = "Collection updated successfully"), From db04232ae0142b2cbc106da7aa687deecbe5d001 Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Wed, 19 Feb 2025 11:19:17 -0800 Subject: [PATCH 18/30] openapi: clean up --- Cargo.toml | 1 - rust/frontend/src/server.rs | 23 +++++++++++------------ 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index cae173f14c7..15252a61cf2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,6 @@ members = ["rust/benchmark", "rust/blockstore", "rust/cache", "rust/chroma", "ru [workspace.dependencies] arrow = "52.2.0" async-trait = "0.1" - axum = { version = "0.8", features = ["macros"] } chrono = { version = "0.4", features = ["serde"] } clap = { version = "4", features = ["derive"] } diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index 24d99c28190..d5ff1ceba87 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -157,7 +157,6 @@ impl FrontendServer { pub async fn run(server: FrontendServer) { let circuit_breaker_config = server.config.circuit_breaker.clone(); - // Build an OpenApiRouter with only the healthcheck endpoint let (docs_router, docs_api) = OpenApiRouter::with_openapi(ApiDoc::openapi()).split_for_parts(); @@ -523,7 +522,6 @@ struct ListDatabasesQueryParams { #[utoipa::path( get, path = "/api/v2/tenants/{tenant_id}/databases", - // Remove the request_body line entirely responses( (status = 200, description = "List of databases", body = [ListDatabasesResponse]), (status = 401, description = "Unauthorized", body = ErrorResponse), @@ -667,7 +665,9 @@ struct ListCollectionsParams { ), params( ("tenant_id" = String, Path, description = "Tenant ID"), - ("database_name" = String, Path, description = "Database name to list collections from") + ("database_name" = String, Path, description = "Database name to list collections from"), + ("limit" = Option, Query, description = "Limit for pagination"), + ("offset" = Option, Query, description = "Offset for pagination") ) )] async fn list_collections( @@ -731,13 +731,7 @@ async fn count_collections( Path((tenant_id, database_name)): Path<(String, String)>, State(mut server): State, ) -> Result, ServerError> { - server.metrics.count_collections.add( - 1, - &[ - KeyValue::new("tenant_id", tenant_id.clone()), - KeyValue::new("database_name", database_name.clone()), - ], - ); + server.metrics.count_collections.add(1, &[]); tracing::info!( "Counting number of collections in database [{database_name}] for tenant [{tenant_id}]", ); @@ -1510,7 +1504,9 @@ pub struct QueryRequestPayload { params( ("tenant_id" = String, Path, description = "Tenant ID"), ("database_name" = String, Path, description = "Database name containing the collection"), - ("collection_id" = String, Path, description = "Collection ID to query") + ("collection_id" = String, Path, description = "Collection ID to query"), + ("limit" = Option, Query, description = "Limit for pagination"), + ("offset" = Option, Query, description = "Offset for pagination") ) )] async fn collection_query( @@ -1599,7 +1595,10 @@ async fn v1_deprecation_notice() -> Response { (StatusCode::GONE, Json(err_response)).into_response() } -// Add a struct implementing Modify to inject the new security scheme +//////////////////////////////////////////////////////////// +/// OpenAPI +//////////////////////////////////////////////////////////// + struct ChromaTokenSecurityAddon; impl Modify for ChromaTokenSecurityAddon { From 562f0a56d3a2f58be0d3ec066904003201bec701 Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Wed, 19 Feb 2025 11:36:36 -0800 Subject: [PATCH 19/30] openapi: clean up --- rust/frontend/src/server.rs | 2 +- rust/types/src/api_types.rs | 35 +---------------------------------- 2 files changed, 2 insertions(+), 35 deletions(-) diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index d5ff1ceba87..02257f8846d 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -1479,7 +1479,7 @@ async fn collection_get( Ok(Json(res)) } -#[derive(Deserialize, Debug, Clone, Serialize, utoipa::ToSchema)] +#[derive(Deserialize, Debug, Clone, Serialize, ToSchema)] pub struct QueryRequestPayload { ids: Option>, #[serde(flatten)] diff --git a/rust/types/src/api_types.rs b/rust/types/src/api_types.rs index da0810c85b7..a8f4592e73f 100644 --- a/rust/types/src/api_types.rs +++ b/rust/types/src/api_types.rs @@ -1257,7 +1257,7 @@ impl QueryRequest { } } -#[derive(Clone, Deserialize, Serialize, utoipa::ToSchema)] +#[derive(Clone, Deserialize, Serialize, ToSchema)] #[pyclass] pub struct QueryResponse { #[pyo3(get)] @@ -1424,36 +1424,3 @@ impl ChromaError for ExecutorError { } } } - -#[derive(Debug, Serialize, ToSchema)] -pub struct ErrorResponse { - pub error: String, - pub message: String, -} - -#[derive(Serialize, Deserialize, Debug, ToSchema)] -pub struct CollectionAddRequest { - /// Name of the new collection - pub name: String, - /// Optional description - pub description: Option, - // etc. -} - -#[derive(Serialize, Deserialize, Debug, ToSchema)] -pub struct CollectionUpdateRequest { - /// Name of the collection to update - pub name: String, - /// Updated description or other fields - pub new_description: Option, - // etc. -} - -/// Example success response for collection create/update -#[derive(Serialize, Deserialize, Debug, ToSchema)] -pub struct CollectionAddResponse { - /// The updated or newly created collection name - pub name: String, - /// Any relevant status or message - pub message: String, -} From 6c2aef0e3cbcc6d84f15755b8564c628d3cfb5fc Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Wed, 19 Feb 2025 11:45:24 -0800 Subject: [PATCH 20/30] openapi: clean up heartbeat --- rust/frontend/src/server.rs | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index 02257f8846d..2ed2ec51cb0 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -296,15 +296,11 @@ async fn healthcheck(State(server): State) -> impl IntoResponse (status = 500, description = "Server error", body = ErrorResponse) ) )] -async fn heartbeat(State(server): State) -> impl IntoResponse { +async fn heartbeat( + State(server): State, +) -> Result, ServerError> { server.metrics.heartbeat.add(1, &[]); - match server.frontend.heartbeat().await { - Ok(response) => (StatusCode::OK, Json(response)).into_response(), - Err(err) => { - let error = ErrorResponse::new("HeartbeatError".to_string(), err.to_string()); - (StatusCode::INTERNAL_SERVER_ERROR, Json(error)).into_response() - } - } + Ok(Json(server.frontend.heartbeat().await?)) } #[utoipa::path( From 54a801ff445aa0d1511ae5966113179bd15eb888 Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Wed, 19 Feb 2025 14:21:46 -0800 Subject: [PATCH 21/30] openapi: clean up docs --- rust/frontend/src/server.rs | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index 2ed2ec51cb0..5e4a8d4d495 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -271,6 +271,7 @@ impl FrontendServer { // These handlers simply proxy the call and the relevant inputs into // the appropriate method on the `FrontendServer` struct. +/// Health check endpoint that returns 200 if the server and executor are ready #[utoipa::path( get, path = "/api/v2/healthcheck", @@ -288,6 +289,7 @@ async fn healthcheck(State(server): State) -> impl IntoResponse (code, Json(res)) } +/// Heartbeat endpoint that returns a nanosecond timestamp of the current time. Useful for making sure the client remains connected. #[utoipa::path( get, path = "/api/v2/heartbeat", @@ -303,6 +305,7 @@ async fn heartbeat( Ok(Json(server.frontend.heartbeat().await?)) } +/// Pre-flight checks endpoint reporting basic readiness info. #[utoipa::path( get, path = "/api/v2/pre-flight-checks", @@ -320,6 +323,7 @@ async fn pre_flight_checks( })) } +/// Reset endpoint allowing authorized users to reset the database. #[utoipa::path( post, path = "/api/v2/reset", @@ -357,6 +361,7 @@ async fn reset(headers: HeaderMap, State(mut server): State) -> } } +/// Returns the version of the server. #[utoipa::path( get, path = "/api/v2/version", @@ -369,6 +374,7 @@ async fn version(State(server): State) -> &'static str { env!("CARGO_PKG_VERSION") } +/// Retrieves the current user's identity, tenant, and databases. #[utoipa::path( get, path = "/api/v2/auth/identity", @@ -385,6 +391,7 @@ async fn get_user_identity( Ok(Json(server.auth.get_user_identity(&headers).await?)) } +/// Creates a new tenant. #[utoipa::path( post, path = "/api/v2/tenants", @@ -416,6 +423,7 @@ async fn create_tenant( Ok(Json(server.frontend.create_tenant(request).await?)) } +/// Returns an existing tenant by name. #[utoipa::path( get, path = "/api/v2/tenants/{tenant_name}", @@ -456,6 +464,7 @@ struct CreateDatabasePayload { name: String, } +/// Creates a new database for a given tenant. #[utoipa::path( post, path = "/api/v2/tenants/{tenant_id}/databases", @@ -515,6 +524,7 @@ struct ListDatabasesQueryParams { offset: u32, } +/// Lists all databases for a given tenant. #[utoipa::path( get, path = "/api/v2/tenants/{tenant_id}/databases", @@ -557,6 +567,7 @@ async fn list_databases( Ok(Json(server.frontend.list_databases(request).await?)) } +/// Retrieves a specific database by name. #[utoipa::path( get, path = "/api/v2/tenants/{tenant_id}/databases/{database_name}", @@ -600,6 +611,7 @@ async fn get_database( Ok(Json(res)) } +/// Deletes a specific database. #[utoipa::path( delete, path = "/api/v2/tenants/{tenant_id}/databases/{database_name}", @@ -651,6 +663,7 @@ struct ListCollectionsParams { offset: u32, } +/// Lists all collections in the specified database. #[utoipa::path( get, path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections", @@ -709,6 +722,7 @@ async fn list_collections( Ok(Json(server.frontend.list_collections(request).await?)) } +/// Retrieves the total number of collections in a given database. #[utoipa::path( get, path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections_count", @@ -758,7 +772,7 @@ pub struct CreateCollectionPayload { pub metadata: Option, pub get_or_create: bool, } - +/// Creates a new collection under the specified database. #[utoipa::path( post, path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections", @@ -820,6 +834,7 @@ async fn create_collection( Ok(Json(collection)) } +/// Retrieves a collection by ID or name. #[utoipa::path( get, path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections/{collection_id}", @@ -868,6 +883,7 @@ pub struct UpdateCollectionPayload { pub new_metadata: Option, } +/// Updates an existing collection's name or metadata. #[utoipa::path( put, path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections/{collection_id}", @@ -936,6 +952,7 @@ async fn update_collection( Ok(Json(UpdateCollectionResponse {})) } +/// Deletes a collection in a given database. #[utoipa::path( delete, path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections/{collection_id}", @@ -989,6 +1006,7 @@ pub struct AddCollectionRecordsPayload { metadatas: Option>>, } +/// Adds records to a collection. #[utoipa::path( post, path = "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/add", @@ -1079,6 +1097,7 @@ pub struct UpdateCollectionRecordsPayload { metadatas: Option>>, } +/// Updates records in a collection by ID. #[utoipa::path( post, path = "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/update", @@ -1166,6 +1185,7 @@ pub struct UpsertCollectionRecordsPayload { metadatas: Option>>, } +/// Upserts records in a collection (create if not exists, otherwise update). #[utoipa::path( post, path = "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/upsert", @@ -1259,6 +1279,7 @@ pub struct DeleteCollectionRecordsPayload { where_fields: RawWhereFields, } +/// Deletes records in a collection. Can filter by IDs or metadata. #[utoipa::path( post, path = "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/delete", @@ -1327,6 +1348,7 @@ async fn collection_delete( Ok(Json(DeleteCollectionRecordsResponse {})) } +/// Retrieves the number of records in a collection. #[utoipa::path( get, path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections/{collection_id}/count", @@ -1395,6 +1417,7 @@ pub struct GetRequestPayload { include: IncludeList, } +/// Retrieves records from a collection by ID or metadata filter. #[utoipa::path( post, path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections/{collection_id}/get", @@ -1486,7 +1509,7 @@ pub struct QueryRequestPayload { include: IncludeList, } -/// Query a collection for nearest matches using vector search +/// Query a collection in a variety of ways, including vector search, metadata filtering, and full-text search #[utoipa::path( post, path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections/{collection_id}/query", From 18d600dbc8adb11d1128c1db24bebdb9839f8ab0 Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Wed, 19 Feb 2025 14:33:11 -0800 Subject: [PATCH 22/30] openapi: refine returned values --- rust/frontend/src/server.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index 5e4a8d4d495..477686af6a4 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -276,7 +276,8 @@ impl FrontendServer { get, path = "/api/v2/healthcheck", responses( - (status = 200, description = "Success", body = str, content_type = "text/plain") + (status = 200, description = "Success", body = String, content_type = "application/json"), + (status = 503, description = "Service Unavailable", body = ErrorResponse), ) )] async fn healthcheck(State(server): State) -> impl IntoResponse { @@ -529,7 +530,7 @@ struct ListDatabasesQueryParams { get, path = "/api/v2/tenants/{tenant_id}/databases", responses( - (status = 200, description = "List of databases", body = [ListDatabasesResponse]), + (status = 200, description = "List of databases", body = ListDatabasesResponse), (status = 401, description = "Unauthorized", body = ErrorResponse), (status = 500, description = "Server error", body = ErrorResponse) ), @@ -668,7 +669,7 @@ struct ListCollectionsParams { get, path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections", responses( - (status = 200, description = "List of collections", body = [Collection]), + (status = 200, description = "List of collections", body = ListCollectionsResponse), (status = 401, description = "Unauthorized", body = ErrorResponse), (status = 500, description = "Server error", body = ErrorResponse) ), @@ -727,7 +728,7 @@ async fn list_collections( get, path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections_count", responses( - (status = 200, description = "Count of collections", body = u32), + (status = 200, description = "Count of collections", body = CountCollectionsResponse), (status = 401, description = "Unauthorized", body = ErrorResponse), (status = 500, description = "Server error", body = ErrorResponse) ), @@ -1103,7 +1104,7 @@ pub struct UpdateCollectionRecordsPayload { path = "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/update", request_body = UpdateCollectionRecordsPayload, responses( - (status = 200, description = "Collection updated successfully"), + (status = 200, description = "Collection updated successfully", body = UpdateCollectionRecordsResponse), (status = 404, description = "Collection not found") ) )] From 9ae1d2cac1350f47ae96d41cab6e270750f9ad8f Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Wed, 19 Feb 2025 14:56:22 -0800 Subject: [PATCH 23/30] server: revert reset endpoint --- rust/frontend/src/server.rs | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index 91909e8efb4..e128547f2d4 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -334,9 +334,12 @@ async fn pre_flight_checks( (status = 500, description = "Server error", body = ErrorResponse) ) )] -async fn reset(headers: HeaderMap, State(mut server): State) -> impl IntoResponse { +async fn reset( + headers: HeaderMap, + State(mut server): State, +) -> Result, ServerError> { server.metrics.reset.add(1, &[]); - match server + server .authenticate_and_authorize( &headers, AuthzAction::Reset, @@ -346,20 +349,9 @@ async fn reset(headers: HeaderMap, State(mut server): State) -> collection: None, }, ) - .await - { - Err(auth_err) => { - let error = ErrorResponse::new("AuthError".to_string(), auth_err.to_string()); - (StatusCode::UNAUTHORIZED, Json(error)).into_response() - } - Ok(_) => match server.frontend.reset().await { - Ok(_) => (StatusCode::OK, Json(true)).into_response(), - Err(reset_err) => { - let error = ErrorResponse::new("ResetError".to_string(), reset_err.to_string()); - (StatusCode::INTERNAL_SERVER_ERROR, Json(error)).into_response() - } - }, - } + .await?; + server.frontend.reset().await?; + Ok(Json(true)) } /// Returns the version of the server. From 0e4345bf9363042e7289a07df25e41b2613a3232 Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Wed, 19 Feb 2025 15:15:17 -0800 Subject: [PATCH 24/30] openapi: fix clippy errors --- rust/frontend/src/server.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index e128547f2d4..3a6e08acc65 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -1580,9 +1580,7 @@ async fn v1_deprecation_notice() -> Response { //////////////////////////////////////////////////////////// /// OpenAPI //////////////////////////////////////////////////////////// - struct ChromaTokenSecurityAddon; - impl Modify for ChromaTokenSecurityAddon { fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) { let components = openapi.components.as_mut().unwrap(); From 8bfcb48e1d87af6058a6294f13898528e96c1230 Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Thu, 20 Feb 2025 10:02:23 -0800 Subject: [PATCH 25/30] frontend + openapi: standardize to `tenant` and `database` as path param names to match legacy api --- rust/frontend/src/server.rs | 416 ++++++++++++++++++------------------ 1 file changed, 205 insertions(+), 211 deletions(-) diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index 3a6e08acc65..a561c6a06d8 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -164,12 +164,15 @@ impl FrontendServer { let app = Router::new() // `GET /` goes to `root` - .route("/api/v1/{*any}", get(v1_deprecation_notice) - .put(v1_deprecation_notice) - .patch(v1_deprecation_notice) - .delete(v1_deprecation_notice) - .head(v1_deprecation_notice) - .options(v1_deprecation_notice)) + .route( + "/api/v1/{*any}", + get(v1_deprecation_notice) + .put(v1_deprecation_notice) + .patch(v1_deprecation_notice) + .delete(v1_deprecation_notice) + .head(v1_deprecation_notice) + .options(v1_deprecation_notice), + ) .route("/api/v2/healthcheck", get(healthcheck)) .route("/api/v2/heartbeat", get(heartbeat)) .route("/api/v2/pre-flight-checks", get(pre_flight_checks)) @@ -177,52 +180,60 @@ impl FrontendServer { .route("/api/v2/version", get(version)) .route("/api/v2/auth/identity", get(get_user_identity)) .route("/api/v2/tenants", post(create_tenant)) - .route("/api/v2/tenants/{tenant_name}", get(get_tenant)) - .route("/api/v2/tenants/{tenant_id}/databases", get(list_databases).post(create_database)) - .route("/api/v2/tenants/{tenant_id}/databases/{name}", get(get_database).delete(delete_database)) + .route("/api/v2/tenants/{tenant}", get(get_tenant)) + .route( + "/api/v2/tenants/{tenant}/databases", + get(list_databases).post(create_database), + ) + .route( + "/api/v2/tenants/{tenant}/databases/{database}", + get(get_database).delete(delete_database), + ) .route( - "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections", - post(create_collection).get(list_collections), + "/api/v2/tenants/{tenant}/databases/{database}/collections", + post(create_collection).get(list_collections), ) .route( - "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections_count", + "/api/v2/tenants/{tenant}/databases/{database}/collections_count", get(count_collections), ) .route( - "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections/{collection_id}", - get(get_collection).put(update_collection).delete(delete_collection), + "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}", + get(get_collection) + .put(update_collection) + .delete(delete_collection), ) .route( - "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/add", + "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/add", post(collection_add), ) .route( - "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/update", + "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/update", post(collection_update), ) .route( - "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/upsert", + "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/upsert", post(collection_upsert), ) .route( - "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/delete", + "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/delete", post(collection_delete), ) .route( - "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections/{collection_id}/count", + "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/count", get(collection_count), ) .route( - "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections/{collection_id}/get", + "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/get", post(collection_get), ) .route( - "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections/{collection_id}/query", + "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/query", post(collection_query), ) .merge(docs_router) .with_state(server) - .layer(DefaultBodyLimit::max(6000000*8)); // TODO: add to server configuration + .layer(DefaultBodyLimit::max(6000000 * 8)); // TODO: add to server configuration let app = add_tracing_middleware(app); // TODO: configuration for this @@ -468,7 +479,7 @@ struct CreateDatabasePayload { /// Creates a new database for a given tenant. #[utoipa::path( post, - path = "/api/v2/tenants/{tenant_id}/databases", + path = "/api/v2/tenants/{tenant}/databases", request_body = CreateDatabasePayload, responses( (status = 200, description = "Database created successfully", body = CreateDatabaseResponse), @@ -476,23 +487,23 @@ struct CreateDatabasePayload { (status = 500, description = "Server error", body = ErrorResponse) ), params( - ("tenant_id" = String, Path, description = "Tenant ID to associate with the new database") + ("tenant" = String, Path, description = "Tenant ID to associate with the new database") ) )] async fn create_database( headers: HeaderMap, - Path(tenant_id): Path, + Path(tenant): Path, State(mut server): State, Json(CreateDatabasePayload { name }): Json, ) -> Result, ServerError> { server.metrics.create_database.add(1, &[]); - tracing::info!("Creating database [{}] for tenant [{}]", name, tenant_id); + tracing::info!("Creating database [{}] for tenant [{}]", name, tenant); server .authenticate_and_authorize( &headers, AuthzAction::CreateDatabase, AuthzResource { - tenant: Some(tenant_id.clone()), + tenant: Some(tenant.clone()), database: Some(name.clone()), collection: None, }, @@ -503,14 +514,12 @@ async fn create_database( .get("x-chroma-token") .map(|val| val.to_str().unwrap_or_default()) .map(|val| val.to_string()); - let mut quota_payload = QuotaPayload::new(Action::CreateDatabase, tenant_id.clone(), api_token); + let mut quota_payload = QuotaPayload::new(Action::CreateDatabase, tenant.clone(), api_token); quota_payload = quota_payload.with_collection_name(&name); server.quota_enforcer.enforce("a_payload).await?; - let _guard = server.scorecard_request(&[ - "op:create_database", - format!("tenant:{}", tenant_id).as_str(), - ]); - let create_database_request = CreateDatabaseRequest::try_new(tenant_id, name)?; + let _guard = + server.scorecard_request(&["op:create_database", format!("tenant:{}", tenant).as_str()]); + let create_database_request = CreateDatabaseRequest::try_new(tenant, name)?; let res = server .frontend .create_database(create_database_request) @@ -528,50 +537,48 @@ struct ListDatabasesParams { /// Lists all databases for a given tenant. #[utoipa::path( get, - path = "/api/v2/tenants/{tenant_id}/databases", + path = "/api/v2/tenants/{tenant}/databases", responses( (status = 200, description = "List of databases", body = ListDatabasesResponse), (status = 401, description = "Unauthorized", body = ErrorResponse), (status = 500, description = "Server error", body = ErrorResponse) ), params( - ("tenant_id" = String, Path, description = "Tenant ID to list databases for"), + ("tenant" = String, Path, description = "Tenant ID to list databases for"), ("limit" = Option, Query, description = "Limit for pagination"), ("offset" = Option, Query, description = "Offset for pagination") ) )] async fn list_databases( headers: HeaderMap, - Path(tenant_id): Path, + Path(tenant): Path, Query(ListDatabasesParams { limit, offset }): Query, State(mut server): State, ) -> Result, ServerError> { server.metrics.list_databases.add(1, &[]); - tracing::info!("Listing database for tenant [{}]", tenant_id); + tracing::info!("Listing database for tenant [{}]", tenant); server .authenticate_and_authorize( &headers, AuthzAction::ListDatabases, AuthzResource { - tenant: Some(tenant_id.clone()), + tenant: Some(tenant.clone()), database: None, collection: None, }, ) .await?; - let _guard = server.scorecard_request(&[ - "op:list_databases", - format!("tenant:{}", tenant_id).as_str(), - ]); + let _guard = + server.scorecard_request(&["op:list_databases", format!("tenant:{}", tenant).as_str()]); - let request = ListDatabasesRequest::try_new(tenant_id, limit, offset)?; + let request = ListDatabasesRequest::try_new(tenant, limit, offset)?; Ok(Json(server.frontend.list_databases(request).await?)) } /// Retrieves a specific database by name. #[utoipa::path( get, - path = "/api/v2/tenants/{tenant_id}/databases/{database_name}", + path = "/api/v2/tenants/{tenant}/databases/{database}", responses( (status = 200, description = "Database retrieved successfully", body = GetDatabaseResponse), (status = 401, description = "Unauthorized", body = ErrorResponse), @@ -579,35 +586,31 @@ async fn list_databases( (status = 500, description = "Server error", body = ErrorResponse) ), params( - ("tenant_id" = String, Path, description = "Tenant ID"), - ("database_name" = String, Path, description = "Name of the database to retrieve") + ("tenant" = String, Path, description = "Tenant ID"), + ("database" = String, Path, description = "Name of the database to retrieve") ) )] async fn get_database( headers: HeaderMap, - Path((tenant_id, database_name)): Path<(String, String)>, + Path((tenant, database)): Path<(String, String)>, State(mut server): State, ) -> Result, ServerError> { server.metrics.get_database.add(1, &[]); - tracing::info!( - "Getting database [{}] for tenant [{}]", - database_name, - tenant_id - ); + tracing::info!("Getting database [{}] for tenant [{}]", database, tenant); server .authenticate_and_authorize( &headers, AuthzAction::GetDatabase, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: None, }, ) .await?; let _guard = - server.scorecard_request(&["op:get_database", format!("tenant:{}", tenant_id).as_str()]); - let request = GetDatabaseRequest::try_new(tenant_id, database_name)?; + server.scorecard_request(&["op:get_database", format!("tenant:{}", tenant).as_str()]); + let request = GetDatabaseRequest::try_new(tenant, database)?; let res = server.frontend.get_database(request).await?; Ok(Json(res)) } @@ -615,7 +618,7 @@ async fn get_database( /// Deletes a specific database. #[utoipa::path( delete, - path = "/api/v2/tenants/{tenant_id}/databases/{database_name}", + path = "/api/v2/tenants/{tenant}/databases/{database}", responses( (status = 200, description = "Database deleted successfully", body = DeleteDatabaseResponse), (status = 401, description = "Unauthorized", body = ErrorResponse), @@ -623,37 +626,31 @@ async fn get_database( (status = 500, description = "Server error", body = ErrorResponse) ), params( - ("tenant_id" = String, Path, description = "Tenant ID"), - ("database_name" = String, Path, description = "Name of the database to delete") + ("tenant" = String, Path, description = "Tenant ID"), + ("database" = String, Path, description = "Name of the database to delete") ) )] async fn delete_database( headers: HeaderMap, - Path((tenant_id, database_name)): Path<(String, String)>, + Path((tenant, database)): Path<(String, String)>, State(mut server): State, ) -> Result, ServerError> { server.metrics.delete_database.add(1, &[]); - tracing::info!( - "Deleting database [{}] for tenant [{}]", - database_name, - tenant_id - ); + tracing::info!("Deleting database [{}] for tenant [{}]", database, tenant); server .authenticate_and_authorize( &headers, AuthzAction::DeleteDatabase, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: None, }, ) .await?; - let _guard = server.scorecard_request(&[ - "op:delete_database", - format!("tenant:{}", tenant_id).as_str(), - ]); - let request = DeleteDatabaseRequest::try_new(tenant_id, database_name)?; + let _guard = + server.scorecard_request(&["op:delete_database", format!("tenant:{}", tenant).as_str()]); + let request = DeleteDatabaseRequest::try_new(tenant, database)?; Ok(Json(server.frontend.delete_database(request).await?)) } @@ -667,30 +664,30 @@ struct ListCollectionsParams { /// Lists all collections in the specified database. #[utoipa::path( get, - path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections", + path = "/api/v2/tenants/{tenant}/databases/{database}/collections", responses( (status = 200, description = "List of collections", body = ListCollectionsResponse), (status = 401, description = "Unauthorized", body = ErrorResponse), (status = 500, description = "Server error", body = ErrorResponse) ), params( - ("tenant_id" = String, Path, description = "Tenant ID"), - ("database_name" = String, Path, description = "Database name to list collections from"), + ("tenant" = String, Path, description = "Tenant ID"), + ("database" = String, Path, description = "Database name to list collections from"), ("limit" = Option, Query, description = "Limit for pagination"), ("offset" = Option, Query, description = "Offset for pagination") ) )] async fn list_collections( headers: HeaderMap, - Path((tenant_id, database_name)): Path<(String, String)>, + Path((tenant, database)): Path<(String, String)>, Query(ListCollectionsParams { limit, offset }): Query, State(mut server): State, ) -> Result, ServerError> { server.metrics.list_collections.add(1, &[]); tracing::info!( "Listing collections in database [{}] for tenant [{}] with limit [{:?}] and offset [{:?}]", - database_name, - tenant_id, + database, + tenant, limit, offset ); @@ -699,8 +696,8 @@ async fn list_collections( &headers, AuthzAction::ListCollections, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: None, }, ) @@ -709,60 +706,55 @@ async fn list_collections( .get("x-chroma-token") .map(|val| val.to_str().unwrap_or_default()) .map(|val| val.to_string()); - let mut quota_payload = - QuotaPayload::new(Action::ListCollections, tenant_id.clone(), api_token); + let mut quota_payload = QuotaPayload::new(Action::ListCollections, tenant.clone(), api_token); if let Some(limit) = limit { quota_payload = quota_payload.with_limit(limit); } server.quota_enforcer.enforce("a_payload).await?; - let _guard = server.scorecard_request(&[ - "op:list_collections", - format!("tenant:{}", tenant_id).as_str(), - ]); - let request = ListCollectionsRequest::try_new(tenant_id, database_name, limit, offset)?; + let _guard = + server.scorecard_request(&["op:list_collections", format!("tenant:{}", tenant).as_str()]); + let request = ListCollectionsRequest::try_new(tenant, database, limit, offset)?; Ok(Json(server.frontend.list_collections(request).await?)) } /// Retrieves the total number of collections in a given database. #[utoipa::path( get, - path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections_count", + path = "/api/v2/tenants/{tenant}/databases/{database}/collections_count", responses( (status = 200, description = "Count of collections", body = CountCollectionsResponse), (status = 401, description = "Unauthorized", body = ErrorResponse), (status = 500, description = "Server error", body = ErrorResponse) ), params( - ("tenant_id" = String, Path, description = "Tenant ID"), - ("database_name" = String, Path, description = "Database name to count collections from") + ("tenant" = String, Path, description = "Tenant ID"), + ("database" = String, Path, description = "Database name to count collections from") ) )] async fn count_collections( headers: HeaderMap, - Path((tenant_id, database_name)): Path<(String, String)>, + Path((tenant, database)): Path<(String, String)>, State(mut server): State, ) -> Result, ServerError> { server.metrics.count_collections.add(1, &[]); - tracing::info!( - "Counting number of collections in database [{database_name}] for tenant [{tenant_id}]", - ); + tracing::info!("Counting number of collections in database [{database}] for tenant [{tenant}]",); server .authenticate_and_authorize( &headers, AuthzAction::CountCollections, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: None, }, ) .await?; let _guard = server.scorecard_request(&[ "op:count_collections", - format!("tenant:{}", tenant_id).as_str(), + format!("tenant:{}", tenant).as_str(), ]); - let request = CountCollectionsRequest::try_new(tenant_id, database_name)?; + let request = CountCollectionsRequest::try_new(tenant, database)?; Ok(Json(server.frontend.count_collections(request).await?)) } @@ -777,7 +769,7 @@ pub struct CreateCollectionPayload { /// Creates a new collection under the specified database. #[utoipa::path( post, - path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections", + path = "/api/v2/tenants/{tenant}/databases/{database}/collections", request_body = CreateCollectionPayload, responses( (status = 200, description = "Collection created successfully", body = Collection), @@ -785,25 +777,25 @@ pub struct CreateCollectionPayload { (status = 500, description = "Server error", body = ErrorResponse) ), params( - ("tenant_id" = String, Path, description = "Tenant ID"), - ("database_name" = String, Path, description = "Database name containing the new collection") + ("tenant" = String, Path, description = "Tenant ID"), + ("database" = String, Path, description = "Database name containing the new collection") ) )] async fn create_collection( headers: HeaderMap, - Path((tenant_id, database_name)): Path<(String, String)>, + Path((tenant, database)): Path<(String, String)>, State(mut server): State, Json(payload): Json, ) -> Result, ServerError> { server.metrics.create_collection.add(1, &[]); - tracing::info!("Creating collection in database [{database_name}] for tenant [{tenant_id}]"); + tracing::info!("Creating collection in database [{database}] for tenant [{tenant}]"); server .authenticate_and_authorize( &headers, AuthzAction::CreateCollection, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: Some(payload.name.clone()), }, ) @@ -812,8 +804,7 @@ async fn create_collection( .get("x-chroma-token") .map(|val| val.to_str().unwrap_or_default()) .map(|val| val.to_string()); - let mut quota_payload = - QuotaPayload::new(Action::CreateCollection, tenant_id.clone(), api_token); + let mut quota_payload = QuotaPayload::new(Action::CreateCollection, tenant.clone(), api_token); quota_payload = quota_payload.with_collection_name(&payload.name); if let Some(metadata) = &payload.metadata { quota_payload = quota_payload.with_create_collection_metadata(metadata); @@ -821,11 +812,11 @@ async fn create_collection( server.quota_enforcer.enforce("a_payload).await?; let _guard = server.scorecard_request(&[ "op:create_collection", - format!("tenant:{}", tenant_id).as_str(), + format!("tenant:{}", tenant).as_str(), ]); let request = CreateCollectionRequest::try_new( - tenant_id, - database_name, + tenant, + database, payload.name, payload.metadata, payload.configuration, @@ -839,7 +830,7 @@ async fn create_collection( /// Retrieves a collection by ID or name. #[utoipa::path( get, - path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections/{collection_id}", + path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}", responses( (status = 200, description = "Collection found", body = Collection), (status = 404, description = "Collection not found", body = ErrorResponse), @@ -847,34 +838,34 @@ async fn create_collection( (status = 500, description = "Server error", body = ErrorResponse) ), params( - ("tenant_id" = String, Path, description = "Tenant ID"), - ("database_name" = String, Path, description = "Database name"), + ("tenant" = String, Path, description = "Tenant ID"), + ("database" = String, Path, description = "Database name"), ("collection_id" = String, Path, description = "UUID of the collection") ) )] async fn get_collection( headers: HeaderMap, - Path((tenant_id, database_name, collection_name)): Path<(String, String, String)>, + Path((tenant, database, collection_name)): Path<(String, String, String)>, State(mut server): State, ) -> Result, ServerError> { server.metrics.get_collection.add(1, &[]); - tracing::info!("Getting collection [{collection_name}] in database [{database_name}] for tenant [{tenant_id}]"); + tracing::info!( + "Getting collection [{collection_name}] in database [{database}] for tenant [{tenant}]" + ); server .authenticate_and_authorize( &headers, AuthzAction::GetCollection, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: Some(collection_name.clone()), }, ) .await?; - let _guard = server.scorecard_request(&[ - "op:get_collection", - format!("tenant:{}", tenant_id).as_str(), - ]); - let request = GetCollectionRequest::try_new(tenant_id, database_name, collection_name)?; + let _guard = + server.scorecard_request(&["op:get_collection", format!("tenant:{}", tenant).as_str()]); + let request = GetCollectionRequest::try_new(tenant, database, collection_name)?; let collection = server.frontend.get_collection(request).await?; Ok(Json(collection)) } @@ -888,7 +879,7 @@ pub struct UpdateCollectionPayload { /// Updates an existing collection's name or metadata. #[utoipa::path( put, - path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections/{collection_id}", + path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}", request_body = UpdateCollectionPayload, responses( (status = 200, description = "Collection updated successfully", body = UpdateCollectionResponse), @@ -897,26 +888,28 @@ pub struct UpdateCollectionPayload { (status = 500, description = "Server error", body = ErrorResponse) ), params( - ("tenant_id" = String, Path, description = "Tenant ID"), - ("database_name" = String, Path, description = "Database name"), + ("tenant" = String, Path, description = "Tenant ID"), + ("database" = String, Path, description = "Database name"), ("collection_id" = String, Path, description = "UUID of the collection to update") ) )] async fn update_collection( headers: HeaderMap, - Path((tenant_id, database_name, collection_id)): Path<(String, String, String)>, + Path((tenant, database, collection_id)): Path<(String, String, String)>, State(mut server): State, Json(payload): Json, ) -> Result, ServerError> { server.metrics.update_collection.add(1, &[]); - tracing::info!("Updating collection [{collection_id}] in database [{database_name}] for tenant [{tenant_id}]"); + tracing::info!( + "Updating collection [{collection_id}] in database [{database}] for tenant [{tenant}]" + ); server .authenticate_and_authorize( &headers, AuthzAction::UpdateCollection, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: Some(collection_id.clone()), }, ) @@ -925,8 +918,7 @@ async fn update_collection( .get("x-chroma-token") .map(|val| val.to_str().unwrap_or_default()) .map(|val| val.to_string()); - let mut quota_payload = - QuotaPayload::new(Action::UpdateCollection, tenant_id.clone(), api_token); + let mut quota_payload = QuotaPayload::new(Action::UpdateCollection, tenant.clone(), api_token); if let Some(new_name) = &payload.new_name { quota_payload = quota_payload.with_collection_new_name(new_name); } @@ -936,7 +928,7 @@ async fn update_collection( server.quota_enforcer.enforce("a_payload).await?; let _guard = server.scorecard_request(&[ "op:update_collection", - format!("tenant:{}", tenant_id).as_str(), + format!("tenant:{}", tenant).as_str(), ]); let collection_id = CollectionUuid::from_str(&collection_id).map_err(|_| ValidationError::CollectionId)?; @@ -957,7 +949,7 @@ async fn update_collection( /// Deletes a collection in a given database. #[utoipa::path( delete, - path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections/{collection_id}", + path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}", responses( (status = 200, description = "Collection deleted successfully", body = UpdateCollectionResponse), (status = 401, description = "Unauthorized", body = ErrorResponse), @@ -965,35 +957,37 @@ async fn update_collection( (status = 500, description = "Server error", body = ErrorResponse) ), params( - ("tenant_id" = String, Path, description = "Tenant ID"), - ("database_name" = String, Path, description = "Database name"), + ("tenant" = String, Path, description = "Tenant ID"), + ("database" = String, Path, description = "Database name"), ("collection_id" = String, Path, description = "UUID of the collection to delete") ) )] async fn delete_collection( headers: HeaderMap, - Path((tenant_id, database_name, collection_name)): Path<(String, String, String)>, + Path((tenant, database, collection_name)): Path<(String, String, String)>, State(mut server): State, ) -> Result, ServerError> { server.metrics.delete_collection.add(1, &[]); - tracing::info!("Deleting collection [{collection_name}] in database [{database_name}] for tenant [{tenant_id}]"); + tracing::info!( + "Deleting collection [{collection_name}] in database [{database}] for tenant [{tenant}]" + ); server .authenticate_and_authorize( &headers, AuthzAction::DeleteCollection, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: Some(collection_name.clone()), }, ) .await?; let _guard = server.scorecard_request(&[ "op:delete_collection", - format!("tenant:{}", tenant_id).as_str(), + format!("tenant:{}", tenant).as_str(), ]); let request = - chroma_types::DeleteCollectionRequest::try_new(tenant_id, database_name, collection_name)?; + chroma_types::DeleteCollectionRequest::try_new(tenant, database, collection_name)?; server.frontend.delete_collection(request).await?; Ok(Json(UpdateCollectionResponse {})) @@ -1011,7 +1005,7 @@ pub struct AddCollectionRecordsPayload { /// Adds records to a collection. #[utoipa::path( post, - path = "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/add", + path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/add", request_body = AddCollectionRecordsPayload, responses( (status = 201, description = "Collection added successfully", body = AddCollectionRecordsResponse), @@ -1020,7 +1014,7 @@ pub struct AddCollectionRecordsPayload { )] async fn collection_add( headers: HeaderMap, - Path((tenant_id, database_name, collection_id)): Path<(String, String, String)>, + Path((tenant, database, collection_id)): Path<(String, String, String)>, State(mut server): State, Json(payload): Json, ) -> Result, ServerError> { @@ -1030,8 +1024,8 @@ async fn collection_add( &headers, AuthzAction::Add, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: Some(collection_id.clone()), }, ) @@ -1042,7 +1036,7 @@ async fn collection_add( .get("x-chroma-token") .map(|val| val.to_str().unwrap_or_default()) .map(|val| val.to_string()); - let mut quota_payload = QuotaPayload::new(Action::Add, tenant_id.clone(), api_token); + let mut quota_payload = QuotaPayload::new(Action::Add, tenant.clone(), api_token); quota_payload = quota_payload.with_ids(&payload.ids); if let Some(embeddings) = &payload.embeddings { quota_payload = quota_payload.with_add_embeddings(embeddings); @@ -1060,13 +1054,13 @@ async fn collection_add( server.quota_enforcer.enforce("a_payload).await?; let _guard = server.scorecard_request(&[ "op:write", - format!("tenant:{}", tenant_id).as_str(), + format!("tenant:{}", tenant).as_str(), format!("collection:{}", collection_id).as_str(), ]); let request = chroma_types::AddCollectionRecordsRequest::try_new( - tenant_id, - database_name, + tenant, + database, collection_id, payload.ids, payload.embeddings, @@ -1092,7 +1086,7 @@ pub struct UpdateCollectionRecordsPayload { /// Updates records in a collection by ID. #[utoipa::path( post, - path = "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/update", + path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/update", request_body = UpdateCollectionRecordsPayload, responses( (status = 200, description = "Collection updated successfully", body = UpdateCollectionRecordsResponse), @@ -1101,7 +1095,7 @@ pub struct UpdateCollectionRecordsPayload { )] async fn collection_update( headers: HeaderMap, - Path((tenant_id, database_name, collection_id)): Path<(String, String, String)>, + Path((tenant, database, collection_id)): Path<(String, String, String)>, State(mut server): State, Json(payload): Json, ) -> Result, ServerError> { @@ -1111,8 +1105,8 @@ async fn collection_update( &headers, AuthzAction::Update, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: Some(collection_id.clone()), }, ) @@ -1123,7 +1117,7 @@ async fn collection_update( .get("x-chroma-token") .map(|val| val.to_str().unwrap_or_default()) .map(|val| val.to_string()); - let mut quota_payload = QuotaPayload::new(Action::Update, tenant_id.clone(), api_token); + let mut quota_payload = QuotaPayload::new(Action::Update, tenant.clone(), api_token); quota_payload = quota_payload.with_ids(&payload.ids); if let Some(embeddings) = &payload.embeddings { quota_payload = quota_payload.with_update_embeddings(embeddings); @@ -1140,13 +1134,13 @@ async fn collection_update( server.quota_enforcer.enforce("a_payload).await?; let _guard = server.scorecard_request(&[ "op:write", - format!("tenant:{}", tenant_id).as_str(), + format!("tenant:{}", tenant).as_str(), format!("collection:{}", collection_id).as_str(), ]); let request = chroma_types::UpdateCollectionRecordsRequest::try_new( - tenant_id, - database_name, + tenant, + database, collection_id, payload.ids, payload.embeddings, @@ -1170,7 +1164,7 @@ pub struct UpsertCollectionRecordsPayload { /// Upserts records in a collection (create if not exists, otherwise update). #[utoipa::path( post, - path = "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/upsert", + path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/upsert", request_body = UpsertCollectionRecordsPayload, responses( (status = 200, description = "Records upserted successfully", body = UpsertCollectionRecordsResponse), @@ -1180,13 +1174,13 @@ pub struct UpsertCollectionRecordsPayload { ), params( ("tenant" = String, Path, description = "Tenant ID"), - ("database_name" = String, Path, description = "Database name"), + ("database" = String, Path, description = "Database name"), ("collection_id" = String, Path, description = "Collection ID"), ) )] async fn collection_upsert( headers: HeaderMap, - Path((tenant_id, database_name, collection_id)): Path<(String, String, String)>, + Path((tenant, database, collection_id)): Path<(String, String, String)>, State(mut server): State, Json(payload): Json, ) -> Result, ServerError> { @@ -1196,8 +1190,8 @@ async fn collection_upsert( &headers, AuthzAction::Update, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: Some(collection_id.clone()), }, ) @@ -1208,7 +1202,7 @@ async fn collection_upsert( .get("x-chroma-token") .map(|val| val.to_str().unwrap_or_default()) .map(|val| val.to_string()); - let mut quota_payload = QuotaPayload::new(Action::Upsert, tenant_id.clone(), api_token); + let mut quota_payload = QuotaPayload::new(Action::Upsert, tenant.clone(), api_token); quota_payload = quota_payload.with_ids(&payload.ids); if let Some(embeddings) = &payload.embeddings { quota_payload = quota_payload.with_add_embeddings(embeddings); @@ -1226,13 +1220,13 @@ async fn collection_upsert( server.quota_enforcer.enforce("a_payload).await?; let _guard = server.scorecard_request(&[ "op:write", - format!("tenant:{}", tenant_id).as_str(), + format!("tenant:{}", tenant).as_str(), format!("collection:{}", collection_id).as_str(), ]); let request = chroma_types::UpsertCollectionRecordsRequest::try_new( - tenant_id, - database_name, + tenant, + database, collection_id, payload.ids, payload.embeddings, @@ -1254,7 +1248,7 @@ pub struct DeleteCollectionRecordsPayload { /// Deletes records in a collection. Can filter by IDs or metadata. #[utoipa::path( post, - path = "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/delete", + path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/delete", request_body = DeleteCollectionRecordsPayload, responses( (status = 200, description = "Records deleted successfully", body = DeleteCollectionRecordsResponse), @@ -1264,13 +1258,13 @@ pub struct DeleteCollectionRecordsPayload { ), params( ("tenant" = String, Path, description = "Tenant ID"), - ("database_name" = String, Path, description = "Database name"), + ("database" = String, Path, description = "Database name"), ("collection_id" = String, Path, description = "Collection ID"), ) )] async fn collection_delete( headers: HeaderMap, - Path((tenant_id, database_name, collection_id)): Path<(String, String, String)>, + Path((tenant, database, collection_id)): Path<(String, String, String)>, State(mut server): State, Json(payload): Json, ) -> Result, ServerError> { @@ -1280,8 +1274,8 @@ async fn collection_delete( &headers, AuthzAction::Delete, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: Some(collection_id.clone()), }, ) @@ -1293,7 +1287,7 @@ async fn collection_delete( .get("x-chroma-token") .map(|val| val.to_str().unwrap_or_default()) .map(|val| val.to_string()); - let mut quota_payload = QuotaPayload::new(Action::Delete, tenant_id.clone(), api_token); + let mut quota_payload = QuotaPayload::new(Action::Delete, tenant.clone(), api_token); if let Some(ids) = &payload.ids { quota_payload = quota_payload.with_ids(ids); } @@ -1303,13 +1297,13 @@ async fn collection_delete( server.quota_enforcer.enforce("a_payload).await?; let _guard = server.scorecard_request(&[ "op:write", - format!("tenant:{}", tenant_id).as_str(), + format!("tenant:{}", tenant).as_str(), format!("collection:{}", collection_id).as_str(), ]); let request = chroma_types::DeleteCollectionRecordsRequest::try_new( - tenant_id, - database_name, + tenant, + database, collection_id, payload.ids, r#where, @@ -1323,7 +1317,7 @@ async fn collection_delete( /// Retrieves the number of records in a collection. #[utoipa::path( get, - path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections/{collection_id}/count", + path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/count", responses( (status = 200, description = "Number of records in the collection", body = CountResponse), (status = 401, description = "Unauthorized", body = ErrorResponse), @@ -1331,47 +1325,47 @@ async fn collection_delete( (status = 500, description = "Server error", body = ErrorResponse) ), params( - ("tenant_id" = String, Path, description = "Tenant ID for the collection"), - ("database_name" = String, Path, description = "Database containing this collection"), + ("tenant" = String, Path, description = "Tenant ID for the collection"), + ("database" = String, Path, description = "Database containing this collection"), ("collection_id" = String, Path, description = "Collection ID whose records are counted") ) )] async fn collection_count( headers: HeaderMap, - Path((tenant_id, database_name, collection_id)): Path<(String, String, String)>, + Path((tenant, database, collection_id)): Path<(String, String, String)>, State(mut server): State, ) -> Result, ServerError> { server.metrics.collection_count.add( 1, &[ - KeyValue::new("tenant_id", tenant_id.clone()), - KeyValue::new("database_name", database_name.clone()), + KeyValue::new("tenant", tenant.clone()), + KeyValue::new("database", database.clone()), KeyValue::new("collection_id", collection_id.clone()), ], ); tracing::info!( - "Counting number of records in collection [{collection_id}] in database [{database_name}] for tenant [{tenant_id}]", + "Counting number of records in collection [{collection_id}] in database [{database}] for tenant [{tenant}]", ); server .authenticate_and_authorize( &headers, AuthzAction::Count, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: Some(collection_id.clone()), }, ) .await?; let _guard = server.scorecard_request(&[ "op:read", - format!("tenant:{}", tenant_id).as_str(), + format!("tenant:{}", tenant).as_str(), format!("collection:{}", collection_id).as_str(), ]); let request = CountRequest::try_new( - tenant_id, - database_name, + tenant, + database, CollectionUuid::from_str(&collection_id).map_err(|_| ValidationError::CollectionId)?, )?; @@ -1392,7 +1386,7 @@ pub struct GetRequestPayload { /// Retrieves records from a collection by ID or metadata filter. #[utoipa::path( post, - path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections/{collection_id}/get", + path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/get", request_body = GetRequestPayload, responses( (status = 200, description = "Records retrieved from the collection", body = GetResponse), @@ -1401,21 +1395,21 @@ pub struct GetRequestPayload { (status = 500, description = "Server error", body = ErrorResponse) ), params( - ("tenant_id" = String, Path, description = "Tenant ID"), - ("database_name" = String, Path, description = "Database name for the collection"), + ("tenant" = String, Path, description = "Tenant ID"), + ("database" = String, Path, description = "Database name for the collection"), ("collection_id" = String, Path, description = "Collection ID to fetch records from") ) )] async fn collection_get( headers: HeaderMap, - Path((tenant_id, database_name, collection_id)): Path<(String, String, String)>, + Path((tenant, database, collection_id)): Path<(String, String, String)>, State(mut server): State, Json(payload): Json, ) -> Result, ServerError> { server.metrics.collection_get.add( 1, &[ - KeyValue::new("tenant_id", tenant_id.clone()), + KeyValue::new("tenant", tenant.clone()), KeyValue::new("collection_id", collection_id.clone()), ], ); @@ -1424,8 +1418,8 @@ async fn collection_get( &headers, AuthzAction::Get, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: Some(collection_id.clone()), }, ) @@ -1437,7 +1431,7 @@ async fn collection_get( .get("x-chroma-token") .map(|val| val.to_str().unwrap_or_default()) .map(|val| val.to_string()); - let mut quota_payload = QuotaPayload::new(Action::Get, tenant_id.clone(), api_token); + let mut quota_payload = QuotaPayload::new(Action::Get, tenant.clone(), api_token); if let Some(ids) = &payload.ids { quota_payload = quota_payload.with_ids(ids); } @@ -1449,16 +1443,16 @@ async fn collection_get( } server.quota_enforcer.enforce("a_payload).await?; tracing::info!( - "Getting records from collection [{collection_id}] in database [{database_name}] for tenant [{tenant_id}]", + "Getting records from collection [{collection_id}] in database [{database}] for tenant [{tenant}]", ); let _guard = server.scorecard_request(&[ "op:read", - format!("tenant:{}", tenant_id).as_str(), + format!("tenant:{}", tenant).as_str(), format!("collection:{}", collection_id).as_str(), ]); let request = GetRequest::try_new( - tenant_id, - database_name, + tenant, + database, collection_id, payload.ids, parsed_where, @@ -1484,7 +1478,7 @@ pub struct QueryRequestPayload { /// Query a collection in a variety of ways, including vector search, metadata filtering, and full-text search #[utoipa::path( post, - path = "/api/v2/tenants/{tenant_id}/databases/{database_name}/collections/{collection_id}/query", + path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/query", request_body = QueryRequestPayload, responses( (status = 200, description = "Records matching the query", body = QueryResponse), @@ -1493,8 +1487,8 @@ pub struct QueryRequestPayload { (status = 500, description = "Server error", body = ErrorResponse), ), params( - ("tenant_id" = String, Path, description = "Tenant ID"), - ("database_name" = String, Path, description = "Database name containing the collection"), + ("tenant" = String, Path, description = "Tenant ID"), + ("database" = String, Path, description = "Database name containing the collection"), ("collection_id" = String, Path, description = "Collection ID to query"), ("limit" = Option, Query, description = "Limit for pagination"), ("offset" = Option, Query, description = "Offset for pagination") @@ -1502,14 +1496,14 @@ pub struct QueryRequestPayload { )] async fn collection_query( headers: HeaderMap, - Path((tenant_id, database_name, collection_id)): Path<(String, String, String)>, + Path((tenant, database, collection_id)): Path<(String, String, String)>, State(mut server): State, Json(payload): Json, ) -> Result, ServerError> { server.metrics.collection_query.add( 1, &[ - KeyValue::new("tenant_id", tenant_id.clone()), + KeyValue::new("tenant", tenant.clone()), KeyValue::new("collection_id", collection_id.clone()), ], ); @@ -1518,8 +1512,8 @@ async fn collection_query( &headers, AuthzAction::Query, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: Some(collection_id.clone()), }, ) @@ -1531,7 +1525,7 @@ async fn collection_query( .get("x-chroma-token") .map(|val| val.to_str().unwrap_or_default()) .map(|val| val.to_string()); - let mut quota_payload = QuotaPayload::new(Action::Query, tenant_id.clone(), api_token); + let mut quota_payload = QuotaPayload::new(Action::Query, tenant.clone(), api_token); if let Some(ids) = &payload.ids { quota_payload = quota_payload.with_ids(ids); } @@ -1544,18 +1538,18 @@ async fn collection_query( } server.quota_enforcer.enforce("a_payload).await?; tracing::info!( - "Querying records from collection [{collection_id}] in database [{database_name}] for tenant [{tenant_id}]", + "Querying records from collection [{collection_id}] in database [{database}] for tenant [{tenant}]", ); let _guard = server.scorecard_request(&[ "op:read", - format!("tenant:{}", tenant_id).as_str(), + format!("tenant:{}", tenant).as_str(), format!("collection:{}", collection_id).as_str(), ]); let request = QueryRequest::try_new( - tenant_id, - database_name, + tenant, + database, collection_id, payload.ids, parsed_where, From ebf3cfbcf385ead78ab6731be65a7dfdfa6a42ec Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Thu, 20 Feb 2025 10:49:39 -0800 Subject: [PATCH 26/30] openapi: update Heartbeat description --- rust/frontend/src/server.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index a561c6a06d8..dc9b0392a35 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -301,7 +301,7 @@ async fn healthcheck(State(server): State) -> impl IntoResponse (code, Json(res)) } -/// Heartbeat endpoint that returns a nanosecond timestamp of the current time. Useful for making sure the client remains connected. +/// Heartbeat endpoint that returns a nanosecond timestamp of the current time. #[utoipa::path( get, path = "/api/v2/heartbeat", From 6da8d2f0487ee80afcbb62a567ef269448d92307 Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Thu, 20 Feb 2025 10:52:48 -0800 Subject: [PATCH 27/30] frontend: document + handle unwrap of components --- rust/frontend/src/server.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index dc9b0392a35..8a9848e0581 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -1577,7 +1577,11 @@ async fn v1_deprecation_notice() -> Response { struct ChromaTokenSecurityAddon; impl Modify for ChromaTokenSecurityAddon { fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) { - let components = openapi.components.as_mut().unwrap(); + // NOTE(philipithomas) - This unwrap is usually safe, and will crash the service on initialization if it's not. + let components = openapi + .components + .as_mut() + .expect("It should be able to get components as mutable"); components.add_security_scheme( "x-chroma-token", SecurityScheme::ApiKey(ApiKey::Header(ApiKeyValue::new("x-chroma-token"))), From 8cf8daf756317b951c3217ff37029f11fe7f1b8f Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Thu, 20 Feb 2025 13:06:17 -0800 Subject: [PATCH 28/30] frontend: simplify heartbeat error --- rust/python_bindings/src/bindings.rs | 2 +- rust/types/src/api_types.rs | 20 +++----------------- 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/rust/python_bindings/src/bindings.rs b/rust/python_bindings/src/bindings.rs index a7895f1f6d1..55845b8439e 100644 --- a/rust/python_bindings/src/bindings.rs +++ b/rust/python_bindings/src/bindings.rs @@ -131,7 +131,7 @@ impl Bindings { fn heartbeat(&self) -> ChromaPyResult { let duration_since_epoch = std::time::SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) - .map_err(|err| HeartbeatError::CouldNotGetTime(err.into()))?; + .map_err(HeartbeatError::from)?; Ok(duration_since_epoch.as_nanos()) } diff --git a/rust/types/src/api_types.rs b/rust/types/src/api_types.rs index de1c108e3d5..32e9eb72300 100644 --- a/rust/types/src/api_types.rs +++ b/rust/types/src/api_types.rs @@ -125,29 +125,15 @@ pub struct HeartbeatResponse { pub nanosecond_heartbeat: u128, } -#[derive(Debug, Error, Serialize, ToSchema)] -#[error("system time error: {message}")] -pub struct ChromaSystemTimeError { - message: String, -} - -impl From for ChromaSystemTimeError { - fn from(err: SystemTimeError) -> Self { - Self { - message: err.to_string(), - } - } -} - #[derive(Debug, Error, ToSchema)] pub enum HeartbeatError { - #[error(transparent)] - CouldNotGetTime(#[from] ChromaSystemTimeError), + #[error("system time error: {0}")] + CouldNotGetTime(String), } impl From for HeartbeatError { fn from(err: SystemTimeError) -> Self { - HeartbeatError::CouldNotGetTime(ChromaSystemTimeError::from(err)) + HeartbeatError::CouldNotGetTime(err.to_string()) } } From 1aa91c9a1f58c5131a517acf1eefe0f789497100 Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Thu, 20 Feb 2025 16:03:59 -0800 Subject: [PATCH 29/30] frontend: fix get_tenant endpoint to use name --- rust/frontend/src/server.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index 8a9848e0581..4480eba1113 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -180,7 +180,7 @@ impl FrontendServer { .route("/api/v2/version", get(version)) .route("/api/v2/auth/identity", get(get_user_identity)) .route("/api/v2/tenants", post(create_tenant)) - .route("/api/v2/tenants/{tenant}", get(get_tenant)) + .route("/api/v2/tenants/{tenant_name}", get(get_tenant)) .route( "/api/v2/tenants/{tenant}/databases", get(list_databases).post(create_database), From f501bf55f653019138dd4f82dedfbe5cd62d4dbc Mon Sep 17 00:00:00 2001 From: "Philip I. Thomas" Date: Thu, 20 Feb 2025 16:52:40 -0800 Subject: [PATCH 30/30] errors: fix spelling mistake --- rust/frontend/src/types/errors.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/frontend/src/types/errors.rs b/rust/frontend/src/types/errors.rs index c69734e5b84..3be52830e6e 100644 --- a/rust/frontend/src/types/errors.rs +++ b/rust/frontend/src/types/errors.rs @@ -20,7 +20,7 @@ pub enum ValidationError { DimensionMismatch(u32, u32), #[error("Error getting collection: {0}")] GetCollection(#[from] GetCollectionError), - #[error("Error updatding collection: {0}")] + #[error("Error updating collection: {0}")] UpdateCollection(#[from] UpdateCollectionError), }