diff --git a/Cargo.lock b/Cargo.lock index 99f5a5c6386..41e629ca563 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -128,6 +128,15 @@ version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" +[[package]] +name = "arbitrary" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dde20b3d026af13f561bdd0f15edf01fc734f0dafcedbaf42bba506a9517f223" +dependencies = [ + "derive_arbitrary", +] + [[package]] name = "arc-swap" version = "1.7.1" @@ -867,9 +876,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" dependencies = [ "async-trait", - "axum-core", + "axum-core 0.4.5", + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "itoa", + "matchit 0.7.3", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "sync_wrapper 1.0.2", + "tower 0.5.2", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d6fd624c75e18b3b4c6b9caf42b1afe24437daaee904069137d8bab077be8b8" +dependencies = [ + "axum-core 0.5.0", "axum-macros", "bytes", + "form_urlencoded", "futures-util", "http 1.1.0", "http-body 1.0.1", @@ -877,7 +913,7 @@ dependencies = [ "hyper 1.5.1", "hyper-util", "itoa", - "matchit", + "matchit 0.8.4", "memchr", "mime", "percent-encoding", @@ -913,14 +949,33 @@ dependencies = [ "sync_wrapper 1.0.2", "tower-layer", "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1362f362fd16024ae199c1970ce98f9661bf5ef94b9808fee734bc3698b733" +dependencies = [ + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper 1.0.2", + "tower-layer", + "tower-service", "tracing", ] [[package]] name = "axum-macros" -version = "0.4.2" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57d123550fa8d071b7255cb0cc04dc302baa6c8c4a79f55701552684d8399bce" +checksum = "604fde5e028fea851ce1d8570bbdc034bec850d157f7569d10f347d06808c05c" dependencies = [ "proc-macro2", "quote", @@ -1103,9 +1158,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.15.4" +version = "3.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ff69b9dd49fd426c69a0db9fc04dd934cdb6645ff000864d98f7e2af8830eaa" +checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" [[package]] name = "bytemuck" @@ -1328,7 +1383,7 @@ name = "chroma-frontend" version = "0.7.0" dependencies = [ "async-trait", - "axum", + "axum 0.8.1", "backon", "chroma-cache", "chroma-config", @@ -1358,6 +1413,9 @@ dependencies = [ "tower 0.4.13", "tower-http 0.6.2", "tracing", + "utoipa", + "utoipa-axum", + "utoipa-swagger-ui", "uuid", "validator", ] @@ -1402,7 +1460,7 @@ name = "chroma-load" version = "0.1.0" dependencies = [ "async-trait", - "axum", + "axum 0.8.1", "chromadb", "chrono", "clap", @@ -1645,6 +1703,7 @@ dependencies = [ "tokio", "tonic", "tonic-build", + "utoipa", "uuid", "validator", ] @@ -2184,6 +2243,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "derive_arbitrary" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30542c1ad912e0e3d22a1935c290e12e8a29d704a420177a31faad4a601a0800" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + [[package]] name = "derive_utils" version = "0.14.2" @@ -3486,6 +3556,7 @@ checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", "hashbrown 0.15.2", + "serde", ] [[package]] @@ -3926,6 +3997,12 @@ dependencies = [ "serde", ] +[[package]] +name = "lockfree-object-pool" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9374ef4228402d4b7e403e5838cb880d9ee663314b0a900d5a6aabf0c213552e" + [[package]] name = "log" version = "0.4.22" @@ -4050,6 +4127,12 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "matrixmultiply" version = "0.3.9" @@ -4125,6 +4208,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -6219,6 +6312,12 @@ dependencies = [ "rand_core", ] +[[package]] +name = "simd-adler32" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" + [[package]] name = "siphasher" version = "1.0.1" @@ -7079,7 +7178,7 @@ checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52" dependencies = [ "async-stream", "async-trait", - "axum", + "axum 0.7.9", "base64 0.22.1", "bytes", "h2 0.4.7", @@ -7350,6 +7449,12 @@ dependencies = [ "version_check", ] +[[package]] +name = "unicase" +version = "2.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" + [[package]] name = "unicode-bidi" version = "0.3.18" @@ -7448,6 +7553,62 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "utoipa" +version = "5.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "435c6f69ef38c9017b4b4eea965dfb91e71e53d869e896db40d1cf2441dd75c0" +dependencies = [ + "indexmap 2.6.0", + "serde", + "serde_json", + "utoipa-gen", +] + +[[package]] +name = "utoipa-axum" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c25bae5bccc842449ec0c5ddc5cbb6a3a1eaeac4503895dc105a1138f8234a0" +dependencies = [ + "axum 0.8.1", + "paste", + "tower-layer", + "tower-service", + "utoipa", +] + +[[package]] +name = "utoipa-gen" +version = "5.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a77d306bc75294fd52f3e99b13ece67c02c1a2789190a6f31d32f736624326f7" +dependencies = [ + "proc-macro2", + "quote", + "regex", + "syn 2.0.89", + "uuid", +] + +[[package]] +name = "utoipa-swagger-ui" +version = "9.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "161166ec520c50144922a625d8bc4925cc801b2dda958ab69878527c0e5c5d61" +dependencies = [ + "axum 0.8.1", + "base64 0.22.1", + "mime_guess", + "regex", + "rust-embed", + "serde", + "serde_json", + "url", + "utoipa", + "zip", +] + [[package]] name = "uuid" version = "1.11.0" @@ -8239,6 +8400,37 @@ dependencies = [ "syn 2.0.89", ] +[[package]] +name = "zip" +version = "2.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae9c1ea7b3a5e1f4b922ff856a129881167511563dc219869afe3787fc0c1a45" +dependencies = [ + "arbitrary", + "crc32fast", + "crossbeam-utils", + "displaydoc", + "flate2", + "indexmap 2.6.0", + "memchr", + "thiserror 2.0.4", + "zopfli", +] + +[[package]] +name = "zopfli" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5019f391bac5cf252e93bbcc53d039ffd62c7bfb7c150414d61369afe57e946" +dependencies = [ + "bumpalo", + "crc32fast", + "lockfree-object-pool", + "log", + "once_cell", + "simd-adler32", +] + [[package]] name = "zstd" version = "0.13.0" diff --git a/Cargo.toml b/Cargo.toml index 6fc4a015976..e2841abc021 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ members = ["rust/benchmark", "rust/blockstore", "rust/cache", "rust/chroma", "ru [workspace.dependencies] arrow = "52.2.0" async-trait = "0.1" -axum = { version = "0.7", features = ["macros"] } +axum = { version = "0.8", features = ["macros"] } chrono = { version = "0.4", features = ["serde"] } clap = { version = "4", features = ["derive"] } criterion = { version = "0.5", features = ["async_tokio"] } @@ -38,6 +38,7 @@ tracing-bunyan-formatter = "0.3" tracing-opentelemetry = "0.28.0" tracing-subscriber = { version = "0.3", features = ["env-filter"] } uuid = { version = "1.11.0", features = ["v4", "fast-rng", "macro-diagnostics", "serde"] } +utoipa = { version = "5.0.0", features = ["macros", "axum_extras", "debug", "uuid"] } sqlx = { version = "0.8.3", features = ["runtime-tokio", "sqlite"] } sha2 = "0.10.8" md5 = "0.7.0" diff --git a/rust/frontend/Cargo.toml b/rust/frontend/Cargo.toml index 7a72ac31c19..ef96e66c99d 100644 --- a/rust/frontend/Cargo.toml +++ b/rust/frontend/Cargo.toml @@ -42,3 +42,6 @@ chroma-system = { workspace = true } chroma-tracing = { workspace = true } chroma-types = { workspace = true } chroma-sqlite = { workspace = true } +utoipa = { workspace = true } +utoipa-axum = { version = "0.2.0", features = ["debug"] } +utoipa-swagger-ui = { version = "9", features = ["axum"] } diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index aa9f835cd37..4480eba1113 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -29,6 +29,11 @@ use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; +use utoipa::openapi::security::{ApiKey, ApiKeyValue, SecurityScheme}; +use utoipa::ToSchema; +use utoipa::{Modify, OpenApi}; +use utoipa_axum::router::OpenApiRouter; +use utoipa_swagger_ui::SwaggerUi; use uuid::Uuid; use crate::{ @@ -151,9 +156,23 @@ impl FrontendServer { #[allow(dead_code)] pub async fn run(server: FrontendServer) { let circuit_breaker_config = server.config.circuit_breaker.clone(); + + let (docs_router, docs_api) = + OpenApiRouter::with_openapi(ApiDoc::openapi()).split_for_parts(); + + let docs_router = docs_router.merge(SwaggerUi::new("/docs").url("/openapi.json", docs_api)); + let app = Router::new() // `GET /` goes to `root` - .route("/api/v1/*any", get(v1_deprecation_notice).put(v1_deprecation_notice).patch(v1_deprecation_notice).delete(v1_deprecation_notice).head(v1_deprecation_notice).options(v1_deprecation_notice)) + .route( + "/api/v1/{*any}", + get(v1_deprecation_notice) + .put(v1_deprecation_notice) + .patch(v1_deprecation_notice) + .delete(v1_deprecation_notice) + .head(v1_deprecation_notice) + .options(v1_deprecation_notice), + ) .route("/api/v2/healthcheck", get(healthcheck)) .route("/api/v2/heartbeat", get(heartbeat)) .route("/api/v2/pre-flight-checks", get(pre_flight_checks)) @@ -161,51 +180,60 @@ impl FrontendServer { .route("/api/v2/version", get(version)) .route("/api/v2/auth/identity", get(get_user_identity)) .route("/api/v2/tenants", post(create_tenant)) - .route("/api/v2/tenants/:tenant_name", get(get_tenant)) - .route("/api/v2/tenants/:tenant_id/databases", get(list_databases).post(create_database)) - .route("/api/v2/tenants/:tenant_id/databases/:name", get(get_database).delete(delete_database)) + .route("/api/v2/tenants/{tenant_name}", get(get_tenant)) + .route( + "/api/v2/tenants/{tenant}/databases", + get(list_databases).post(create_database), + ) + .route( + "/api/v2/tenants/{tenant}/databases/{database}", + get(get_database).delete(delete_database), + ) .route( - "/api/v2/tenants/:tenant_id/databases/:database_name/collections", - post(create_collection).get(list_collections), + "/api/v2/tenants/{tenant}/databases/{database}/collections", + post(create_collection).get(list_collections), ) .route( - "/api/v2/tenants/:tenant_id/databases/:database_name/collections_count", + "/api/v2/tenants/{tenant}/databases/{database}/collections_count", get(count_collections), ) .route( - "/api/v2/tenants/:tenant_id/databases/:database_name/collections/:collection_id", - get(get_collection).put(update_collection).delete(delete_collection), + "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}", + get(get_collection) + .put(update_collection) + .delete(delete_collection), ) .route( - "/api/v2/tenants/:tenant/databases/:database_name/collections/:collection_id/add", + "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/add", post(collection_add), ) .route( - "/api/v2/tenants/:tenant/databases/:database_name/collections/:collection_id/update", + "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/update", post(collection_update), ) .route( - "/api/v2/tenants/:tenant/databases/:database_name/collections/:collection_id/upsert", + "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/upsert", post(collection_upsert), ) .route( - "/api/v2/tenants/:tenant/databases/:database_name/collections/:collection_id/delete", + "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/delete", post(collection_delete), ) .route( - "/api/v2/tenants/:tenant_id/databases/:database_name/collections/:collection_id/count", + "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/count", get(collection_count), ) .route( - "/api/v2/tenants/:tenant_id/databases/:database_name/collections/:collection_id/get", + "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/get", post(collection_get), ) .route( - "/api/v2/tenants/:tenant_id/databases/:database_name/collections/:collection_id/query", + "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/query", post(collection_query), ) + .merge(docs_router) .with_state(server) - .layer(DefaultBodyLimit::max(6000000*8)); // TODO: add to server configuration + .layer(DefaultBodyLimit::max(6000000 * 8)); // TODO: add to server configuration let app = add_tracing_middleware(app); // TODO: configuration for this @@ -254,6 +282,15 @@ impl FrontendServer { // These handlers simply proxy the call and the relevant inputs into // the appropriate method on the `FrontendServer` struct. +/// Health check endpoint that returns 200 if the server and executor are ready +#[utoipa::path( + get, + path = "/api/v2/healthcheck", + responses( + (status = 200, description = "Success", body = String, content_type = "application/json"), + (status = 503, description = "Service Unavailable", body = ErrorResponse), + ) +)] async fn healthcheck(State(server): State) -> impl IntoResponse { server.metrics.healthcheck.add(1, &[]); let res = server.frontend.healthcheck().await; @@ -261,10 +298,18 @@ async fn healthcheck(State(server): State) -> impl IntoResponse tonic::Code::Ok => StatusCode::OK, _ => StatusCode::SERVICE_UNAVAILABLE, }; - (code, Json(res)) } +/// Heartbeat endpoint that returns a nanosecond timestamp of the current time. +#[utoipa::path( + get, + path = "/api/v2/heartbeat", + responses( + (status = 200, description = "Success", body = HeartbeatResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ) +)] async fn heartbeat( State(server): State, ) -> Result, ServerError> { @@ -272,7 +317,15 @@ async fn heartbeat( Ok(Json(server.frontend.heartbeat().await?)) } -// Dummy implementation for now +/// Pre-flight checks endpoint reporting basic readiness info. +#[utoipa::path( + get, + path = "/api/v2/pre-flight-checks", + responses( + (status = 200, description = "Pre flight checks", body = ChecklistResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ) +)] async fn pre_flight_checks( State(server): State, ) -> Result, ServerError> { @@ -282,6 +335,16 @@ async fn pre_flight_checks( })) } +/// Reset endpoint allowing authorized users to reset the database. +#[utoipa::path( + post, + path = "/api/v2/reset", + responses( + (status = 200, description = "Reset successful", body = bool), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ) +)] async fn reset( headers: HeaderMap, State(mut server): State, @@ -302,6 +365,14 @@ async fn reset( Ok(Json(true)) } +/// Returns the version of the server. +#[utoipa::path( + get, + path = "/api/v2/version", + responses( + (status = 200, description = "Get server version", body = String) + ) +)] async fn version(State(server): State) -> Json { server.metrics.version.add(1, &[]); // TODO: Decide on how to handle versioning across python / rust frontend @@ -309,6 +380,15 @@ async fn version(State(server): State) -> Json { Json("0.7.0".to_string()) } +/// Retrieves the current user's identity, tenant, and databases. +#[utoipa::path( + get, + path = "/api/v2/auth/identity", + responses( + (status = 200, description = "Get user identity", body = GetUserIdentityResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ) +)] async fn get_user_identity( headers: HeaderMap, State(server): State, @@ -317,11 +397,22 @@ async fn get_user_identity( Ok(Json(server.auth.get_user_identity(&headers).await?)) } -#[derive(Deserialize, Debug)] +#[derive(Deserialize, Debug, ToSchema)] struct CreateTenantPayload { name: String, } +/// Creates a new tenant. +#[utoipa::path( + post, + path = "/api/v2/tenants", + request_body = CreateTenantPayload, + responses( + (status = 200, description = "Tenant created successfully", body = CreateTenantResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ) +)] async fn create_tenant( headers: HeaderMap, State(mut server): State, @@ -344,6 +435,20 @@ async fn create_tenant( Ok(Json(server.frontend.create_tenant(request).await?)) } +/// Returns an existing tenant by name. +#[utoipa::path( + get, + path = "/api/v2/tenants/{tenant_name}", + params( + ("tenant_name" = String, Path, description = "Tenant name or ID to retrieve") + ), + responses( + (status = 200, description = "Tenant found", body = GetTenantResponse), + (status = 404, description = "Tenant not found", body = ErrorResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ) +)] async fn get_tenant( headers: HeaderMap, Path(name): Path, @@ -366,25 +471,39 @@ async fn get_tenant( Ok(Json(server.frontend.get_tenant(request).await?)) } -#[derive(Deserialize, Debug)] +#[derive(Deserialize, Serialize, ToSchema, Debug)] struct CreateDatabasePayload { name: String, } +/// Creates a new database for a given tenant. +#[utoipa::path( + post, + path = "/api/v2/tenants/{tenant}/databases", + request_body = CreateDatabasePayload, + responses( + (status = 200, description = "Database created successfully", body = CreateDatabaseResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant" = String, Path, description = "Tenant ID to associate with the new database") + ) +)] async fn create_database( headers: HeaderMap, - Path(tenant_id): Path, + Path(tenant): Path, State(mut server): State, Json(CreateDatabasePayload { name }): Json, ) -> Result, ServerError> { server.metrics.create_database.add(1, &[]); - tracing::info!("Creating database [{}] for tenant [{}]", name, tenant_id); + tracing::info!("Creating database [{}] for tenant [{}]", name, tenant); server .authenticate_and_authorize( &headers, AuthzAction::CreateDatabase, AuthzResource { - tenant: Some(tenant_id.clone()), + tenant: Some(tenant.clone()), database: Some(name.clone()), collection: None, }, @@ -395,14 +514,12 @@ async fn create_database( .get("x-chroma-token") .map(|val| val.to_str().unwrap_or_default()) .map(|val| val.to_string()); - let mut quota_payload = QuotaPayload::new(Action::CreateDatabase, tenant_id.clone(), api_token); + let mut quota_payload = QuotaPayload::new(Action::CreateDatabase, tenant.clone(), api_token); quota_payload = quota_payload.with_collection_name(&name); server.quota_enforcer.enforce("a_payload).await?; - let _guard = server.scorecard_request(&[ - "op:create_database", - format!("tenant:{}", tenant_id).as_str(), - ]); - let create_database_request = CreateDatabaseRequest::try_new(tenant_id, name)?; + let _guard = + server.scorecard_request(&["op:create_database", format!("tenant:{}", tenant).as_str()]); + let create_database_request = CreateDatabaseRequest::try_new(tenant, name)?; let res = server .frontend .create_database(create_database_request) @@ -410,118 +527,167 @@ async fn create_database( Ok(Json(res)) } -#[derive(Deserialize)] +#[derive(Deserialize, ToSchema, Debug)] struct ListDatabasesParams { limit: Option, #[serde(default)] offset: u32, } +/// Lists all databases for a given tenant. +#[utoipa::path( + get, + path = "/api/v2/tenants/{tenant}/databases", + responses( + (status = 200, description = "List of databases", body = ListDatabasesResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant" = String, Path, description = "Tenant ID to list databases for"), + ("limit" = Option, Query, description = "Limit for pagination"), + ("offset" = Option, Query, description = "Offset for pagination") + ) +)] async fn list_databases( headers: HeaderMap, - Path(tenant_id): Path, + Path(tenant): Path, Query(ListDatabasesParams { limit, offset }): Query, State(mut server): State, ) -> Result, ServerError> { server.metrics.list_databases.add(1, &[]); - tracing::info!("Listing database for tenant [{}]", tenant_id); + tracing::info!("Listing database for tenant [{}]", tenant); server .authenticate_and_authorize( &headers, AuthzAction::ListDatabases, AuthzResource { - tenant: Some(tenant_id.clone()), + tenant: Some(tenant.clone()), database: None, collection: None, }, ) .await?; - let _guard = server.scorecard_request(&[ - "op:list_databases", - format!("tenant:{}", tenant_id).as_str(), - ]); + let _guard = + server.scorecard_request(&["op:list_databases", format!("tenant:{}", tenant).as_str()]); - let request = ListDatabasesRequest::try_new(tenant_id, limit, offset)?; + let request = ListDatabasesRequest::try_new(tenant, limit, offset)?; Ok(Json(server.frontend.list_databases(request).await?)) } +/// Retrieves a specific database by name. +#[utoipa::path( + get, + path = "/api/v2/tenants/{tenant}/databases/{database}", + responses( + (status = 200, description = "Database retrieved successfully", body = GetDatabaseResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 404, description = "Database not found", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant" = String, Path, description = "Tenant ID"), + ("database" = String, Path, description = "Name of the database to retrieve") + ) +)] async fn get_database( headers: HeaderMap, - Path((tenant_id, database_name)): Path<(String, String)>, + Path((tenant, database)): Path<(String, String)>, State(mut server): State, ) -> Result, ServerError> { server.metrics.get_database.add(1, &[]); - tracing::info!( - "Getting database [{}] for tenant [{}]", - database_name, - tenant_id - ); + tracing::info!("Getting database [{}] for tenant [{}]", database, tenant); server .authenticate_and_authorize( &headers, AuthzAction::GetDatabase, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: None, }, ) .await?; let _guard = - server.scorecard_request(&["op:get_database", format!("tenant:{}", tenant_id).as_str()]); - let request = GetDatabaseRequest::try_new(tenant_id, database_name)?; + server.scorecard_request(&["op:get_database", format!("tenant:{}", tenant).as_str()]); + let request = GetDatabaseRequest::try_new(tenant, database)?; let res = server.frontend.get_database(request).await?; Ok(Json(res)) } +/// Deletes a specific database. +#[utoipa::path( + delete, + path = "/api/v2/tenants/{tenant}/databases/{database}", + responses( + (status = 200, description = "Database deleted successfully", body = DeleteDatabaseResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 404, description = "Database not found", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant" = String, Path, description = "Tenant ID"), + ("database" = String, Path, description = "Name of the database to delete") + ) +)] async fn delete_database( headers: HeaderMap, - Path((tenant_id, database_name)): Path<(String, String)>, + Path((tenant, database)): Path<(String, String)>, State(mut server): State, ) -> Result, ServerError> { server.metrics.delete_database.add(1, &[]); - tracing::info!( - "Deleting database [{}] for tenant [{}]", - database_name, - tenant_id - ); + tracing::info!("Deleting database [{}] for tenant [{}]", database, tenant); server .authenticate_and_authorize( &headers, AuthzAction::DeleteDatabase, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: None, }, ) .await?; - let _guard = server.scorecard_request(&[ - "op:delete_database", - format!("tenant:{}", tenant_id).as_str(), - ]); - let request = DeleteDatabaseRequest::try_new(tenant_id, database_name)?; + let _guard = + server.scorecard_request(&["op:delete_database", format!("tenant:{}", tenant).as_str()]); + let request = DeleteDatabaseRequest::try_new(tenant, database)?; Ok(Json(server.frontend.delete_database(request).await?)) } -#[derive(Deserialize)] +#[derive(Deserialize, Debug)] struct ListCollectionsParams { limit: Option, #[serde(default)] offset: u32, } +/// Lists all collections in the specified database. +#[utoipa::path( + get, + path = "/api/v2/tenants/{tenant}/databases/{database}/collections", + responses( + (status = 200, description = "List of collections", body = ListCollectionsResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant" = String, Path, description = "Tenant ID"), + ("database" = String, Path, description = "Database name to list collections from"), + ("limit" = Option, Query, description = "Limit for pagination"), + ("offset" = Option, Query, description = "Offset for pagination") + ) +)] async fn list_collections( headers: HeaderMap, - Path((tenant_id, database_name)): Path<(String, String)>, + Path((tenant, database)): Path<(String, String)>, Query(ListCollectionsParams { limit, offset }): Query, State(mut server): State, ) -> Result, ServerError> { server.metrics.list_collections.add(1, &[]); tracing::info!( "Listing collections in database [{}] for tenant [{}] with limit [{:?}] and offset [{:?}]", - database_name, - tenant_id, + database, + tenant, limit, offset ); @@ -530,8 +696,8 @@ async fn list_collections( &headers, AuthzAction::ListCollections, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: None, }, ) @@ -540,51 +706,59 @@ async fn list_collections( .get("x-chroma-token") .map(|val| val.to_str().unwrap_or_default()) .map(|val| val.to_string()); - let mut quota_payload = - QuotaPayload::new(Action::ListCollections, tenant_id.clone(), api_token); + let mut quota_payload = QuotaPayload::new(Action::ListCollections, tenant.clone(), api_token); if let Some(limit) = limit { quota_payload = quota_payload.with_limit(limit); } server.quota_enforcer.enforce("a_payload).await?; - let _guard = server.scorecard_request(&[ - "op:list_collections", - format!("tenant:{}", tenant_id).as_str(), - ]); - let request = ListCollectionsRequest::try_new(tenant_id, database_name, limit, offset)?; + let _guard = + server.scorecard_request(&["op:list_collections", format!("tenant:{}", tenant).as_str()]); + let request = ListCollectionsRequest::try_new(tenant, database, limit, offset)?; Ok(Json(server.frontend.list_collections(request).await?)) } +/// Retrieves the total number of collections in a given database. +#[utoipa::path( + get, + path = "/api/v2/tenants/{tenant}/databases/{database}/collections_count", + responses( + (status = 200, description = "Count of collections", body = CountCollectionsResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant" = String, Path, description = "Tenant ID"), + ("database" = String, Path, description = "Database name to count collections from") + ) +)] async fn count_collections( headers: HeaderMap, - Path((tenant_id, database_name)): Path<(String, String)>, + Path((tenant, database)): Path<(String, String)>, State(mut server): State, ) -> Result, ServerError> { server.metrics.count_collections.add(1, &[]); - tracing::info!( - "Counting collections in database [{}] for tenant [{}]", - database_name, - tenant_id - ); + tracing::info!("Counting number of collections in database [{database}] for tenant [{tenant}]",); server .authenticate_and_authorize( &headers, AuthzAction::CountCollections, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: None, }, ) .await?; let _guard = server.scorecard_request(&[ "op:count_collections", - format!("tenant:{}", tenant_id).as_str(), + format!("tenant:{}", tenant).as_str(), ]); - let request = CountCollectionsRequest::try_new(tenant_id, database_name)?; + + let request = CountCollectionsRequest::try_new(tenant, database)?; Ok(Json(server.frontend.count_collections(request).await?)) } -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Serialize, ToSchema, Debug, Clone)] pub struct CreateCollectionPayload { pub name: String, pub configuration: Option, @@ -592,21 +766,36 @@ pub struct CreateCollectionPayload { pub get_or_create: bool, } +/// Creates a new collection under the specified database. +#[utoipa::path( + post, + path = "/api/v2/tenants/{tenant}/databases/{database}/collections", + request_body = CreateCollectionPayload, + responses( + (status = 200, description = "Collection created successfully", body = Collection), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant" = String, Path, description = "Tenant ID"), + ("database" = String, Path, description = "Database name containing the new collection") + ) +)] async fn create_collection( headers: HeaderMap, - Path((tenant_id, database_name)): Path<(String, String)>, + Path((tenant, database)): Path<(String, String)>, State(mut server): State, Json(payload): Json, ) -> Result, ServerError> { server.metrics.create_collection.add(1, &[]); - tracing::info!("Creating collection in database [{database_name}] for tenant [{tenant_id}]"); + tracing::info!("Creating collection in database [{database}] for tenant [{tenant}]"); server .authenticate_and_authorize( &headers, AuthzAction::CreateCollection, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: Some(payload.name.clone()), }, ) @@ -615,8 +804,7 @@ async fn create_collection( .get("x-chroma-token") .map(|val| val.to_str().unwrap_or_default()) .map(|val| val.to_string()); - let mut quota_payload = - QuotaPayload::new(Action::CreateCollection, tenant_id.clone(), api_token); + let mut quota_payload = QuotaPayload::new(Action::CreateCollection, tenant.clone(), api_token); quota_payload = quota_payload.with_collection_name(&payload.name); if let Some(metadata) = &payload.metadata { quota_payload = quota_payload.with_create_collection_metadata(metadata); @@ -624,11 +812,11 @@ async fn create_collection( server.quota_enforcer.enforce("a_payload).await?; let _guard = server.scorecard_request(&[ "op:create_collection", - format!("tenant:{}", tenant_id).as_str(), + format!("tenant:{}", tenant).as_str(), ]); let request = CreateCollectionRequest::try_new( - tenant_id, - database_name, + tenant, + database, payload.name, payload.metadata, payload.configuration, @@ -639,54 +827,89 @@ async fn create_collection( Ok(Json(collection)) } +/// Retrieves a collection by ID or name. +#[utoipa::path( + get, + path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}", + responses( + (status = 200, description = "Collection found", body = Collection), + (status = 404, description = "Collection not found", body = ErrorResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant" = String, Path, description = "Tenant ID"), + ("database" = String, Path, description = "Database name"), + ("collection_id" = String, Path, description = "UUID of the collection") + ) +)] async fn get_collection( headers: HeaderMap, - Path((tenant_id, database_name, collection_name)): Path<(String, String, String)>, + Path((tenant, database, collection_name)): Path<(String, String, String)>, State(mut server): State, ) -> Result, ServerError> { server.metrics.get_collection.add(1, &[]); - tracing::info!("Getting collection [{collection_name}] in database [{database_name}] for tenant [{tenant_id}]"); + tracing::info!( + "Getting collection [{collection_name}] in database [{database}] for tenant [{tenant}]" + ); server .authenticate_and_authorize( &headers, AuthzAction::GetCollection, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: Some(collection_name.clone()), }, ) .await?; - let _guard = server.scorecard_request(&[ - "op:get_collection", - format!("tenant:{}", tenant_id).as_str(), - ]); - let request = GetCollectionRequest::try_new(tenant_id, database_name, collection_name)?; + let _guard = + server.scorecard_request(&["op:get_collection", format!("tenant:{}", tenant).as_str()]); + let request = GetCollectionRequest::try_new(tenant, database, collection_name)?; let collection = server.frontend.get_collection(request).await?; Ok(Json(collection)) } -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Serialize, ToSchema, Debug, Clone)] pub struct UpdateCollectionPayload { pub new_name: Option, pub new_metadata: Option, } +/// Updates an existing collection's name or metadata. +#[utoipa::path( + put, + path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}", + request_body = UpdateCollectionPayload, + responses( + (status = 200, description = "Collection updated successfully", body = UpdateCollectionResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 404, description = "Collection not found", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant" = String, Path, description = "Tenant ID"), + ("database" = String, Path, description = "Database name"), + ("collection_id" = String, Path, description = "UUID of the collection to update") + ) +)] async fn update_collection( headers: HeaderMap, - Path((tenant_id, database_name, collection_id)): Path<(String, String, String)>, + Path((tenant, database, collection_id)): Path<(String, String, String)>, State(mut server): State, Json(payload): Json, ) -> Result, ServerError> { server.metrics.update_collection.add(1, &[]); - tracing::info!("Updating collection [{collection_id}] in database [{database_name}] for tenant [{tenant_id}]"); + tracing::info!( + "Updating collection [{collection_id}] in database [{database}] for tenant [{tenant}]" + ); server .authenticate_and_authorize( &headers, AuthzAction::UpdateCollection, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: Some(collection_id.clone()), }, ) @@ -695,8 +918,7 @@ async fn update_collection( .get("x-chroma-token") .map(|val| val.to_str().unwrap_or_default()) .map(|val| val.to_string()); - let mut quota_payload = - QuotaPayload::new(Action::UpdateCollection, tenant_id.clone(), api_token); + let mut quota_payload = QuotaPayload::new(Action::UpdateCollection, tenant.clone(), api_token); if let Some(new_name) = &payload.new_name { quota_payload = quota_payload.with_collection_new_name(new_name); } @@ -706,7 +928,7 @@ async fn update_collection( server.quota_enforcer.enforce("a_payload).await?; let _guard = server.scorecard_request(&[ "op:update_collection", - format!("tenant:{}", tenant_id).as_str(), + format!("tenant:{}", tenant).as_str(), ]); let collection_id = CollectionUuid::from_str(&collection_id).map_err(|_| ValidationError::CollectionId)?; @@ -724,36 +946,54 @@ async fn update_collection( Ok(Json(UpdateCollectionResponse {})) } +/// Deletes a collection in a given database. +#[utoipa::path( + delete, + path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}", + responses( + (status = 200, description = "Collection deleted successfully", body = UpdateCollectionResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 404, description = "Collection not found", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant" = String, Path, description = "Tenant ID"), + ("database" = String, Path, description = "Database name"), + ("collection_id" = String, Path, description = "UUID of the collection to delete") + ) +)] async fn delete_collection( headers: HeaderMap, - Path((tenant_id, database_name, collection_name)): Path<(String, String, String)>, + Path((tenant, database, collection_name)): Path<(String, String, String)>, State(mut server): State, ) -> Result, ServerError> { server.metrics.delete_collection.add(1, &[]); - tracing::info!("Deleting collection [{collection_name}] in database [{database_name}] for tenant [{tenant_id}]"); + tracing::info!( + "Deleting collection [{collection_name}] in database [{database}] for tenant [{tenant}]" + ); server .authenticate_and_authorize( &headers, AuthzAction::DeleteCollection, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: Some(collection_name.clone()), }, ) .await?; let _guard = server.scorecard_request(&[ "op:delete_collection", - format!("tenant:{}", tenant_id).as_str(), + format!("tenant:{}", tenant).as_str(), ]); let request = - chroma_types::DeleteCollectionRequest::try_new(tenant_id, database_name, collection_name)?; + chroma_types::DeleteCollectionRequest::try_new(tenant, database, collection_name)?; server.frontend.delete_collection(request).await?; Ok(Json(UpdateCollectionResponse {})) } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, ToSchema, Debug, Clone)] pub struct AddCollectionRecordsPayload { ids: Vec, embeddings: Option>>, @@ -762,9 +1002,19 @@ pub struct AddCollectionRecordsPayload { metadatas: Option>>, } +/// Adds records to a collection. +#[utoipa::path( + post, + path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/add", + request_body = AddCollectionRecordsPayload, + responses( + (status = 201, description = "Collection added successfully", body = AddCollectionRecordsResponse), + (status = 400, description = "Invalid data for collection addition") + ) +)] async fn collection_add( headers: HeaderMap, - Path((tenant_id, database_name, collection_id)): Path<(String, String, String)>, + Path((tenant, database, collection_id)): Path<(String, String, String)>, State(mut server): State, Json(payload): Json, ) -> Result, ServerError> { @@ -774,8 +1024,8 @@ async fn collection_add( &headers, AuthzAction::Add, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: Some(collection_id.clone()), }, ) @@ -786,7 +1036,7 @@ async fn collection_add( .get("x-chroma-token") .map(|val| val.to_str().unwrap_or_default()) .map(|val| val.to_string()); - let mut quota_payload = QuotaPayload::new(Action::Add, tenant_id.clone(), api_token); + let mut quota_payload = QuotaPayload::new(Action::Add, tenant.clone(), api_token); quota_payload = quota_payload.with_ids(&payload.ids); if let Some(embeddings) = &payload.embeddings { quota_payload = quota_payload.with_add_embeddings(embeddings); @@ -804,13 +1054,13 @@ async fn collection_add( server.quota_enforcer.enforce("a_payload).await?; let _guard = server.scorecard_request(&[ "op:write", - format!("tenant:{}", tenant_id).as_str(), + format!("tenant:{}", tenant).as_str(), format!("collection:{}", collection_id).as_str(), ]); let request = chroma_types::AddCollectionRecordsRequest::try_new( - tenant_id, - database_name, + tenant, + database, collection_id, payload.ids, payload.embeddings, @@ -824,7 +1074,7 @@ async fn collection_add( Ok(Json(res)) } -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone, ToSchema, Serialize)] pub struct UpdateCollectionRecordsPayload { ids: Vec, embeddings: Option>>>, @@ -833,9 +1083,19 @@ pub struct UpdateCollectionRecordsPayload { metadatas: Option>>, } +/// Updates records in a collection by ID. +#[utoipa::path( + post, + path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/update", + request_body = UpdateCollectionRecordsPayload, + responses( + (status = 200, description = "Collection updated successfully", body = UpdateCollectionRecordsResponse), + (status = 404, description = "Collection not found") + ) +)] async fn collection_update( headers: HeaderMap, - Path((tenant_id, database_name, collection_id)): Path<(String, String, String)>, + Path((tenant, database, collection_id)): Path<(String, String, String)>, State(mut server): State, Json(payload): Json, ) -> Result, ServerError> { @@ -845,8 +1105,8 @@ async fn collection_update( &headers, AuthzAction::Update, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: Some(collection_id.clone()), }, ) @@ -857,7 +1117,7 @@ async fn collection_update( .get("x-chroma-token") .map(|val| val.to_str().unwrap_or_default()) .map(|val| val.to_string()); - let mut quota_payload = QuotaPayload::new(Action::Update, tenant_id.clone(), api_token); + let mut quota_payload = QuotaPayload::new(Action::Update, tenant.clone(), api_token); quota_payload = quota_payload.with_ids(&payload.ids); if let Some(embeddings) = &payload.embeddings { quota_payload = quota_payload.with_update_embeddings(embeddings); @@ -874,13 +1134,13 @@ async fn collection_update( server.quota_enforcer.enforce("a_payload).await?; let _guard = server.scorecard_request(&[ "op:write", - format!("tenant:{}", tenant_id).as_str(), + format!("tenant:{}", tenant).as_str(), format!("collection:{}", collection_id).as_str(), ]); let request = chroma_types::UpdateCollectionRecordsRequest::try_new( - tenant_id, - database_name, + tenant, + database, collection_id, payload.ids, payload.embeddings, @@ -892,7 +1152,7 @@ async fn collection_update( Ok(Json(server.frontend.update(request).await?)) } -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone, ToSchema, Serialize)] pub struct UpsertCollectionRecordsPayload { ids: Vec, embeddings: Option>>, @@ -901,9 +1161,26 @@ pub struct UpsertCollectionRecordsPayload { metadatas: Option>>, } +/// Upserts records in a collection (create if not exists, otherwise update). +#[utoipa::path( + post, + path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/upsert", + request_body = UpsertCollectionRecordsPayload, + responses( + (status = 200, description = "Records upserted successfully", body = UpsertCollectionRecordsResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 404, description = "Collection not found", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse), + ), + params( + ("tenant" = String, Path, description = "Tenant ID"), + ("database" = String, Path, description = "Database name"), + ("collection_id" = String, Path, description = "Collection ID"), + ) +)] async fn collection_upsert( headers: HeaderMap, - Path((tenant_id, database_name, collection_id)): Path<(String, String, String)>, + Path((tenant, database, collection_id)): Path<(String, String, String)>, State(mut server): State, Json(payload): Json, ) -> Result, ServerError> { @@ -913,8 +1190,8 @@ async fn collection_upsert( &headers, AuthzAction::Update, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: Some(collection_id.clone()), }, ) @@ -925,7 +1202,7 @@ async fn collection_upsert( .get("x-chroma-token") .map(|val| val.to_str().unwrap_or_default()) .map(|val| val.to_string()); - let mut quota_payload = QuotaPayload::new(Action::Upsert, tenant_id.clone(), api_token); + let mut quota_payload = QuotaPayload::new(Action::Upsert, tenant.clone(), api_token); quota_payload = quota_payload.with_ids(&payload.ids); if let Some(embeddings) = &payload.embeddings { quota_payload = quota_payload.with_add_embeddings(embeddings); @@ -943,13 +1220,13 @@ async fn collection_upsert( server.quota_enforcer.enforce("a_payload).await?; let _guard = server.scorecard_request(&[ "op:write", - format!("tenant:{}", tenant_id).as_str(), + format!("tenant:{}", tenant).as_str(), format!("collection:{}", collection_id).as_str(), ]); let request = chroma_types::UpsertCollectionRecordsRequest::try_new( - tenant_id, - database_name, + tenant, + database, collection_id, payload.ids, payload.embeddings, @@ -961,16 +1238,33 @@ async fn collection_upsert( Ok(Json(server.frontend.upsert(request).await?)) } -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone, ToSchema, Serialize)] pub struct DeleteCollectionRecordsPayload { ids: Option>, #[serde(flatten)] where_fields: RawWhereFields, } +/// Deletes records in a collection. Can filter by IDs or metadata. +#[utoipa::path( + post, + path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/delete", + request_body = DeleteCollectionRecordsPayload, + responses( + (status = 200, description = "Records deleted successfully", body = DeleteCollectionRecordsResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 404, description = "Collection not found", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse), + ), + params( + ("tenant" = String, Path, description = "Tenant ID"), + ("database" = String, Path, description = "Database name"), + ("collection_id" = String, Path, description = "Collection ID"), + ) +)] async fn collection_delete( headers: HeaderMap, - Path((tenant_id, database_name, collection_id)): Path<(String, String, String)>, + Path((tenant, database, collection_id)): Path<(String, String, String)>, State(mut server): State, Json(payload): Json, ) -> Result, ServerError> { @@ -980,8 +1274,8 @@ async fn collection_delete( &headers, AuthzAction::Delete, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: Some(collection_id.clone()), }, ) @@ -993,7 +1287,7 @@ async fn collection_delete( .get("x-chroma-token") .map(|val| val.to_str().unwrap_or_default()) .map(|val| val.to_string()); - let mut quota_payload = QuotaPayload::new(Action::Delete, tenant_id.clone(), api_token); + let mut quota_payload = QuotaPayload::new(Action::Delete, tenant.clone(), api_token); if let Some(ids) = &payload.ids { quota_payload = quota_payload.with_ids(ids); } @@ -1003,13 +1297,13 @@ async fn collection_delete( server.quota_enforcer.enforce("a_payload).await?; let _guard = server.scorecard_request(&[ "op:write", - format!("tenant:{}", tenant_id).as_str(), + format!("tenant:{}", tenant).as_str(), format!("collection:{}", collection_id).as_str(), ]); let request = chroma_types::DeleteCollectionRecordsRequest::try_new( - tenant_id, - database_name, + tenant, + database, collection_id, payload.ids, r#where, @@ -1020,49 +1314,65 @@ async fn collection_delete( Ok(Json(DeleteCollectionRecordsResponse {})) } +/// Retrieves the number of records in a collection. +#[utoipa::path( + get, + path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/count", + responses( + (status = 200, description = "Number of records in the collection", body = CountResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 404, description = "Collection not found", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant" = String, Path, description = "Tenant ID for the collection"), + ("database" = String, Path, description = "Database containing this collection"), + ("collection_id" = String, Path, description = "Collection ID whose records are counted") + ) +)] async fn collection_count( headers: HeaderMap, - Path((tenant_id, database_name, collection_id)): Path<(String, String, String)>, + Path((tenant, database, collection_id)): Path<(String, String, String)>, State(mut server): State, ) -> Result, ServerError> { server.metrics.collection_count.add( 1, &[ - KeyValue::new("tenant_id", tenant_id.clone()), - KeyValue::new("database_name", database_name.clone()), + KeyValue::new("tenant", tenant.clone()), + KeyValue::new("database", database.clone()), KeyValue::new("collection_id", collection_id.clone()), ], ); tracing::info!( - "Counting number of records in collection [{collection_id}] in database [{database_name}] for tenant [{tenant_id}]", + "Counting number of records in collection [{collection_id}] in database [{database}] for tenant [{tenant}]", ); server .authenticate_and_authorize( &headers, AuthzAction::Count, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: Some(collection_id.clone()), }, ) .await?; let _guard = server.scorecard_request(&[ "op:read", - format!("tenant:{}", tenant_id).as_str(), + format!("tenant:{}", tenant).as_str(), format!("collection:{}", collection_id).as_str(), ]); let request = CountRequest::try_new( - tenant_id, - database_name, + tenant, + database, CollectionUuid::from_str(&collection_id).map_err(|_| ValidationError::CollectionId)?, )?; Ok(Json(server.frontend.count(request).await?)) } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, ToSchema)] pub struct GetRequestPayload { ids: Option>, #[serde(flatten)] @@ -1073,16 +1383,33 @@ pub struct GetRequestPayload { include: IncludeList, } +/// Retrieves records from a collection by ID or metadata filter. +#[utoipa::path( + post, + path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/get", + request_body = GetRequestPayload, + responses( + (status = 200, description = "Records retrieved from the collection", body = GetResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 404, description = "Collection not found", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant" = String, Path, description = "Tenant ID"), + ("database" = String, Path, description = "Database name for the collection"), + ("collection_id" = String, Path, description = "Collection ID to fetch records from") + ) +)] async fn collection_get( headers: HeaderMap, - Path((tenant_id, database_name, collection_id)): Path<(String, String, String)>, + Path((tenant, database, collection_id)): Path<(String, String, String)>, State(mut server): State, Json(payload): Json, ) -> Result, ServerError> { server.metrics.collection_get.add( 1, &[ - KeyValue::new("tenant_id", tenant_id.clone()), + KeyValue::new("tenant", tenant.clone()), KeyValue::new("collection_id", collection_id.clone()), ], ); @@ -1091,8 +1418,8 @@ async fn collection_get( &headers, AuthzAction::Get, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: Some(collection_id.clone()), }, ) @@ -1104,7 +1431,7 @@ async fn collection_get( .get("x-chroma-token") .map(|val| val.to_str().unwrap_or_default()) .map(|val| val.to_string()); - let mut quota_payload = QuotaPayload::new(Action::Get, tenant_id.clone(), api_token); + let mut quota_payload = QuotaPayload::new(Action::Get, tenant.clone(), api_token); if let Some(ids) = &payload.ids { quota_payload = quota_payload.with_ids(ids); } @@ -1116,16 +1443,16 @@ async fn collection_get( } server.quota_enforcer.enforce("a_payload).await?; tracing::info!( - "Getting records from collection [{collection_id}] in database [{database_name}] for tenant [{tenant_id}]", + "Getting records from collection [{collection_id}] in database [{database}] for tenant [{tenant}]", ); let _guard = server.scorecard_request(&[ "op:read", - format!("tenant:{}", tenant_id).as_str(), + format!("tenant:{}", tenant).as_str(), format!("collection:{}", collection_id).as_str(), ]); let request = GetRequest::try_new( - tenant_id, - database_name, + tenant, + database, collection_id, payload.ids, parsed_where, @@ -1137,7 +1464,7 @@ async fn collection_get( Ok(Json(res)) } -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone, Serialize, ToSchema)] pub struct QueryRequestPayload { ids: Option>, #[serde(flatten)] @@ -1148,16 +1475,35 @@ pub struct QueryRequestPayload { include: IncludeList, } +/// Query a collection in a variety of ways, including vector search, metadata filtering, and full-text search +#[utoipa::path( + post, + path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/query", + request_body = QueryRequestPayload, + responses( + (status = 200, description = "Records matching the query", body = QueryResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 404, description = "Collection not found", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse), + ), + params( + ("tenant" = String, Path, description = "Tenant ID"), + ("database" = String, Path, description = "Database name containing the collection"), + ("collection_id" = String, Path, description = "Collection ID to query"), + ("limit" = Option, Query, description = "Limit for pagination"), + ("offset" = Option, Query, description = "Offset for pagination") + ) +)] async fn collection_query( headers: HeaderMap, - Path((tenant_id, database_name, collection_id)): Path<(String, String, String)>, + Path((tenant, database, collection_id)): Path<(String, String, String)>, State(mut server): State, Json(payload): Json, ) -> Result, ServerError> { server.metrics.collection_query.add( 1, &[ - KeyValue::new("tenant_id", tenant_id.clone()), + KeyValue::new("tenant", tenant.clone()), KeyValue::new("collection_id", collection_id.clone()), ], ); @@ -1166,8 +1512,8 @@ async fn collection_query( &headers, AuthzAction::Query, AuthzResource { - tenant: Some(tenant_id.clone()), - database: Some(database_name.clone()), + tenant: Some(tenant.clone()), + database: Some(database.clone()), collection: Some(collection_id.clone()), }, ) @@ -1179,7 +1525,7 @@ async fn collection_query( .get("x-chroma-token") .map(|val| val.to_str().unwrap_or_default()) .map(|val| val.to_string()); - let mut quota_payload = QuotaPayload::new(Action::Query, tenant_id.clone(), api_token); + let mut quota_payload = QuotaPayload::new(Action::Query, tenant.clone(), api_token); if let Some(ids) = &payload.ids { quota_payload = quota_payload.with_ids(ids); } @@ -1192,18 +1538,18 @@ async fn collection_query( } server.quota_enforcer.enforce("a_payload).await?; tracing::info!( - "Querying records from collection [{collection_id}] in database [{database_name}] for tenant [{tenant_id}]", + "Querying records from collection [{collection_id}] in database [{database}] for tenant [{tenant}]", ); let _guard = server.scorecard_request(&[ "op:read", - format!("tenant:{}", tenant_id).as_str(), + format!("tenant:{}", tenant).as_str(), format!("collection:{}", collection_id).as_str(), ]); let request = QueryRequest::try_new( - tenant_id, - database_name, + tenant, + database, collection_id, payload.ids, parsed_where, @@ -1224,3 +1570,55 @@ async fn v1_deprecation_notice() -> Response { ); (StatusCode::GONE, Json(err_response)).into_response() } + +//////////////////////////////////////////////////////////// +/// OpenAPI +//////////////////////////////////////////////////////////// +struct ChromaTokenSecurityAddon; +impl Modify for ChromaTokenSecurityAddon { + fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) { + // NOTE(philipithomas) - This unwrap is usually safe, and will crash the service on initialization if it's not. + let components = openapi + .components + .as_mut() + .expect("It should be able to get components as mutable"); + components.add_security_scheme( + "x-chroma-token", + SecurityScheme::ApiKey(ApiKey::Header(ApiKeyValue::new("x-chroma-token"))), + ); + } +} + +#[derive(OpenApi)] +#[openapi( + paths( + healthcheck, + heartbeat, + pre_flight_checks, + reset, + version, + get_user_identity, + create_tenant, + get_tenant, + list_databases, + create_database, + get_database, + delete_database, + create_collection, + list_collections, + count_collections, + get_collection, + update_collection, + delete_collection, + collection_add, + collection_update, + collection_upsert, + collection_delete, + collection_count, + collection_get, + collection_query + ), + // Apply our new security scheme here + modifiers(&ChromaTokenSecurityAddon) +)] +struct ApiDoc; diff --git a/rust/frontend/src/types/errors.rs b/rust/frontend/src/types/errors.rs index dba157c491b..3be52830e6e 100644 --- a/rust/frontend/src/types/errors.rs +++ b/rust/frontend/src/types/errors.rs @@ -6,7 +6,9 @@ use axum::{ use chroma_error::{ChromaError, ErrorCodes}; use chroma_types::{GetCollectionError, UpdateCollectionError}; use serde::Serialize; +use std::fmt; use thiserror::Error; +use utoipa::ToSchema; #[derive(Error, Debug)] pub enum ValidationError { @@ -18,7 +20,7 @@ pub enum ValidationError { DimensionMismatch(u32, u32), #[error("Error getting collection: {0}")] GetCollection(#[from] GetCollectionError), - #[error("Error updatding collection: {0}")] + #[error("Error updating collection: {0}")] UpdateCollection(#[from] UpdateCollectionError), } @@ -43,7 +45,13 @@ impl From for ServerError { } } -#[derive(Serialize)] +impl fmt::Display for ServerError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +#[derive(Serialize, ToSchema)] pub struct ErrorResponse { error: String, message: String, diff --git a/rust/python_bindings/src/bindings.rs b/rust/python_bindings/src/bindings.rs index bec23b92af9..55845b8439e 100644 --- a/rust/python_bindings/src/bindings.rs +++ b/rust/python_bindings/src/bindings.rs @@ -131,7 +131,7 @@ impl Bindings { fn heartbeat(&self) -> ChromaPyResult { let duration_since_epoch = std::time::SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) - .map_err(HeartbeatError::CouldNotGetTime)?; + .map_err(HeartbeatError::from)?; Ok(duration_since_epoch.as_nanos()) } diff --git a/rust/types/Cargo.toml b/rust/types/Cargo.toml index 26f18dc9ee7..94147889295 100644 --- a/rust/types/Cargo.toml +++ b/rust/types/Cargo.toml @@ -19,7 +19,7 @@ tokio = { workspace = true } pyo3 = { workspace = true } validator = { workspace = true } regex = { workspace = true } - +utoipa = { workspace = true } # (Cross-crate testing dependencies) proptest = { workspace = true, optional = true } proptest-derive = { workspace = true, optional = true } diff --git a/rust/types/src/api_types.rs b/rust/types/src/api_types.rs index 4e04cfc5430..32e9eb72300 100644 --- a/rust/types/src/api_types.rs +++ b/rust/types/src/api_types.rs @@ -26,6 +26,7 @@ use serde_json::Value; use std::time::SystemTimeError; use thiserror::Error; use tonic::Status; +use utoipa::ToSchema; use uuid::Uuid; use validator::Validate; use validator::ValidationError; @@ -113,21 +114,27 @@ impl ChromaError for ResetError { } } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] pub struct ChecklistResponse { pub max_batch_size: u32, } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] pub struct HeartbeatResponse { #[serde(rename(serialize = "nanosecond heartbeat"))] pub nanosecond_heartbeat: u128, } -#[derive(Debug, Error)] +#[derive(Debug, Error, ToSchema)] pub enum HeartbeatError { - #[error(transparent)] - CouldNotGetTime(#[from] SystemTimeError), + #[error("system time error: {0}")] + CouldNotGetTime(String), +} + +impl From for HeartbeatError { + fn from(err: SystemTimeError) -> Self { + HeartbeatError::CouldNotGetTime(err.to_string()) + } } impl ChromaError for HeartbeatError { @@ -136,7 +143,7 @@ impl ChromaError for HeartbeatError { } } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] pub struct GetUserIdentityResponse { pub user_id: String, pub tenant: String, @@ -144,7 +151,7 @@ pub struct GetUserIdentityResponse { } #[non_exhaustive] -#[derive(Deserialize, Validate)] +#[derive(Serialize, Validate, Deserialize, ToSchema)] pub struct CreateTenantRequest { #[validate(length(min = 3))] pub name: String, @@ -158,7 +165,7 @@ impl CreateTenantRequest { } } -#[derive(Serialize)] +#[derive(Serialize, Deserialize, ToSchema)] pub struct CreateTenantResponse {} #[derive(Debug, Error)] @@ -192,7 +199,7 @@ impl GetTenantRequest { } } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] #[pyclass] pub struct GetTenantResponse { pub name: String, @@ -240,7 +247,7 @@ impl CreateDatabaseRequest { } } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] pub struct CreateDatabaseResponse {} #[derive(Error, Debug)] @@ -260,7 +267,7 @@ impl ChromaError for CreateDatabaseError { } } -#[derive(Serialize, Debug)] +#[derive(Serialize, Debug, ToSchema)] #[pyo3::pyclass] pub struct Database { pub id: Uuid, @@ -388,7 +395,7 @@ impl DeleteDatabaseRequest { } } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] pub struct DeleteDatabaseResponse {} #[derive(Debug, Error)] @@ -638,7 +645,7 @@ impl UpdateCollectionRequest { } } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] pub struct UpdateCollectionResponse {} #[derive(Error, Debug)] @@ -776,7 +783,7 @@ impl AddCollectionRecordsRequest { } } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] pub struct AddCollectionRecordsResponse {} #[derive(Error, Debug)] @@ -838,7 +845,7 @@ impl UpdateCollectionRecordsRequest { } } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] pub struct UpdateCollectionRecordsResponse {} #[derive(Error, Debug)] @@ -897,7 +904,7 @@ impl UpsertCollectionRecordsRequest { } } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] pub struct UpsertCollectionRecordsResponse {} #[derive(Error, Debug)] @@ -917,7 +924,7 @@ impl ChromaError for UpsertCollectionRecordsError { ////////////////////////// DeleteCollectionRecords ////////////////////////// #[non_exhaustive] -#[derive(Clone, Validate)] +#[derive(Clone, Validate, ToSchema)] pub struct DeleteCollectionRecordsRequest { pub tenant_id: String, pub database_name: String, @@ -954,7 +961,7 @@ impl DeleteCollectionRecordsRequest { } } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] pub struct DeleteCollectionRecordsResponse {} #[derive(Error, Debug)] @@ -986,7 +993,7 @@ impl ChromaError for IncludeParsingError { } } -#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] pub enum Include { #[serde(rename = "distances")] Distance, @@ -1015,7 +1022,7 @@ impl TryFrom<&str> for Include { } } -#[derive(Debug, Clone, Deserialize)] +#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, PartialEq)] #[pyclass] pub struct IncludeList(pub Vec); @@ -1120,7 +1127,7 @@ impl GetRequest { } } -#[derive(Clone, Deserialize, Serialize, Debug)] +#[derive(Clone, Deserialize, Serialize, Debug, ToSchema)] #[pyclass] pub struct GetResponse { #[pyo3(get)] @@ -1235,7 +1242,7 @@ impl QueryRequest { } } -#[derive(Clone, Deserialize, Serialize)] +#[derive(Clone, Deserialize, Serialize, ToSchema)] #[pyclass] pub struct QueryResponse { #[pyo3(get)] diff --git a/rust/types/src/collection.rs b/rust/types/src/collection.rs index 1e91a821a67..1d0ec407b05 100644 --- a/rust/types/src/collection.rs +++ b/rust/types/src/collection.rs @@ -5,11 +5,23 @@ use pyo3::types::PyAnyMethods; use serde::{Deserialize, Serialize}; use serde_json::Value; use thiserror::Error; +use utoipa::ToSchema; use uuid::Uuid; /// CollectionUuid is a wrapper around Uuid to provide a type for the collection id. #[derive( - Copy, Clone, Debug, Default, Deserialize, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, + Copy, + Clone, + Debug, + Default, + Deserialize, + Eq, + PartialEq, + Ord, + PartialOrd, + Hash, + Serialize, + ToSchema, )] pub struct CollectionUuid(pub Uuid); @@ -36,7 +48,7 @@ impl std::fmt::Display for CollectionUuid { } } -#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] #[pyo3::pyclass] pub struct Collection { #[serde(rename(serialize = "id"))] diff --git a/rust/types/src/metadata.rs b/rust/types/src/metadata.rs index 7d98742304d..6aa7fba6892 100644 --- a/rust/types/src/metadata.rs +++ b/rust/types/src/metadata.rs @@ -7,10 +7,11 @@ use std::{ collections::{HashMap, HashSet}, }; use thiserror::Error; +use utoipa::ToSchema; use crate::chroma_proto; -#[derive(Clone, Debug, PartialEq, PartialOrd, Deserialize, Serialize)] +#[derive(Clone, Debug, PartialEq, PartialOrd, Deserialize, Serialize, ToSchema)] #[serde(untagged)] pub enum UpdateMetadataValue { Bool(bool), @@ -118,7 +119,15 @@ MetadataValue */ #[derive( - Clone, Debug, Deserialize, PartialEq, PartialOrd, Serialize, FromPyObject, IntoPyObject, + Clone, + Debug, + Deserialize, + PartialEq, + PartialOrd, + Serialize, + FromPyObject, + IntoPyObject, + ToSchema, )] #[serde(untagged)] pub enum MetadataValue { @@ -438,7 +447,7 @@ impl WhereConversionError { /// present we simply create a conjunction of both clauses as the actual filter. This is consistent with /// the semantics we used to have when the `where` and `where_document` clauses are treated seperately. // TODO: Remove this note once the `where` clause and `where_document` clause is unified in the API level. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, ToSchema)] pub enum Where { Composite(CompositeExpression), Document(DocumentExpression), @@ -512,7 +521,7 @@ impl TryFrom for chroma_proto::Where { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, ToSchema)] pub struct CompositeExpression { pub operator: BooleanOperator, pub children: Vec, @@ -548,7 +557,7 @@ impl TryFrom for chroma_proto::WhereChildren { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, ToSchema)] pub enum BooleanOperator { And, Or, @@ -572,7 +581,7 @@ impl From for chroma_proto::BooleanOperator { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, ToSchema)] pub struct DocumentExpression { pub operator: DocumentOperator, pub text: String, @@ -596,7 +605,7 @@ impl From for chroma_proto::DirectWhereDocument { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, ToSchema)] pub enum DocumentOperator { Contains, NotContains, @@ -619,7 +628,7 @@ impl From for chroma_proto::WhereDocumentOperator { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, ToSchema)] pub struct MetadataExpression { pub key: String, pub comparison: MetadataComparison, @@ -749,7 +758,7 @@ impl TryFrom for chroma_proto::DirectComparison { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, ToSchema)] pub enum MetadataComparison { Primitive(PrimitiveOperator, MetadataValue), Set(SetOperator, MetadataSetValue), diff --git a/rust/types/src/where_parsing.rs b/rust/types/src/where_parsing.rs index 33cedc47477..757ec43cb6b 100644 --- a/rust/types/src/where_parsing.rs +++ b/rust/types/src/where_parsing.rs @@ -1,10 +1,12 @@ use crate::{CompositeExpression, DocumentOperator, MetadataExpression, PrimitiveOperator, Where}; use chroma_error::ChromaError; use serde::Deserialize; +use serde::Serialize; use serde_json::Value; use thiserror::Error; +use utoipa::ToSchema; -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone, Serialize, ToSchema)] pub struct RawWhereFields { #[serde(default)] r#where: Value,