diff --git a/chromadb/api/async_fastapi.py b/chromadb/api/async_fastapi.py index 99f09c3029e..b0e8de18cb1 100644 --- a/chromadb/api/async_fastapi.py +++ b/chromadb/api/async_fastapi.py @@ -6,6 +6,15 @@ import logging import httpx from overrides import override +from tenacity import ( + AsyncRetrying, + RetryError, + before_sleep_log, + retry_if_exception, + stop_after_attempt, + wait_exponential, + wait_random_exponential, +) from chromadb import __version__ from chromadb.auth import UserIdentity from chromadb.api.async_api import AsyncServerAPI @@ -16,6 +25,7 @@ create_collection_configuration_to_json, update_collection_configuration_to_json, ) +from chromadb.api.fastapi import is_retryable_exception from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System, Settings from chromadb.telemetry.opentelemetry import ( OpenTelemetryClient, @@ -140,20 +150,63 @@ def _get_client(self) -> httpx.AsyncClient: async def _make_request( self, method: str, path: str, **kwargs: Dict[str, Any] ) -> Any: - # If the request has json in kwargs, use orjson to serialize it, - # remove it from kwargs, and add it to the content parameter - # This is because httpx uses a slower json serializer - if "json" in kwargs: - data = orjson.dumps(kwargs.pop("json"), option=orjson.OPT_SERIALIZE_NUMPY) - kwargs["content"] = data - - # Unlike requests, httpx does not automatically escape the path - escaped_path = urllib.parse.quote(path, safe="/", encoding=None, errors=None) - url = self._api_url + escaped_path - - response = await self._get_client().request(method, url, **cast(Any, kwargs)) - BaseHTTPClient._raise_chroma_error(response) - return orjson.loads(response.text) + async def _send_request() -> Any: + # If the request has json in kwargs, use orjson to serialize it, + # remove it from kwargs, and add it to the content parameter + # This is because httpx uses a slower json serializer + if "json" in kwargs: + data = orjson.dumps( + kwargs.pop("json"), option=orjson.OPT_SERIALIZE_NUMPY + ) + kwargs["content"] = data + + # Unlike requests, httpx does not automatically escape the path + escaped_path = urllib.parse.quote( + path, safe="/", encoding=None, errors=None + ) + url = self._api_url + escaped_path + + response = await self._get_client().request( + method, url, **cast(Any, kwargs) + ) + BaseHTTPClient._raise_chroma_error(response) + return orjson.loads(response.text) + + retry_config = self._settings.retry_config + + if retry_config is None: + return await _send_request() + + min_delay = max(float(retry_config.min_delay), 0.0) + max_delay = max(float(retry_config.max_delay), min_delay) + multiplier = max(min_delay, 1e-3) + exp_base = retry_config.factor if retry_config.factor > 0 else 2.0 + + wait_args = { + "multiplier": multiplier, + "min": min_delay, + "max": max_delay, + "exp_base": exp_base, + } + + wait_strategy = ( + wait_random_exponential(**wait_args) + if retry_config.jitter + else wait_exponential(**wait_args) + ) + + retrying = AsyncRetrying( + stop=stop_after_attempt(retry_config.max_attempts), + wait=wait_strategy, + retry=retry_if_exception(is_retryable_exception), + before_sleep=before_sleep_log(logger, logging.INFO), + reraise=True, + ) + + try: + return await retrying(_send_request) + except RetryError as e: + raise e.last_attempt.exception() from None @trace_method("AsyncFastAPI.heartbeat", OpenTelemetryGranularity.OPERATION) @override diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index 24e9cbfca5a..f751189b5df 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -6,6 +6,15 @@ import httpx import urllib.parse from overrides import override +from tenacity import ( + RetryError, + Retrying, + before_sleep_log, + retry_if_exception, + stop_after_attempt, + wait_exponential, + wait_random_exponential, +) from chromadb.api.collection_configuration import ( CreateCollectionConfiguration, @@ -57,6 +66,28 @@ logger = logging.getLogger(__name__) +def is_retryable_exception(exception: BaseException) -> bool: + if isinstance( + exception, + ( + httpx.ConnectError, + httpx.ConnectTimeout, + httpx.ReadTimeout, + httpx.WriteTimeout, + httpx.PoolTimeout, + httpx.NetworkError, + httpx.RemoteProtocolError, + ), + ): + return True + + if isinstance(exception, httpx.HTTPStatusError): + # Retry on server errors that might be temporary + return exception.response.status_code in [502, 503, 504] + + return False + + class FastAPI(BaseHTTPClient, ServerAPI): def __init__(self, system: System): super().__init__(system) @@ -97,20 +128,62 @@ def __init__(self, system: System): self._session.headers[header] = value.get_secret_value() def _make_request(self, method: str, path: str, **kwargs: Dict[str, Any]) -> Any: - # If the request has json in kwargs, use orjson to serialize it, - # remove it from kwargs, and add it to the content parameter - # This is because httpx uses a slower json serializer - if "json" in kwargs: - data = orjson.dumps(kwargs.pop("json"), option=orjson.OPT_SERIALIZE_NUMPY) - kwargs["content"] = data - - # Unlike requests, httpx does not automatically escape the path - escaped_path = urllib.parse.quote(path, safe="/", encoding=None, errors=None) - url = self._api_url + escaped_path - - response = self._session.request(method, url, **cast(Any, kwargs)) - BaseHTTPClient._raise_chroma_error(response) - return orjson.loads(response.text) + def _send_request() -> Any: + # If the request has json in kwargs, use orjson to serialize it, + # remove it from kwargs, and add it to the content parameter + # This is because httpx uses a slower json serializer + if "json" in kwargs: + data = orjson.dumps( + kwargs.pop("json"), option=orjson.OPT_SERIALIZE_NUMPY + ) + kwargs["content"] = data + + # Unlike requests, httpx does not automatically escape the path + escaped_path = urllib.parse.quote( + path, safe="/", encoding=None, errors=None + ) + url = self._api_url + escaped_path + + response = self._session.request(method, url, **cast(Any, kwargs)) + BaseHTTPClient._raise_chroma_error(response) + return orjson.loads(response.text) + + retry_config = self._settings.retry_config + + if retry_config is None: + return _send_request() + + min_delay = max(float(retry_config.min_delay), 0.0) + max_delay = max(float(retry_config.max_delay), min_delay) + multiplier = max(min_delay, 1e-3) + exp_base = retry_config.factor if retry_config.factor > 0 else 2.0 + + wait_args = { + "multiplier": multiplier, + "min": min_delay, + "max": max_delay, + "exp_base": exp_base, + } + + wait_strategy = ( + wait_random_exponential(**wait_args) + if retry_config.jitter + else wait_exponential(**wait_args) + ) + + retrying = Retrying( + stop=stop_after_attempt(retry_config.max_attempts), + wait=wait_strategy, + retry=retry_if_exception(is_retryable_exception), + before_sleep=before_sleep_log(logger, logging.INFO), + reraise=True, + ) + + try: + return retrying(_send_request) + except RetryError as e: + # Re-raise the last exception that caused the retry to fail + raise e.last_attempt.exception() from None @trace_method("FastAPI.heartbeat", OpenTelemetryGranularity.OPERATION) @override diff --git a/chromadb/config.py b/chromadb/config.py index d2a4a87b725..37172ebbfa4 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -11,6 +11,7 @@ from overrides import override from typing_extensions import Literal import platform +from pydantic import BaseModel in_pydantic_v2 = False try: @@ -97,6 +98,14 @@ class APIVersion(str, Enum): V2 = "/api/v2" +class RetryConfig(BaseModel): + factor: float = 2.0 + min_delay: int = 1 + max_delay: int = 5 + max_attempts: int = 5 + jitter: bool = True + + # NOTE(hammadb) 1/13/2024 - This has to be in config.py instead of being localized to the module # that uses it because of a circular import issue. This is a temporary solution until we can # refactor the code to remove the circular import. @@ -133,6 +142,8 @@ def empty_str_to_none(cls, v: str) -> Optional[str]: return None return v + retry_config: Optional[RetryConfig] = RetryConfig() + chroma_server_nofile: Optional[int] = None # the number of maximum threads to handle synchronous tasks in the FastAPI server chroma_server_thread_pool_size: int = 40 diff --git a/clients/new-js/packages/chromadb/src/admin-client.ts b/clients/new-js/packages/chromadb/src/admin-client.ts index 4fa0dca49f9..fa986c5de2a 100644 --- a/clients/new-js/packages/chromadb/src/admin-client.ts +++ b/clients/new-js/packages/chromadb/src/admin-client.ts @@ -1,7 +1,8 @@ import { defaultAdminClientArgs, HttpMethod, normalizeMethod } from "./utils"; import { createClient, createConfig } from "@hey-api/client-fetch"; import { Database, DefaultService as Api } from "./api"; -import { chromaFetch } from "./chroma-fetch"; +import { createChromaFetch } from "./chroma-fetch"; +import type { RetryConfig } from "./retry"; /** * Configuration options for the AdminClient. @@ -17,6 +18,8 @@ export interface AdminClientArgs { headers?: Record; /** Additional fetch options for HTTP requests */ fetchOptions?: RequestInit; + /** Retry configuration for HTTP requests. Set to null to disable retries */ + retryConfig?: RetryConfig | null; } /** @@ -43,8 +46,17 @@ export class AdminClient { * @param args - Optional configuration for the admin client */ constructor(args?: AdminClientArgs) { - const { host, port, ssl, headers, fetchOptions } = - args || defaultAdminClientArgs; + const { + host, + port, + ssl, + headers, + fetchOptions, + retryConfig, + } = { + ...defaultAdminClientArgs, + ...(args ?? {}), + }; const baseUrl = `${ssl ? "https" : "http"}://${host}:${port}`; @@ -56,7 +68,7 @@ export class AdminClient { }; this.apiClient = createClient(createConfig(configOptions)); - this.apiClient.setConfig({ fetch: chromaFetch }); + this.apiClient.setConfig({ fetch: createChromaFetch({ retryConfig }) }); } /** diff --git a/clients/new-js/packages/chromadb/src/chroma-client.ts b/clients/new-js/packages/chromadb/src/chroma-client.ts index f285efe1176..216692ca7b9 100644 --- a/clients/new-js/packages/chromadb/src/chroma-client.ts +++ b/clients/new-js/packages/chromadb/src/chroma-client.ts @@ -9,7 +9,8 @@ import { DefaultService as Api, ChecklistResponse } from "./api"; import { CollectionMetadata, UserIdentity } from "./types"; import { Collection, CollectionImpl } from "./collection"; import { EmbeddingFunction, getEmbeddingFunction } from "./embedding-function"; -import { chromaFetch } from "./chroma-fetch"; +import { createChromaFetch } from "./chroma-fetch"; +import type { RetryConfig } from "./retry"; import * as process from "node:process"; import { ChromaConnectionError, @@ -39,6 +40,8 @@ export interface ChromaClientArgs { headers?: Record; /** Additional fetch options for HTTP requests */ fetchOptions?: RequestInit; + /** Retry configuration for HTTP requests. Set to null to disable retries */ + retryConfig?: RetryConfig | null; /** @deprecated Use host, port, and ssl instead */ path?: string; /** @deprecated */ @@ -68,6 +71,7 @@ export class ChromaClient { database = defaultArgs.database, headers = defaultArgs.headers, fetchOptions = defaultArgs.fetchOptions, + retryConfig = defaultArgs.retryConfig, } = args; if (args.path) { @@ -109,7 +113,7 @@ export class ChromaClient { }; this.apiClient = createClient(createConfig(configOptions)); - this.apiClient.setConfig({ fetch: chromaFetch }); + this.apiClient.setConfig({ fetch: createChromaFetch({ retryConfig }) }); } /** diff --git a/clients/new-js/packages/chromadb/src/chroma-fetch.ts b/clients/new-js/packages/chromadb/src/chroma-fetch.ts index cd745df432f..baded205885 100644 --- a/clients/new-js/packages/chromadb/src/chroma-fetch.ts +++ b/clients/new-js/packages/chromadb/src/chroma-fetch.ts @@ -8,45 +8,63 @@ import { ChromaUnauthorizedError, ChromaUniqueError, } from "./errors"; +import { defaultRetryConfig, RetryConfig } from "./retry"; -const offlineError = (error: any): boolean => { - return Boolean( - (error?.name === "TypeError" || error?.name === "FetchError") && - (error.message?.includes("fetch failed") || - error.message?.includes("Failed to fetch") || - error.message?.includes("ENOTFOUND")), - ); +const RETRYABLE_STATUS = new Set([502, 503, 504]); + +const isRetryableError = (error: unknown): boolean => { + if (!error || typeof error !== "object") { + return false; + } + + const name = (error as { name?: unknown }).name; + return name === "TypeError" || name === "FetchError"; }; -export const chromaFetch: typeof fetch = async (input, init) => { - let response: Response; - try { - response = await fetch(input, init); - } catch (err) { - if (offlineError(err)) { - throw new ChromaConnectionError( - "Failed to connect to chromadb. Make sure your server is running and try again. If you are running from a browser, make sure that your chromadb instance is configured to allow requests from the current origin using the CHROMA_SERVER_CORS_ALLOW_ORIGINS environment variable.", - ); - } - throw new ChromaConnectionError("Failed to connect to Chroma"); +const CONNECTION_ERROR_MESSAGE = + "Failed to connect to chromadb. Make sure your server is running and try again. If you are running from a browser, make sure that your chromadb instance is configured to allow requests from the current origin using the CHROMA_SERVER_CORS_ALLOW_ORIGINS environment variable."; + +const shouldRetryResponse = (status: number): boolean => + RETRYABLE_STATUS.has(status); + +const shouldRetryError = (error: unknown): boolean => isRetryableError(error); + +const computeDelaySeconds = (config: RetryConfig, attempt: number): number => { + const exponent = Math.max(attempt - 1, 0); + const exponentialDelay = config.minDelay * Math.pow(config.factor, exponent); + const capped = Math.min(config.maxDelay, Math.max(config.minDelay, exponentialDelay)); + if (!config.jitter) { + return capped; } + return Math.random() * capped; +}; - if (response.ok) { - return response; +const sleep = async (seconds: number): Promise => { + if (seconds <= 0) { + return; } + await new Promise((resolve) => setTimeout(resolve, seconds * 1000)); +}; + +const buildConnectionError = (error: unknown): ChromaConnectionError => + new ChromaConnectionError(CONNECTION_ERROR_MESSAGE, error); +const throwForResponse = async ( + response: Response, + input: RequestInfo | URL, +): Promise => { switch (response.status) { - case 400: + case 400: { let status = "Bad Request"; try { const responseBody = await response.json(); status = responseBody.message || status; } catch {} throw new ChromaClientError( - `Bad request to ${ - (input as Request).url || "Chroma" + `Bad request to ${(input as Request).url || "Chroma" } with status: ${status}`, ); + } case 401: throw new ChromaUnauthorizedError(`Unauthorized`); case 403: @@ -59,17 +77,20 @@ export const chromaFetch: typeof fetch = async (input, init) => { ); case 409: throw new ChromaUniqueError("The resource already exists"); - case 422: - const body = await response.json(); - if ( - body && - body.message && - (body.message.startsWith("Quota exceeded") || - body.message.startsWith("Billing limit exceeded")) - ) { - throw new ChromaQuotaExceededError(body?.message); - } + case 422: { + try { + const body = await response.json(); + if ( + body && + body.message && + (body.message.startsWith("Quota exceeded") || + body.message.startsWith("Billing limit exceeded")) + ) { + throw new ChromaQuotaExceededError(body?.message); + } + } catch {} break; + } case 429: throw new ChromaRateLimitError("Rate limit exceeded"); } @@ -78,3 +99,57 @@ export const chromaFetch: typeof fetch = async (input, init) => { `Unable to connect to the chromadb server (status: ${response.status}). Please try again later.`, ); }; + +export const createChromaFetch = (options?: { + retryConfig?: RetryConfig | null; +}): typeof fetch => { + const userConfig = options?.retryConfig; + const retriesEnabled = userConfig !== null; + const config = userConfig ?? defaultRetryConfig; + const maxAttempts = retriesEnabled + ? Math.max(config.maxAttempts, 1) + : 1; + + return async (input, init) => { + let attempt = 0; + let lastError: unknown; + + while (attempt < maxAttempts) { + attempt += 1; + try { + const response = await fetch(input, init); + if (response.ok) { + return response; + } + + if (retriesEnabled && shouldRetryResponse(response.status)) { + if (attempt < maxAttempts) { + try { + response.body?.cancel(); + } catch {} + await sleep(computeDelaySeconds(config, attempt)); + continue; + } + } + + await throwForResponse(response, input); + } catch (err) { + lastError = err; + if (retriesEnabled && shouldRetryError(err) && attempt < maxAttempts) { + await sleep(computeDelaySeconds(config, attempt)); + continue; + } + + break; + } + } + + if (!lastError) { + throw new ChromaConnectionError(CONNECTION_ERROR_MESSAGE); + } + + throw buildConnectionError(lastError); + }; +}; + +export const chromaFetch = createChromaFetch(); diff --git a/clients/new-js/packages/chromadb/src/retry.ts b/clients/new-js/packages/chromadb/src/retry.ts new file mode 100644 index 00000000000..2f4db55e1d2 --- /dev/null +++ b/clients/new-js/packages/chromadb/src/retry.ts @@ -0,0 +1,15 @@ +export interface RetryConfig { + factor: number; + minDelay: number; + maxDelay: number; + maxAttempts: number; + jitter: boolean; +} + +export const defaultRetryConfig: RetryConfig = { + factor: 2.0, + minDelay: 0.1, + maxDelay: 5.0, + maxAttempts: 5, + jitter: true, +}; diff --git a/clients/new-js/packages/chromadb/src/utils.ts b/clients/new-js/packages/chromadb/src/utils.ts index e44fd908225..5ea141f8d3b 100644 --- a/clients/new-js/packages/chromadb/src/utils.ts +++ b/clients/new-js/packages/chromadb/src/utils.ts @@ -1,5 +1,6 @@ -import { AdminClientArgs } from "./admin-client"; -import { ChromaClientArgs } from "./chroma-client"; +import type { AdminClientArgs } from "./admin-client"; +import type { ChromaClientArgs } from "./chroma-client"; +import { defaultRetryConfig } from "./retry"; import { BaseRecordSet, IncludeEnum, @@ -22,6 +23,7 @@ export const defaultAdminClientArgs: AdminClientArgs = { host: "localhost", port: 8000, ssl: false, + retryConfig: defaultRetryConfig, }; /** Default configuration for ChromaClient connections */