diff --git a/.github/workflows/_python-tests.yml b/.github/workflows/_python-tests.yml index 5b786e7acc4..6bd226573ac 100644 --- a/.github/workflows/_python-tests.yml +++ b/.github/workflows/_python-tests.yml @@ -141,6 +141,7 @@ jobs: - "chromadb/test/distributed/test_sanity.py" - "chromadb/test/distributed/test_log_backpressure.py" - "chromadb/test/distributed/test_repair_collection_log_offset.py" + - "chromadb/test/distributed/test_task_api.py" include: - test-glob: "chromadb/test/property/test_add.py" parallelized: false diff --git a/Cargo.lock b/Cargo.lock index 6825079b7e4..f7802002a29 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1928,6 +1928,7 @@ dependencies = [ "futures", "parking_lot", "prost 0.13.5", + "prost-types 0.13.5", "sea-query", "sea-query-binder", "serde", diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index 2234459e88d..91f597fbc06 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Sequence, Optional, List +from typing import Sequence, Optional, List, Dict, Any from uuid import UUID from overrides import override @@ -775,3 +775,53 @@ def _delete( database: str = DEFAULT_DATABASE, ) -> None: pass + + @abstractmethod + def create_task( + self, + task_name: str, + operator_name: str, + input_collection_id: UUID, + output_collection_name: str, + params: Optional[Dict[str, Any]] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> tuple[bool, str]: + """Create a recurring task on a collection. + + Args: + task_name: Unique name for this task instance + operator_name: Built-in operator name (e.g., 'record_counter') + input_collection_id: Source collection that triggers the task + output_collection_name: Target collection where task output is stored + params: Optional dictionary with operator-specific parameters + tenant: The tenant name + database: The database name + + Returns: + tuple: (success: bool, task_id: str) + """ + pass + + @abstractmethod + def remove_task( + self, + task_name: str, + input_collection_id: UUID, + delete_output: bool = False, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> bool: + """Delete a task and prevent any further runs. + + Args: + task_name: Name of the task to remove + input_collection_id: Id of the input collection the task is registered on + delete_output: Whether to also delete the output collection + tenant: The tenant name + database: The database name + + Returns: + bool: True if successful + """ + pass diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index f12c3f5cf45..504f076fb35 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -702,3 +702,49 @@ def get_max_batch_size(self) -> int: pre_flight_checks = self.get_pre_flight_checks() max_batch_size = cast(int, pre_flight_checks.get("max_batch_size", -1)) return max_batch_size + + @trace_method("FastAPI.create_task", OpenTelemetryGranularity.ALL) + @override + def create_task( + self, + task_name: str, + operator_name: str, + input_collection_id: UUID, + output_collection_name: str, + params: Optional[Dict[str, Any]] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> tuple[bool, str]: + """Register a recurring task on a collection.""" + resp_json = self._make_request( + "post", + f"/tenants/{tenant}/databases/{database}/collections/{input_collection_id}/tasks/create", + json={ + "task_name": task_name, + "operator_name": operator_name, + "output_collection_name": output_collection_name, + "params": params, + }, + ) + return cast(bool, resp_json["success"]), cast(str, resp_json["task_id"]) + + @trace_method("FastAPI.remove_task", OpenTelemetryGranularity.ALL) + @override + def remove_task( + self, + task_name: str, + input_collection_id: UUID, + delete_output: bool = False, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> bool: + """Delete a task and prevent any further runs.""" + resp_json = self._make_request( + "post", + f"/tenants/{tenant}/databases/{database}/collections/{input_collection_id}/tasks/delete", + json={ + "task_name": task_name, + "delete_output": delete_output, + }, + ) + return cast(bool, resp_json["success"]) diff --git a/chromadb/api/models/Collection.py b/chromadb/api/models/Collection.py index 13bc47e0874..6ba6a693a61 100644 --- a/chromadb/api/models/Collection.py +++ b/chromadb/api/models/Collection.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional, Union, List, cast +from typing import TYPE_CHECKING, Optional, Union, List, cast, Dict, Any from chromadb.api.models.CollectionCommon import CollectionCommon from chromadb.api.types import ( @@ -327,29 +327,29 @@ def search( from chromadb.execution.expression import ( Search, Key, K, Knn, Val ) - + # Note: K is an alias for Key, so K.DOCUMENT == Key.DOCUMENT search = (Search() .where((K("category") == "science") & (K("score") > 0.5)) .rank(Knn(query=[0.1, 0.2, 0.3]) * 0.8 + Val(0.5) * 0.2) .limit(10, offset=0) .select(K.DOCUMENT, K.SCORE, "title")) - + # Direct construction from chromadb.execution.expression import ( Search, Eq, And, Gt, Knn, Limit, Select, Key ) - + search = Search( where=And([Eq("category", "science"), Gt("score", 0.5)]), rank=Knn(query=[0.1, 0.2, 0.3]), limit=Limit(offset=0, limit=10), select=Select(keys={Key.DOCUMENT, Key.SCORE, "title"}) ) - + # Single search result = collection.search(search) - + # Multiple searches at once searches = [ Search().where(K("type") == "article").rank(Knn(query=[0.1, 0.2])), @@ -490,3 +490,64 @@ def delete( tenant=self.tenant, database=self.database, ) + + def create_task( + self, + task_name: str, + operator_name: str, + output_collection_name: str, + params: Optional[Dict[str, Any]] = None, + ) -> tuple[bool, str]: + """Create a recurring task that processes this collection. + + Args: + task_name: Unique name for this task instance + operator_name: Built-in operator name (e.g., "record_counter") + output_collection_name: Name of the collection where task output will be stored + params: Optional dictionary with operator-specific parameters + + Returns: + tuple: (success: bool, task_id: str) + + Example: + >>> success, task_id = collection.create_task( + ... task_name="count_docs", + ... operator_name="record_counter", + ... output_collection_name="doc_counts", + ... params={"threshold": 100} + ... ) + """ + return self._client.create_task( + task_name=task_name, + operator_name=operator_name, + input_collection_id=self.id, + output_collection_name=output_collection_name, + params=params, + tenant=self.tenant, + database=self.database, + ) + + def remove_task( + self, + task_name: str, + delete_output: bool = False, + ) -> bool: + """Delete a task and prevent any further runs. + + Args: + task_name: Name of the task to remove + delete_output: Whether to also delete the output collection. Defaults to False. + + Returns: + bool: True if successful + + Example: + >>> success = collection.remove_task("count_docs", delete_output=True) + """ + return self._client.remove_task( + task_name=task_name, + input_collection_id=self.id, + delete_output=delete_output, + tenant=self.tenant, + database=self.database, + ) diff --git a/chromadb/api/rust.py b/chromadb/api/rust.py index 9f7e34cea1a..3aae75d030e 100644 --- a/chromadb/api/rust.py +++ b/chromadb/api/rust.py @@ -44,7 +44,7 @@ import chromadb_rust_bindings -from typing import Optional, Sequence, List +from typing import Optional, Sequence, List, Dict, Any from overrides import override from uuid import UUID import json @@ -587,6 +587,38 @@ def get_settings(self) -> Settings: def get_max_batch_size(self) -> int: return self.bindings.get_max_batch_size() + @override + def create_task( + self, + task_name: str, + operator_name: str, + input_collection_id: UUID, + output_collection_name: str, + params: Optional[Dict[str, Any]] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> tuple[bool, str]: + """Tasks are not supported in the Rust bindings (local embedded mode).""" + raise NotImplementedError( + "Tasks are only supported when connecting to a Chroma server via HttpClient. " + "The Rust bindings (embedded mode) do not support task operations." + ) + + @override + def remove_task( + self, + task_name: str, + input_collection_id: UUID, + delete_output: bool = False, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> bool: + """Tasks are not supported in the Rust bindings (local embedded mode).""" + raise NotImplementedError( + "Tasks are only supported when connecting to a Chroma server via HttpClient. " + "The Rust bindings (embedded mode) do not support task operations." + ) + # TODO: Remove this if it's not planned to be used @override def get_user_identity(self) -> UserIdentity: diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index bf0f494d52b..5961528a0d2 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -67,6 +67,7 @@ Generator, List, Any, + Dict, Callable, TypeVar, ) @@ -906,6 +907,38 @@ def get_settings(self) -> Settings: def get_max_batch_size(self) -> int: return self._producer.max_batch_size + @override + def create_task( + self, + task_name: str, + operator_name: str, + input_collection_id: UUID, + output_collection_name: str, + params: Optional[Dict[str, Any]] = None, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> tuple[bool, str]: + """Tasks are not supported in the Segment API (local embedded mode).""" + raise NotImplementedError( + "Tasks are only supported when connecting to a Chroma server via HttpClient. " + "The Segment API (embedded mode) does not support task operations." + ) + + @override + def remove_task( + self, + task_name: str, + input_collection_id: UUID, + delete_output: bool = False, + tenant: str = DEFAULT_TENANT, + database: str = DEFAULT_DATABASE, + ) -> bool: + """Tasks are not supported in the Segment API (local embedded mode).""" + raise NotImplementedError( + "Tasks are only supported when connecting to a Chroma server via HttpClient. " + "The Segment API (embedded mode) does not support task operations." + ) + # TODO: This could potentially cause race conditions in a distributed version of the # system, since the cache is only local. # TODO: promote collection -> topic to a base class method so that it can be diff --git a/chromadb/test/distributed/test_task_api.py b/chromadb/test/distributed/test_task_api.py new file mode 100644 index 00000000000..32d9a961e70 --- /dev/null +++ b/chromadb/test/distributed/test_task_api.py @@ -0,0 +1,208 @@ +""" +Integration test for Chroma's Task API + +Tests the task creation, execution, and removal functionality +for automatically processing collections. +""" + +import pytest +from chromadb.api.client import Client as ClientCreator +from chromadb.config import System +from chromadb.errors import ChromaError, NotFoundError + + +def test_task_create_and_remove(basic_http_client: System) -> None: + """Test creating and removing a task with the record_counter operator""" + client = ClientCreator.from_system(basic_http_client) + client.reset() + + # Create a collection + collection = client.get_or_create_collection( + name="my_document", + metadata={"description": "Sample documents for task processing"}, + ) + + # Add initial documents + collection.add( + ids=["doc1", "doc2", "doc3"], + documents=[ + "The quick brown fox jumps over the lazy dog", + "Machine learning is a subset of artificial intelligence", + "Python is a popular programming language", + ], + metadatas=[{"source": "proverb"}, {"source": "tech"}, {"source": "tech"}], + ) + + # Verify collection has documents + assert collection.count() == 3 + + # Create a task that counts records in the collection + success, task_id = collection.create_task( + task_name="count_my_docs", + operator_name="record_counter", # Built-in operator that counts records + output_collection_name="my_documents_counts", + params=None, + ) + + # Verify task creation succeeded + assert success is True + assert task_id is not None + assert len(task_id) > 0 + + # Add more documents + collection.add( + ids=["doc4", "doc5"], + documents=[ + "Chroma is a vector database", + "Tasks automate data processing", + ], + ) + + # Verify documents were added + assert collection.count() == 5 + + # Remove the task + success = collection.remove_task( + task_name="count_my_docs", + delete_output=True, + ) + + # Verify task removal succeeded + assert success is True + + +def test_task_with_invalid_operator(basic_http_client: System) -> None: + """Test that creating a task with an invalid operator raises an error""" + client = ClientCreator.from_system(basic_http_client) + client.reset() + + collection = client.get_or_create_collection(name="test_invalid_operator") + collection.add(ids=["id1"], documents=["test document"]) + + # Attempt to create task with non-existent operator should raise ChromaError + with pytest.raises(ChromaError, match="operator not found"): + collection.create_task( + task_name="invalid_task", + operator_name="nonexistent_operator", + output_collection_name="output_collection", + params=None, + ) + + +def test_task_multiple_collections(basic_http_client: System) -> None: + """Test creating tasks on multiple collections""" + client = ClientCreator.from_system(basic_http_client) + client.reset() + + # Create first collection and task + collection1 = client.create_collection(name="collection_1") + collection1.add(ids=["id1", "id2"], documents=["doc1", "doc2"]) + + success1, task_id1 = collection1.create_task( + task_name="task_1", + operator_name="record_counter", + output_collection_name="output_1", + params=None, + ) + + assert success1 is True + assert task_id1 is not None + + # Create second collection and task + collection2 = client.create_collection(name="collection_2") + collection2.add(ids=["id3", "id4"], documents=["doc3", "doc4"]) + + success2, task_id2 = collection2.create_task( + task_name="task_2", + operator_name="record_counter", + output_collection_name="output_2", + params=None, + ) + + assert success2 is True + assert task_id2 is not None + + # Task IDs should be different + assert task_id1 != task_id2 + + # Clean up + assert collection1.remove_task(task_name="task_1", delete_output=True) is True + assert collection2.remove_task(task_name="task_2", delete_output=True) is True + + +def test_task_multiple_tasks(basic_http_client: System) -> None: + """Test creating multiple tasks on the same collection""" + client = ClientCreator.from_system(basic_http_client) + client.reset() + + # Create a single collection + collection = client.create_collection(name="multi_task_collection") + collection.add(ids=["id1", "id2", "id3"], documents=["doc1", "doc2", "doc3"]) + + # Create first task on the collection + success1, task_id1 = collection.create_task( + task_name="task_1", + operator_name="record_counter", + output_collection_name="output_1", + params=None, + ) + + assert success1 is True + assert task_id1 is not None + + # Create second task on the SAME collection with a different name + success2, task_id2 = collection.create_task( + task_name="task_2", + operator_name="record_counter", + output_collection_name="output_2", + params=None, + ) + + assert success2 is True + assert task_id2 is not None + + # Task IDs should be different even though they're on the same collection + assert task_id1 != task_id2 + + # Create third task on the same collection + success3, task_id3 = collection.create_task( + task_name="task_3", + operator_name="record_counter", + output_collection_name="output_3", + params=None, + ) + + assert success3 is True + assert task_id3 is not None + assert task_id3 != task_id1 + assert task_id3 != task_id2 + + # Attempt to create a task with duplicate name on same collection should fail + with pytest.raises(ChromaError, match="already exists"): + collection.create_task( + task_name="task_1", # Duplicate name + operator_name="record_counter", + output_collection_name="output_duplicate", + params=None, + ) + + # Clean up - remove each task individually + assert collection.remove_task(task_name="task_1", delete_output=True) is True + assert collection.remove_task(task_name="task_2", delete_output=True) is True + assert collection.remove_task(task_name="task_3", delete_output=True) is True + + +def test_task_remove_nonexistent(basic_http_client: System) -> None: + """Test removing a task that doesn't exist raises NotFoundError""" + client = ClientCreator.from_system(basic_http_client) + client.reset() + + collection = client.create_collection(name="test_collection") + collection.add(ids=["id1"], documents=["test"]) + + # Try to remove a task that was never created should raise NotFoundError + with pytest.raises(NotFoundError, match="does not exist"): + collection.remove_task( + task_name="nonexistent_task", + delete_output=False, + ) diff --git a/clients/new-js/packages/chromadb/src/api/sdk.gen.ts b/clients/new-js/packages/chromadb/src/api/sdk.gen.ts index 586c8170a12..daa883d4021 100644 --- a/clients/new-js/packages/chromadb/src/api/sdk.gen.ts +++ b/clients/new-js/packages/chromadb/src/api/sdk.gen.ts @@ -1,7 +1,7 @@ // This file is auto-generated by @hey-api/openapi-ts import type { Options as ClientOptions, TDataShape, Client } from '@hey-api/client-fetch'; -import type { GetUserIdentityData, GetUserIdentityResponse2, GetUserIdentityError, GetCollectionByCrnData, GetCollectionByCrnResponse, GetCollectionByCrnError, HealthcheckData, HealthcheckResponse, HealthcheckError, HeartbeatData, HeartbeatResponse2, HeartbeatError, PreFlightChecksData, PreFlightChecksResponse, PreFlightChecksError, ResetData, ResetResponse, ResetError, CreateTenantData, CreateTenantResponse2, CreateTenantError, GetTenantData, GetTenantResponse2, GetTenantError, UpdateTenantData, UpdateTenantResponse2, UpdateTenantError, ListDatabasesData, ListDatabasesResponse, ListDatabasesError, CreateDatabaseData, CreateDatabaseResponse2, CreateDatabaseError, DeleteDatabaseData, DeleteDatabaseResponse2, DeleteDatabaseError, GetDatabaseData, GetDatabaseResponse, GetDatabaseError, ListCollectionsData, ListCollectionsResponse, ListCollectionsError, CreateCollectionData, CreateCollectionResponse, CreateCollectionError, DeleteCollectionData, DeleteCollectionResponse, DeleteCollectionError, GetCollectionData, GetCollectionResponse, GetCollectionError, UpdateCollectionData, UpdateCollectionResponse2, UpdateCollectionError, CollectionAddData, CollectionAddResponse, CollectionCountData, CollectionCountResponse, CollectionCountError, CollectionDeleteData, CollectionDeleteResponse, CollectionDeleteError, ForkCollectionData, ForkCollectionResponse, ForkCollectionError, CollectionGetData, CollectionGetResponse, CollectionGetError, CollectionQueryData, CollectionQueryResponse, CollectionQueryError, CollectionSearchData, CollectionSearchResponse, CollectionSearchError, CollectionUpdateData, CollectionUpdateResponse, CollectionUpsertData, CollectionUpsertResponse, CollectionUpsertError, CountCollectionsData, CountCollectionsResponse, CountCollectionsError, VersionData, VersionResponse } from './types.gen'; +import type { GetUserIdentityData, GetUserIdentityResponse2, GetUserIdentityError, GetCollectionByCrnData, GetCollectionByCrnResponse, GetCollectionByCrnError, HealthcheckData, HealthcheckResponse, HealthcheckError, HeartbeatData, HeartbeatResponse2, HeartbeatError, PreFlightChecksData, PreFlightChecksResponse, PreFlightChecksError, ResetData, ResetResponse, ResetError, CreateTenantData, CreateTenantResponse2, CreateTenantError, GetTenantData, GetTenantResponse2, GetTenantError, UpdateTenantData, UpdateTenantResponse2, UpdateTenantError, ListDatabasesData, ListDatabasesResponse, ListDatabasesError, CreateDatabaseData, CreateDatabaseResponse2, CreateDatabaseError, DeleteDatabaseData, DeleteDatabaseResponse2, DeleteDatabaseError, GetDatabaseData, GetDatabaseResponse, GetDatabaseError, ListCollectionsData, ListCollectionsResponse, ListCollectionsError, CreateCollectionData, CreateCollectionResponse, CreateCollectionError, DeleteCollectionData, DeleteCollectionResponse, DeleteCollectionError, GetCollectionData, GetCollectionResponse, GetCollectionError, UpdateCollectionData, UpdateCollectionResponse2, UpdateCollectionError, CollectionAddData, CollectionAddResponse, CollectionCountData, CollectionCountResponse, CollectionCountError, CollectionDeleteData, CollectionDeleteResponse, CollectionDeleteError, ForkCollectionData, ForkCollectionResponse, ForkCollectionError, CollectionGetData, CollectionGetResponse, CollectionGetError, CollectionQueryData, CollectionQueryResponse, CollectionQueryError, CollectionSearchData, CollectionSearchResponse, CollectionSearchError, CreateTaskData, CreateTaskResponse2, CreateTaskError, RemoveTaskData, RemoveTaskResponse2, RemoveTaskError, CollectionUpdateData, CollectionUpdateResponse, CollectionUpsertData, CollectionUpsertResponse, CollectionUpsertError, CountCollectionsData, CountCollectionsResponse, CountCollectionsError, VersionData, VersionResponse } from './types.gen'; import { client as _heyApiClient } from './client.gen'; export type Options = ClientOptions & { @@ -313,6 +313,34 @@ export class DefaultService { }); } + /** + * Register a new task for a collection + */ + public static createTask(options: Options) { + return (options.client ?? _heyApiClient).post({ + url: '/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/tasks/create', + ...options, + headers: { + 'Content-Type': 'application/json', + ...options?.headers + } + }); + } + + /** + * Remove a task + */ + public static removeTask(options: Options) { + return (options.client ?? _heyApiClient).post({ + url: '/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/tasks/delete', + ...options, + headers: { + 'Content-Type': 'application/json', + ...options?.headers + } + }); + } + /** * Updates records in a collection by ID. */ diff --git a/clients/new-js/packages/chromadb/src/api/types.gen.ts b/clients/new-js/packages/chromadb/src/api/types.gen.ts index 89d4ffab5e5..3bcf8799332 100644 --- a/clients/new-js/packages/chromadb/src/api/types.gen.ts +++ b/clients/new-js/packages/chromadb/src/api/types.gen.ts @@ -73,6 +73,18 @@ export type CreateDatabaseResponse = { [key: string]: unknown; }; +export type CreateTaskRequest = { + operator_name: string; + output_collection_name: string; + params?: unknown; + task_name: string; +}; + +export type CreateTaskResponse = { + success: boolean; + task_id: string; +}; + export type CreateTenantPayload = { name: string; }; @@ -273,6 +285,18 @@ export type RawWhereFields = { where_document?: unknown; }; +export type RemoveTaskRequest = { + /** + * Whether to delete the output collection as well + */ + delete_output?: boolean; + task_name: string; +}; + +export type RemoveTaskResponse = { + success: boolean; +}; + export type SearchPayload = { filter?: { query_ids?: Array; @@ -1472,6 +1496,90 @@ export type CollectionSearchResponses = { export type CollectionSearchResponse = CollectionSearchResponses[keyof CollectionSearchResponses]; +export type CreateTaskData = { + body: CreateTaskRequest; + path: { + /** + * Tenant ID + */ + tenant: string; + /** + * Database name + */ + database: string; + /** + * Collection ID + */ + collection_id: string; + }; + query?: never; + url: '/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/tasks/create'; +}; + +export type CreateTaskErrors = { + /** + * Unauthorized + */ + 401: ErrorResponse; + /** + * Server error + */ + 500: ErrorResponse; +}; + +export type CreateTaskError = CreateTaskErrors[keyof CreateTaskErrors]; + +export type CreateTaskResponses = { + /** + * Task created successfully + */ + 200: CreateTaskResponse; +}; + +export type CreateTaskResponse2 = CreateTaskResponses[keyof CreateTaskResponses]; + +export type RemoveTaskData = { + body: RemoveTaskRequest; + path: { + /** + * Tenant ID + */ + tenant: string; + /** + * Database name + */ + database: string; + /** + * Collection ID + */ + collection_id: string; + }; + query?: never; + url: '/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/tasks/delete'; +}; + +export type RemoveTaskErrors = { + /** + * Unauthorized + */ + 401: ErrorResponse; + /** + * Server error + */ + 500: ErrorResponse; +}; + +export type RemoveTaskError = RemoveTaskErrors[keyof RemoveTaskErrors]; + +export type RemoveTaskResponses = { + /** + * Task removed successfully + */ + 200: RemoveTaskResponse; +}; + +export type RemoveTaskResponse2 = RemoveTaskResponses[keyof RemoveTaskResponses]; + export type CollectionUpdateData = { body: UpdateCollectionRecordsPayload; path: { diff --git a/examples/task_api_example.py b/examples/task_api_example.py new file mode 100644 index 00000000000..f71fd911d35 --- /dev/null +++ b/examples/task_api_example.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +""" +Example: Using Chroma's Task API to process collections automatically + +This demonstrates how to register tasks that automatically process +collections as new records are added. +""" + +import chromadb + +# Connect to Chroma server +client = chromadb.HttpClient(host="localhost", port=8000) +# ignore error if collection does not exist +try: + client.delete_collection("my_documents_counts") +except Exception: + pass +# Create or get a collection +collection = client.get_or_create_collection( + name="my_document", metadata={"description": "Sample documents for task processing"} +) + +# Add some sample documents +collection.add( + ids=["doc1", "doc2", "doc3"], + documents=[ + "The quick brown fox jumps over the lazy dog", + "Machine learning is a subset of artificial intelligence", + "Python is a popular programming language", + ], + metadatas=[{"source": "proverb"}, {"source": "tech"}, {"source": "tech"}], +) + +print(f"✅ Created collection '{collection.name}' with {collection.count()} documents") + +# Create a task that counts records in the collection +# The 'record_counter' operator processes each record and outputs {"count": N} +success, task_id = collection.create_task( + task_name="count_my_docs", + operator_name="record_counter", # Built-in operator that counts records + output_collection_name="my_documents_counts", # Auto-created + params=None, # No additional parameters needed +) +assert success +if success: + print("✅ Task created successfully!") + print(f" Task ID: {task_id}") + print(" Task name: count_my_docs") + print(f" Input collection: {collection.name}") + print(" Output collection: my_documents_counts") + print(" Operator: record_counter") +else: + print("❌ Failed to create task") + +# The task will now run automatically when: +# 1. New documents are added to 'my_documents' +# 2. The number of new records >= min_records_for_task (default: 100) + +print("\n" + "=" * 60) +print("Task is now registered and will run on new data!") +print("=" * 60) + +# Add more documents to trigger task execution +print("\nAdding more documents...") +collection.add( + ids=["doc4", "doc5"], + documents=["Chroma is a vector database", "Tasks automate data processing"], +) + +print(f"Collection now has {collection.count()} documents") + +# Later, you can remove the task +print("\n" + "=" * 60) +input("Press Enter to remove the task...") + +success = collection.remove_task( + task_name="count_my_docs", delete_output=True # Also delete the output collection +) + +if success: + print("✅ Task removed successfully!") +else: + print("❌ Failed to remove task") diff --git a/go/pkg/sysdb/coordinator/task.go b/go/pkg/sysdb/coordinator/task.go index 6165b5df944..731c777eba0 100644 --- a/go/pkg/sysdb/coordinator/task.go +++ b/go/pkg/sysdb/coordinator/task.go @@ -16,6 +16,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/structpb" ) // CreateTask creates a new task in the database @@ -37,9 +38,8 @@ func (s *Coordinator) CreateTask(ctx context.Context, req *coordinatorpb.CreateT return err } if existingTask != nil { - log.Info("CreateTask: task already exists, returning existing") - taskID = existingTask.ID - return nil + log.Error("CreateTask: task already exists", zap.String("task_name", req.Name)) + return common.ErrTaskAlreadyExists } // Generate new task UUID @@ -98,6 +98,19 @@ func (s *Coordinator) CreateTask(ctx context.Context, req *coordinatorpb.CreateT return common.ErrCollectionUniqueConstraintViolation } + // Serialize params from protobuf Struct to JSON string for database storage + var paramsJSON string + if req.Params != nil { + paramsBytes, err := req.Params.MarshalJSON() + if err != nil { + log.Error("CreateTask: failed to marshal params", zap.Error(err)) + return err + } + paramsJSON = string(paramsBytes) + } else { + paramsJSON = "{}" + } + now := time.Now() task := &dbmodel.Task{ ID: taskID, @@ -107,7 +120,7 @@ func (s *Coordinator) CreateTask(ctx context.Context, req *coordinatorpb.CreateT InputCollectionID: req.InputCollectionId, OutputCollectionName: req.OutputCollectionName, OperatorID: operatorID, - OperatorParams: req.Params, + OperatorParams: paramsJSON, CompletionOffset: 0, LastRun: nil, NextRun: nil, // Will be set to zero initially, scheduled by task scheduler @@ -171,6 +184,16 @@ func (s *Coordinator) GetTaskByName(ctx context.Context, req *coordinatorpb.GetT // Debug logging log.Info("Found task", zap.String("task_id", task.ID.String()), zap.String("name", task.Name), zap.String("input_collection_id", task.InputCollectionID), zap.String("output_collection_name", task.OutputCollectionName)) + // Deserialize params from JSON string to protobuf Struct + var paramsStruct *structpb.Struct + if task.OperatorParams != "" { + paramsStruct = &structpb.Struct{} + if err := paramsStruct.UnmarshalJSON([]byte(task.OperatorParams)); err != nil { + log.Error("GetTaskByName: failed to unmarshal params", zap.Error(err)) + return nil, err + } + } + // Convert task to response response := &coordinatorpb.GetTaskByNameResponse{ TaskId: proto.String(task.ID.String()), @@ -178,9 +201,11 @@ func (s *Coordinator) GetTaskByName(ctx context.Context, req *coordinatorpb.GetT OperatorName: proto.String(operator.OperatorName), InputCollectionId: proto.String(task.InputCollectionID), OutputCollectionName: proto.String(task.OutputCollectionName), - Params: proto.String(task.OperatorParams), + Params: paramsStruct, CompletionOffset: proto.Int64(task.CompletionOffset), MinRecordsForTask: proto.Uint64(uint64(task.MinRecordsForTask)), + TenantId: proto.String(task.TenantID), + DatabaseId: proto.String(task.DatabaseID), } // Add output_collection_id if it's set if task.OutputCollectionID != nil { @@ -211,9 +236,10 @@ func (s *Coordinator) DeleteTask(ctx context.Context, req *coordinatorpb.DeleteT } deleteCollection := &model.DeleteCollection{ - ID: collectionUUID, - TenantID: task.TenantID, - DatabaseName: task.DatabaseID, + ID: collectionUUID, + TenantID: task.TenantID, + // Database name isn't available but also isn't needed since we supplied a collection id + DatabaseName: "", } err = s.SoftDeleteCollection(ctx, deleteCollection) diff --git a/go/pkg/sysdb/grpc/task_service.go b/go/pkg/sysdb/grpc/task_service.go index 9f96f7faf2c..baf6ba1826f 100644 --- a/go/pkg/sysdb/grpc/task_service.go +++ b/go/pkg/sysdb/grpc/task_service.go @@ -3,6 +3,8 @@ package grpc import ( "context" + "github.com/chroma-core/chroma/go/pkg/common" + "github.com/chroma-core/chroma/go/pkg/grpcutils" "github.com/chroma-core/chroma/go/pkg/proto/coordinatorpb" "github.com/pingcap/log" "go.uber.org/zap" @@ -14,6 +16,9 @@ func (s *Server) CreateTask(ctx context.Context, req *coordinatorpb.CreateTaskRe res, err := s.coordinator.CreateTask(ctx, req) if err != nil { log.Error("CreateTask failed", zap.Error(err)) + if err == common.ErrTaskAlreadyExists { + return nil, grpcutils.BuildAlreadyExistsGrpcError(err.Error()) + } return nil, err } @@ -26,6 +31,9 @@ func (s *Server) GetTaskByName(ctx context.Context, req *coordinatorpb.GetTaskBy res, err := s.coordinator.GetTaskByName(ctx, req) if err != nil { log.Error("GetTaskByName failed", zap.Error(err)) + if err == common.ErrTaskNotFound { + return nil, grpcutils.BuildNotFoundGrpcError(err.Error()) + } return nil, err } diff --git a/idl/chromadb/proto/coordinator.proto b/idl/chromadb/proto/coordinator.proto index 1e3ae00fd17..e8dc47794fc 100644 --- a/idl/chromadb/proto/coordinator.proto +++ b/idl/chromadb/proto/coordinator.proto @@ -5,6 +5,7 @@ option go_package = "github.com/chroma-core/chroma/go/pkg/proto/coordinatorpb"; import "chromadb/proto/chroma.proto"; import "google/protobuf/empty.proto"; +import "google/protobuf/struct.proto"; import "google/protobuf/timestamp.proto"; message CreateDatabaseRequest { @@ -530,7 +531,7 @@ message CreateTaskRequest { string operator_name = 2; string input_collection_id = 3; string output_collection_name = 4; - string params = 5; + optional google.protobuf.Struct params = 5; string tenant_id = 6; string database = 7; uint64 min_records_for_task = 8; @@ -552,9 +553,11 @@ message GetTaskByNameResponse { optional string input_collection_id = 4; optional string output_collection_name = 5; optional string output_collection_id = 6; - optional string params = 7; + optional google.protobuf.Struct params = 7; optional int64 completion_offset = 8; optional uint64 min_records_for_task = 9; + optional string tenant_id = 10; + optional string database_id = 11; } message DeleteTaskRequest { diff --git a/rust/frontend/src/auth/mod.rs b/rust/frontend/src/auth/mod.rs index 8d596918683..709adecdf39 100644 --- a/rust/frontend/src/auth/mod.rs +++ b/rust/frontend/src/auth/mod.rs @@ -36,6 +36,8 @@ pub enum AuthzAction { Update, Upsert, Search, + CreateTask, + RemoveTask, } impl Display for AuthzAction { @@ -66,6 +68,8 @@ impl Display for AuthzAction { AuthzAction::Update => write!(f, "collection:update"), AuthzAction::Upsert => write!(f, "collection:upsert"), AuthzAction::Search => write!(f, "collection:search"), + AuthzAction::CreateTask => write!(f, "collection:create_task"), + AuthzAction::RemoveTask => write!(f, "collection:remove_task"), } } } diff --git a/rust/frontend/src/config.rs b/rust/frontend/src/config.rs index 4a9772a4150..57b55e11d77 100644 --- a/rust/frontend/src/config.rs +++ b/rust/frontend/src/config.rs @@ -71,6 +71,8 @@ pub struct FrontendConfig { pub tenants_to_migrate_immediately_threshold: Option, #[serde(default = "default_enable_schema")] pub enable_schema: bool, + #[serde(default = "default_min_records_for_task")] + pub min_records_for_task: u64, } impl FrontendConfig { @@ -90,6 +92,7 @@ impl FrontendConfig { tenants_to_migrate_immediately: vec![], tenants_to_migrate_immediately_threshold: None, enable_schema: default_enable_schema(), + min_records_for_task: default_min_records_for_task(), } } } @@ -142,6 +145,10 @@ fn default_enable_schema() -> bool { false } +pub fn default_min_records_for_task() -> u64 { + 100 +} + #[derive(Deserialize, Serialize, Clone, Debug)] pub struct FrontendServerConfig { #[serde(flatten)] diff --git a/rust/frontend/src/impls/service_based_frontend.rs b/rust/frontend/src/impls/service_based_frontend.rs index a467d2c564b..96ca495261c 100644 --- a/rust/frontend/src/impls/service_based_frontend.rs +++ b/rust/frontend/src/impls/service_based_frontend.rs @@ -21,26 +21,27 @@ use chroma_types::{ operator::{Filter, KnnBatch, KnnProjection, Limit, Projection, Scan}, plan::{Count, Get, Knn, Search}, AddCollectionRecordsError, AddCollectionRecordsRequest, AddCollectionRecordsResponse, - Collection, CollectionUuid, CountCollectionsError, CountCollectionsRequest, + AddTaskError, Collection, CollectionUuid, CountCollectionsError, CountCollectionsRequest, CountCollectionsResponse, CountRequest, CountResponse, CreateCollectionError, CreateCollectionRequest, CreateCollectionResponse, CreateDatabaseError, CreateDatabaseRequest, - CreateDatabaseResponse, CreateTenantError, CreateTenantRequest, CreateTenantResponse, - DeleteCollectionError, DeleteCollectionRecordsError, DeleteCollectionRecordsRequest, - DeleteCollectionRecordsResponse, DeleteCollectionRequest, DeleteDatabaseError, - DeleteDatabaseRequest, DeleteDatabaseResponse, ForkCollectionError, ForkCollectionRequest, - ForkCollectionResponse, GetCollectionByCrnError, GetCollectionByCrnRequest, - GetCollectionByCrnResponse, GetCollectionError, GetCollectionRequest, GetCollectionResponse, - GetCollectionsError, GetDatabaseError, GetDatabaseRequest, GetDatabaseResponse, GetRequest, - GetResponse, GetTenantError, GetTenantRequest, GetTenantResponse, HealthCheckResponse, - HeartbeatError, HeartbeatResponse, Include, InternalSchema, KnnIndex, ListCollectionsRequest, - ListCollectionsResponse, ListDatabasesError, ListDatabasesRequest, ListDatabasesResponse, - Operation, OperationRecord, QueryError, QueryRequest, QueryResponse, ResetError, ResetResponse, - SchemaError, SearchRequest, SearchResponse, Segment, SegmentScope, SegmentType, SegmentUuid, - UpdateCollectionError, UpdateCollectionRecordsError, UpdateCollectionRecordsRequest, - UpdateCollectionRecordsResponse, UpdateCollectionRequest, UpdateCollectionResponse, - UpdateTenantError, UpdateTenantRequest, UpdateTenantResponse, UpsertCollectionRecordsError, - UpsertCollectionRecordsRequest, UpsertCollectionRecordsResponse, VectorIndexConfiguration, - Where, + CreateDatabaseResponse, CreateTaskRequest, CreateTaskResponse, CreateTenantError, + CreateTenantRequest, CreateTenantResponse, DeleteCollectionError, DeleteCollectionRecordsError, + DeleteCollectionRecordsRequest, DeleteCollectionRecordsResponse, DeleteCollectionRequest, + DeleteDatabaseError, DeleteDatabaseRequest, DeleteDatabaseResponse, ForkCollectionError, + ForkCollectionRequest, ForkCollectionResponse, GetCollectionByCrnError, + GetCollectionByCrnRequest, GetCollectionByCrnResponse, GetCollectionError, + GetCollectionRequest, GetCollectionResponse, GetCollectionsError, GetDatabaseError, + GetDatabaseRequest, GetDatabaseResponse, GetRequest, GetResponse, GetTenantError, + GetTenantRequest, GetTenantResponse, HealthCheckResponse, HeartbeatError, HeartbeatResponse, + Include, InternalSchema, KnnIndex, ListCollectionsRequest, ListCollectionsResponse, + ListDatabasesError, ListDatabasesRequest, ListDatabasesResponse, Operation, OperationRecord, + QueryError, QueryRequest, QueryResponse, RemoveTaskError, RemoveTaskRequest, + RemoveTaskResponse, ResetError, ResetResponse, SchemaError, SearchRequest, SearchResponse, + Segment, SegmentScope, SegmentType, SegmentUuid, UpdateCollectionError, + UpdateCollectionRecordsError, UpdateCollectionRecordsRequest, UpdateCollectionRecordsResponse, + UpdateCollectionRequest, UpdateCollectionResponse, UpdateTenantError, UpdateTenantRequest, + UpdateTenantResponse, UpsertCollectionRecordsError, UpsertCollectionRecordsRequest, + UpsertCollectionRecordsResponse, VectorIndexConfiguration, Where, }; use opentelemetry::global; use opentelemetry::metrics::Counter; @@ -78,6 +79,7 @@ pub struct ServiceBasedFrontend { default_knn_index: KnnIndex, enable_schema: bool, retries_builder: ExponentialBuilder, + min_records_for_task: u64, } impl ServiceBasedFrontend { @@ -91,6 +93,7 @@ impl ServiceBasedFrontend { max_batch_size: u32, default_knn_index: KnnIndex, enable_schema: bool, + min_records_for_task: u64, ) -> Self { let meter = global::meter("chroma"); let fork_retries_counter = meter.u64_counter("fork_retries").build(); @@ -144,6 +147,7 @@ impl ServiceBasedFrontend { default_knn_index, enable_schema, retries_builder, + min_records_for_task, } } @@ -1833,6 +1837,101 @@ impl ServiceBasedFrontend { res } + pub async fn create_task( + &mut self, + tenant_name: String, + database_name: String, + collection_id: String, + CreateTaskRequest { + task_name, + operator_name, + output_collection_name, + params, + .. + }: CreateTaskRequest, + ) -> Result { + // TODO: Trigger initial task run via heaptender + + // Parse collection_id from path parameter - client-side validation + let input_collection_id = + CollectionUuid(uuid::Uuid::parse_str(&collection_id).map_err(|e| { + AddTaskError::Internal(Box::new(chroma_error::TonicError( + tonic::Status::invalid_argument(format!( + "Client validation error: Invalid collection_id UUID format: {}", + e + )), + ))) + })?); + + let task_id = self + .sysdb_client + .create_task( + task_name.clone(), + operator_name, + input_collection_id, + output_collection_name.clone(), + params, + tenant_name, + database_name, + self.min_records_for_task, + ) + .await + .map_err(|e| match e { + chroma_sysdb::CreateTaskError::AlreadyExists => { + AddTaskError::AlreadyExists(task_name.clone()) + } + chroma_sysdb::CreateTaskError::FailedToCreateTask(s) => { + AddTaskError::Internal(Box::new(chroma_error::TonicError(s))) + } + chroma_sysdb::CreateTaskError::ServerReturnedInvalidData => AddTaskError::Internal( + Box::new(chroma_sysdb::CreateTaskError::ServerReturnedInvalidData), + ), + })?; + + Ok(CreateTaskResponse { + success: true, + task_id: task_id.to_string(), + }) + } + + pub async fn remove_task( + &mut self, + _tenant_id: String, + _database_name: String, + collection_id: String, + RemoveTaskRequest { + task_name, + delete_output, + .. + }: RemoveTaskRequest, + ) -> Result { + // Parse collection_id from path parameter - client-side validation + let collection_uuid = + CollectionUuid(uuid::Uuid::parse_str(&collection_id).map_err(|e| { + RemoveTaskError::Internal(Box::new(chroma_error::TonicError( + tonic::Status::invalid_argument(format!( + "Client validation error: Invalid collection_id UUID format: {}", + e + )), + ))) + })?); + + // Delete task by name - the coordinator handles output collection deletion atomically + self.sysdb_client + .delete_task_by_name(collection_uuid, task_name.clone(), delete_output) + .await + .map_err(|e| match e { + chroma_sysdb::DeleteTaskError::NotFound => { + RemoveTaskError::NotFound(task_name.clone()) + } + chroma_sysdb::DeleteTaskError::FailedToDeleteTask(s) => { + RemoveTaskError::Internal(Box::new(chroma_error::TonicError(s))) + } + })?; + + Ok(RemoveTaskResponse { success: true }) + } + pub async fn healthcheck(&self) -> HealthCheckResponse { HealthCheckResponse { is_executor_ready: self.executor.is_ready().await, @@ -1897,6 +1996,7 @@ impl Configurable<(FrontendConfig, System)> for ServiceBasedFrontend { max_batch_size, config.default_knn_index, config.enable_schema, + config.min_records_for_task, )) } } diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index ca0330b28cb..1d7456548f2 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -16,16 +16,16 @@ use chroma_types::{ AddCollectionRecordsResponse, ChecklistResponse, Collection, CollectionConfiguration, CollectionMetadataUpdate, CollectionUuid, CountCollectionsRequest, CountCollectionsResponse, CountRequest, CountResponse, CreateCollectionRequest, CreateDatabaseRequest, - CreateDatabaseResponse, CreateTenantRequest, CreateTenantResponse, - DeleteCollectionRecordsResponse, DeleteDatabaseRequest, DeleteDatabaseResponse, - GetCollectionByCrnRequest, GetCollectionRequest, GetDatabaseRequest, GetDatabaseResponse, - GetRequest, GetResponse, GetTenantRequest, GetTenantResponse, GetUserIdentityResponse, - HeartbeatResponse, IncludeList, InternalCollectionConfiguration, + CreateDatabaseResponse, CreateTaskRequest, CreateTaskResponse, CreateTenantRequest, + CreateTenantResponse, DeleteCollectionRecordsResponse, DeleteDatabaseRequest, + DeleteDatabaseResponse, GetCollectionByCrnRequest, GetCollectionRequest, GetDatabaseRequest, + GetDatabaseResponse, GetRequest, GetResponse, GetTenantRequest, GetTenantResponse, + GetUserIdentityResponse, HeartbeatResponse, IncludeList, InternalCollectionConfiguration, InternalUpdateCollectionConfiguration, ListCollectionsRequest, ListCollectionsResponse, ListDatabasesRequest, ListDatabasesResponse, Metadata, QueryRequest, QueryResponse, - SearchRequest, SearchResponse, UpdateCollectionConfiguration, UpdateCollectionRecordsResponse, - UpdateCollectionResponse, UpdateMetadata, UpdateTenantRequest, UpdateTenantResponse, - UpsertCollectionRecordsResponse, + RemoveTaskRequest, RemoveTaskResponse, SearchRequest, SearchResponse, + UpdateCollectionConfiguration, UpdateCollectionRecordsResponse, UpdateCollectionResponse, + UpdateMetadata, UpdateTenantRequest, UpdateTenantResponse, UpsertCollectionRecordsResponse, }; use chroma_types::{ForkCollectionResponse, RawWhereFields}; use mdac::{Rule, Scorecard, ScorecardTicket}; @@ -150,6 +150,8 @@ pub struct Metrics { collection_get: Counter, collection_query: Counter, collection_search: Counter, + create_task: Counter, + remove_task: Counter, } impl Metrics { @@ -184,6 +186,8 @@ impl Metrics { collection_get: meter.u64_counter("collection_get").build(), collection_query: meter.u64_counter("collection_query").build(), collection_search: meter.u64_counter("collection_search").build(), + create_task: meter.u64_counter("create_task").build(), + remove_task: meter.u64_counter("remove_task").build(), } } } @@ -325,6 +329,14 @@ impl FrontendServer { "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/search", post(collection_search), ) + .route( + "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/tasks/create", + post(create_task), + ) + .route( + "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/tasks/delete", + post(remove_task), + ) .merge(docs_router) .with_state(self) .layer(DefaultBodyLimit::max(max_payload_size_bytes)) @@ -2232,6 +2244,102 @@ async fn collection_search( Ok(Json(res)) } +/// Register a new task for a collection +#[utoipa::path( + post, + path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/tasks/create", + request_body = CreateTaskRequest, + responses( + (status = 200, description = "Task created successfully", body = CreateTaskResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant" = String, Path, description = "Tenant ID"), + ("database" = String, Path, description = "Database name"), + ("collection_id" = String, Path, description = "Collection ID") + ) +)] +async fn create_task( + headers: HeaderMap, + Path((tenant, database, collection_id)): Path<(String, String, String)>, + State(mut server): State, + TracedJson(request): TracedJson, +) -> Result, ServerError> { + server.metrics.create_task.add(1, &[]); + server + .authenticate_and_authorize( + &headers, + AuthzAction::CreateTask, + AuthzResource { + tenant: Some(tenant.clone()), + database: Some(database.clone()), + collection: None, + }, + ) + .await?; + + let _guard = server.scorecard_request(&[ + "op:create_task", + format!("tenant:{}", tenant).as_str(), + format!("database:{}", database).as_str(), + ])?; + + let res = server + .frontend + .create_task(tenant, database, collection_id, request) + .await?; + Ok(Json(res)) +} + +/// Remove a task +#[utoipa::path( + post, + path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}/tasks/delete", + request_body = RemoveTaskRequest, + responses( + (status = 200, description = "Task removed successfully", body = RemoveTaskResponse), + (status = 401, description = "Unauthorized", body = ErrorResponse), + (status = 500, description = "Server error", body = ErrorResponse) + ), + params( + ("tenant" = String, Path, description = "Tenant ID"), + ("database" = String, Path, description = "Database name"), + ("collection_id" = String, Path, description = "Collection ID") + ) +)] +async fn remove_task( + headers: HeaderMap, + Path((tenant, database_name, collection_id)): Path<(String, String, String)>, + State(mut server): State, + TracedJson(request): TracedJson, +) -> Result, ServerError> { + server.metrics.remove_task.add(1, &[]); + server + .authenticate_and_authorize( + &headers, + AuthzAction::RemoveTask, + AuthzResource { + tenant: Some(tenant.clone()), + database: Some(database_name.clone()), + collection: None, + }, + ) + .await?; + + let _guard = server.scorecard_request(&[ + "op:remove_task", + format!("tenant:{}", tenant).as_str(), + format!("database:{}", database_name).as_str(), + ])?; + + let res = server + .frontend + .remove_task(tenant, database_name, collection_id, request) + .await?; + Ok(Json(res)) +} + async fn v1_deprecation_notice() -> Response { let err_response = ErrorResponse::new( "Unimplemented".to_string(), @@ -2290,6 +2398,8 @@ impl Modify for ChromaTokenSecurityAddon { collection_get, collection_query, collection_search, + create_task, + remove_task, ), // Apply our new security scheme here modifiers(&ChromaTokenSecurityAddon) diff --git a/rust/python_bindings/src/bindings.rs b/rust/python_bindings/src/bindings.rs index 50d9c611129..99b51e063e2 100644 --- a/rust/python_bindings/src/bindings.rs +++ b/rust/python_bindings/src/bindings.rs @@ -3,6 +3,7 @@ use chroma_cache::FoyerCacheConfig; use chroma_cli::chroma_cli; use chroma_config::{registry::Registry, Configurable}; use chroma_frontend::{ + config::default_min_records_for_task, executor::config::{ExecutorConfig, LocalExecutorConfig}, get_collection_with_segments_provider::{ CacheInvalidationRetryConfig, CollectionsWithSegmentsProviderConfig, @@ -126,6 +127,7 @@ impl Bindings { tenants_to_migrate_immediately: vec![], tenants_to_migrate_immediately_threshold: None, enable_schema, + min_records_for_task: default_min_records_for_task(), }; let frontend = runtime.block_on(async { diff --git a/rust/sysdb/Cargo.toml b/rust/sysdb/Cargo.toml index 51fbe8a977a..1c1267d39e6 100644 --- a/rust/sysdb/Cargo.toml +++ b/rust/sysdb/Cargo.toml @@ -23,6 +23,7 @@ sea-query = { workspace = true } sea-query-binder = { workspace = true, features = ["sqlx-sqlite"] } chrono = { workspace = true } prost = { workspace = true } +prost-types = { workspace = true } derivative = "2.2.0" chroma-config = { workspace = true } diff --git a/rust/sysdb/src/sqlite.rs b/rust/sysdb/src/sqlite.rs index e5ee48ae22d..9b23982ab53 100644 --- a/rust/sysdb/src/sqlite.rs +++ b/rust/sysdb/src/sqlite.rs @@ -612,6 +612,57 @@ impl SqliteSysDb { Ok(ResetResponse {}) } + #[allow(clippy::too_many_arguments)] + pub(crate) async fn create_task( + &self, + _name: String, + _operator_id: String, + _input_collection_id: chroma_types::CollectionUuid, + _output_collection_name: String, + _params: serde_json::Value, + _tenant_id: String, + _database_id: String, + _min_records_for_task: u64, + ) -> Result { + // TODO: Implement this when task support is added to SqliteSysDb + Err(crate::CreateTaskError::FailedToCreateTask( + tonic::Status::unimplemented("Task operations not yet implemented in SqliteSysDb"), + )) + } + + pub(crate) async fn get_task_by_name( + &self, + _input_collection_id: chroma_types::CollectionUuid, + _task_name: String, + ) -> Result { + // TODO: Implement this when task support is added to SqliteSysDb + Err(crate::GetTaskError::FailedToGetTask( + tonic::Status::unimplemented("Task operations not yet implemented in SqliteSysDb"), + )) + } + + pub(crate) async fn soft_delete_task( + &self, + _task_id: chroma_types::TaskUuid, + ) -> Result<(), crate::DeleteTaskError> { + // TODO: Implement this when task support is added to SqliteSysDb + Err(crate::DeleteTaskError::FailedToDeleteTask( + tonic::Status::unimplemented("Task operations not yet implemented in SqliteSysDb"), + )) + } + + pub(crate) async fn delete_task_by_name( + &self, + _input_collection_id: chroma_types::CollectionUuid, + _task_name: String, + _delete_output: bool, + ) -> Result<(), crate::DeleteTaskError> { + // TODO: Implement this when task support is added to SqliteSysDb + Err(crate::DeleteTaskError::FailedToDeleteTask( + tonic::Status::unimplemented("Task operations not yet implemented in SqliteSysDb"), + )) + } + #[allow(clippy::too_many_arguments)] async fn get_collections_with_conn<'a, C>( &self, diff --git a/rust/sysdb/src/sysdb.rs b/rust/sysdb/src/sysdb.rs index d6ba475dfd9..ee2ef67141d 100644 --- a/rust/sysdb/src/sysdb.rs +++ b/rust/sysdb/src/sysdb.rs @@ -26,6 +26,7 @@ use chroma_types::{ ForkCollectionError, InternalSchema, SchemaError, Segment, SegmentConversionError, SegmentScope, Tenant, }; +use prost_types; use std::collections::HashMap; use std::fmt::Debug; use std::sync::Arc; @@ -38,6 +39,61 @@ use uuid::{Error, Uuid}; pub const VERSION_FILE_S3_PREFIX: &str = "sysdb/version_files/"; +// Helper function to convert serde_json::Value to prost_types::Value +fn json_to_prost_value(json: serde_json::Value) -> prost_types::Value { + use prost_types::value::Kind; + let kind = match json { + serde_json::Value::Null => Kind::NullValue(0), + serde_json::Value::Bool(b) => Kind::BoolValue(b), + serde_json::Value::Number(n) => { + if let Some(f) = n.as_f64() { + Kind::NumberValue(f) + } else { + Kind::NullValue(0) + } + } + serde_json::Value::String(s) => Kind::StringValue(s), + serde_json::Value::Array(arr) => Kind::ListValue(prost_types::ListValue { + values: arr.into_iter().map(json_to_prost_value).collect(), + }), + serde_json::Value::Object(map) => Kind::StructValue(prost_types::Struct { + fields: map + .into_iter() + .map(|(k, v)| (k, json_to_prost_value(v))) + .collect(), + }), + }; + prost_types::Value { kind: Some(kind) } +} + +// Helper function to convert prost_types::Value to serde_json::Value +fn prost_value_to_json(value: prost_types::Value) -> serde_json::Value { + use prost_types::value::Kind; + match value.kind { + Some(Kind::NullValue(_)) => serde_json::Value::Null, + Some(Kind::BoolValue(b)) => serde_json::Value::Bool(b), + Some(Kind::NumberValue(n)) => serde_json::Number::from_f64(n) + .map(serde_json::Value::Number) + .unwrap_or(serde_json::Value::Null), + Some(Kind::StringValue(s)) => serde_json::Value::String(s), + Some(Kind::ListValue(list)) => { + serde_json::Value::Array(list.values.into_iter().map(prost_value_to_json).collect()) + } + Some(Kind::StructValue(s)) => prost_struct_to_json(s), + None => serde_json::Value::Null, + } +} + +// Helper function to convert prost_types::Struct to serde_json::Value +fn prost_struct_to_json(s: prost_types::Struct) -> serde_json::Value { + serde_json::Value::Object( + s.fields + .into_iter() + .map(|(k, v)| (k, prost_value_to_json(v))) + .collect(), + ) +} + #[derive(Debug, Clone)] pub enum SysDb { Grpc(GrpcSysDb), @@ -1586,6 +1642,179 @@ impl GrpcSysDb { .map_err(|e| TonicError(e).boxed())?; Ok(ResetResponse {}) } + + #[allow(clippy::too_many_arguments)] + pub async fn create_task( + &mut self, + name: String, + operator_name: String, + input_collection_id: chroma_types::CollectionUuid, + output_collection_name: String, + params: serde_json::Value, + tenant_name: String, + database_name: String, + min_records_for_task: u64, + ) -> Result { + // Convert serde_json::Value to prost_types::Struct for gRPC + let params_struct = match params { + serde_json::Value::Object(map) => Some(prost_types::Struct { + fields: map + .into_iter() + .map(|(k, v)| (k, json_to_prost_value(v))) + .collect(), + }), + _ => None, // Non-object params omitted from proto + }; + + let req = chroma_proto::CreateTaskRequest { + name: name.clone(), + operator_name: operator_name.clone(), + input_collection_id: input_collection_id.to_string(), + output_collection_name: output_collection_name.clone(), + params: params_struct, + tenant_id: tenant_name.clone(), + database: database_name.clone(), + min_records_for_task, + }; + + let response = self.client.create_task(req).await?.into_inner(); + + // Parse the returned task_id - this should always succeed since the server generated it + // If this fails, it indicates a serious server bug or protocol corruption + let task_id = chroma_types::TaskUuid( + uuid::Uuid::parse_str(&response.task_id).map_err(|e| { + tracing::error!( + task_id = %response.task_id, + error = %e, + "Server returned invalid task_id UUID - task was created but response is corrupt" + ); + CreateTaskError::ServerReturnedInvalidData + })?, + ); + + Ok(task_id) + } + + pub async fn get_task_by_name( + &mut self, + input_collection_id: chroma_types::CollectionUuid, + task_name: String, + ) -> Result { + let req = chroma_proto::GetTaskByNameRequest { + input_collection_id: input_collection_id.to_string(), + task_name: task_name.clone(), + }; + + let response = match self.client.get_task_by_name(req).await { + Ok(resp) => resp, + Err(status) => { + if status.code() == tonic::Code::NotFound { + return Err(GetTaskError::NotFound); + } + return Err(GetTaskError::FailedToGetTask(status)); + } + }; + let response = response.into_inner(); + + // If response has no task_id, task was not found + if response.task_id.is_none() { + return Err(GetTaskError::NotFound); + } + + // Parse the response and construct Task + let task_id_str = response.task_id.unwrap(); + let task_id = chroma_types::TaskUuid(uuid::Uuid::parse_str(&task_id_str).map_err(|e| { + tracing::error!( + task_id = %task_id_str, + error = %e, + "Server returned invalid task_id UUID" + ); + GetTaskError::ServerReturnedInvalidData + })?); + + let operator_id = response.operator_name.ok_or_else(|| { + GetTaskError::FailedToGetTask(tonic::Status::internal( + "Missing operator_name in response", + )) + })?; + + let input_collection_id_str = response + .input_collection_id + .unwrap_or_else(|| input_collection_id.to_string()); + let parsed_input_collection_id = chroma_types::CollectionUuid( + uuid::Uuid::parse_str(&input_collection_id_str).map_err(|e| { + tracing::error!( + input_collection_id = %input_collection_id_str, + error = %e, + "Server returned invalid input_collection_id UUID" + ); + GetTaskError::ServerReturnedInvalidData + })?, + ); + + // Convert params from Struct to JSON string + let params_str = response.params.map(|s| { + let json_value = prost_struct_to_json(s); + serde_json::to_string(&json_value).unwrap_or_else(|_| "{}".to_string()) + }); + + Ok(chroma_types::Task { + id: task_id, + name: response.name.unwrap_or(task_name), + operator_id, + input_collection_id: parsed_input_collection_id, + output_collection_name: response.output_collection_name.unwrap_or_default(), + output_collection_id: Some(response.output_collection_id.unwrap_or_default()), + params: params_str, + tenant_id: response.tenant_id.unwrap_or_default(), + database_id: response.database_id.unwrap_or_default(), + last_run: None, + next_run: None, + completion_offset: response.completion_offset.unwrap_or(0) as u64, + min_records_for_task: response.min_records_for_task.unwrap_or(100), + is_deleted: false, + created_at: std::time::SystemTime::now(), + updated_at: std::time::SystemTime::now(), + }) + } + + pub async fn soft_delete_task( + &mut self, + _task_id: chroma_types::TaskUuid, + ) -> Result<(), DeleteTaskError> { + // Note: The gRPC DeleteTask API requires tenant_id, database_id, and task_name. + // We cannot implement this method with just a task_id. + // Callers should use delete_task_by_name() instead, which has all required parameters. + Err(DeleteTaskError::FailedToDeleteTask( + tonic::Status::unimplemented( + "soft_delete_task by ID not supported - use delete_task_by_name instead", + ), + )) + } + + pub async fn delete_task_by_name( + &mut self, + input_collection_id: chroma_types::CollectionUuid, + task_name: String, + delete_output: bool, + ) -> Result<(), DeleteTaskError> { + let req = chroma_proto::DeleteTaskRequest { + input_collection_id: input_collection_id.to_string(), + task_name, + delete_output, + }; + + match self.client.delete_task(req).await { + Ok(_) => Ok(()), + Err(status) => { + if status.code() == tonic::Code::NotFound { + Err(DeleteTaskError::NotFound) + } else { + Err(DeleteTaskError::FailedToDeleteTask(status)) + } + } + } + } } #[derive(Error, Debug)] @@ -1673,6 +1902,166 @@ impl ChromaError for DeleteCollectionVersionError { } } +////////////////////////// Task Operations ////////////////////////// + +impl SysDb { + #[allow(clippy::too_many_arguments)] + pub async fn create_task( + &mut self, + name: String, + operator_name: String, + input_collection_id: chroma_types::CollectionUuid, + output_collection_name: String, + params: serde_json::Value, + tenant_name: String, + database_name: String, + min_records_for_task: u64, + ) -> Result { + match self { + SysDb::Grpc(grpc) => { + grpc.create_task( + name, + operator_name, + input_collection_id, + output_collection_name, + params, + tenant_name, + database_name, + min_records_for_task, + ) + .await + } + SysDb::Sqlite(sqlite) => { + sqlite + .create_task( + name, + operator_name, + input_collection_id, + output_collection_name, + params, + tenant_name, + database_name, + min_records_for_task, + ) + .await + } + SysDb::Test(_) => { + todo!() + } + } + } + + pub async fn get_task_by_name( + &mut self, + input_collection_id: chroma_types::CollectionUuid, + task_name: String, + ) -> Result { + match self { + SysDb::Grpc(grpc) => grpc.get_task_by_name(input_collection_id, task_name).await, + SysDb::Sqlite(sqlite) => { + sqlite + .get_task_by_name(input_collection_id, task_name) + .await + } + SysDb::Test(_) => { + todo!() + } + } + } + + pub async fn soft_delete_task( + &mut self, + task_id: chroma_types::TaskUuid, + ) -> Result<(), DeleteTaskError> { + match self { + SysDb::Grpc(grpc) => grpc.soft_delete_task(task_id).await, + SysDb::Sqlite(sqlite) => sqlite.soft_delete_task(task_id).await, + SysDb::Test(_) => { + todo!() + } + } + } + + pub async fn delete_task_by_name( + &mut self, + input_collection_id: chroma_types::CollectionUuid, + task_name: String, + delete_output: bool, + ) -> Result<(), DeleteTaskError> { + match self { + SysDb::Grpc(grpc) => { + grpc.delete_task_by_name(input_collection_id, task_name, delete_output) + .await + } + SysDb::Sqlite(sqlite) => { + sqlite + .delete_task_by_name(input_collection_id, task_name, delete_output) + .await + } + SysDb::Test(_) => { + todo!() + } + } + } +} + +#[derive(Error, Debug)] +pub enum CreateTaskError { + #[error("Task already exists")] + AlreadyExists, + #[error("Failed to create task: {0}")] + FailedToCreateTask(#[from] tonic::Status), + #[error("Server returned invalid data - task was created but response is corrupt")] + ServerReturnedInvalidData, +} + +impl ChromaError for CreateTaskError { + fn code(&self) -> ErrorCodes { + match self { + CreateTaskError::AlreadyExists => ErrorCodes::AlreadyExists, + CreateTaskError::FailedToCreateTask(e) => e.code().into(), + CreateTaskError::ServerReturnedInvalidData => ErrorCodes::Internal, + } + } +} + +#[derive(Error, Debug)] +pub enum GetTaskError { + #[error("Task not found")] + NotFound, + #[error("Failed to get task: {0}")] + FailedToGetTask(tonic::Status), + #[error("Server returned invalid data")] + ServerReturnedInvalidData, +} + +impl ChromaError for GetTaskError { + fn code(&self) -> ErrorCodes { + match self { + GetTaskError::NotFound => ErrorCodes::NotFound, + GetTaskError::FailedToGetTask(e) => e.code().into(), + GetTaskError::ServerReturnedInvalidData => ErrorCodes::Internal, + } + } +} + +#[derive(Error, Debug)] +pub enum DeleteTaskError { + #[error("Task not found")] + NotFound, + #[error("Failed to delete task: {0}")] + FailedToDeleteTask(#[from] tonic::Status), +} + +impl ChromaError for DeleteTaskError { + fn code(&self) -> ErrorCodes { + match self { + DeleteTaskError::NotFound => ErrorCodes::NotFound, + DeleteTaskError::FailedToDeleteTask(e) => e.code().into(), + } + } +} + #[cfg(test)] mod tests { use tonic::Status; diff --git a/rust/types/src/api_types.rs b/rust/types/src/api_types.rs index 074de318091..4e114cfc73a 100644 --- a/rust/types/src/api_types.rs +++ b/rust/types/src/api_types.rs @@ -2032,6 +2032,119 @@ impl ChromaError for ExecutorError { } } +////////////////////////// Task Operations ////////////////////////// + +#[non_exhaustive] +#[derive(Clone, Debug, Deserialize, Serialize, Validate, ToSchema)] +pub struct CreateTaskRequest { + #[validate(length(min = 1))] + pub task_name: String, + pub operator_name: String, + pub output_collection_name: String, + #[serde(default = "default_empty_json_object")] + pub params: serde_json::Value, +} + +fn default_empty_json_object() -> serde_json::Value { + serde_json::json!({}) +} + +impl CreateTaskRequest { + pub fn try_new( + task_name: String, + operator_name: String, + output_collection_name: String, + params: serde_json::Value, + ) -> Result { + let request = Self { + task_name, + operator_name, + output_collection_name, + params, + }; + request.validate().map_err(ChromaValidationError::from)?; + Ok(request) + } +} + +#[derive(Clone, Debug, Serialize, ToSchema)] +pub struct CreateTaskResponse { + pub success: bool, + pub task_id: String, +} + +#[derive(Error, Debug)] +pub enum AddTaskError { + #[error("Task with name [{0}] already exists")] + AlreadyExists(String), + #[error("Input collection [{0}] does not exist")] + InputCollectionNotFound(String), + #[error("Output collection [{0}] already exists")] + OutputCollectionExists(String), + #[error(transparent)] + Validation(#[from] ChromaValidationError), + #[error(transparent)] + Internal(#[from] Box), +} + +impl ChromaError for AddTaskError { + fn code(&self) -> ErrorCodes { + match self { + AddTaskError::AlreadyExists(_) => ErrorCodes::AlreadyExists, + AddTaskError::InputCollectionNotFound(_) => ErrorCodes::NotFound, + AddTaskError::OutputCollectionExists(_) => ErrorCodes::AlreadyExists, + AddTaskError::Validation(err) => err.code(), + AddTaskError::Internal(err) => err.code(), + } + } +} + +#[non_exhaustive] +#[derive(Clone, Debug, Deserialize, Validate, Serialize, ToSchema)] +pub struct RemoveTaskRequest { + #[validate(length(min = 1))] + pub task_name: String, + /// Whether to delete the output collection as well + #[serde(default)] + pub delete_output: bool, +} + +impl RemoveTaskRequest { + pub fn try_new(task_name: String, delete_output: bool) -> Result { + let request = Self { + task_name, + delete_output, + }; + request.validate().map_err(ChromaValidationError::from)?; + Ok(request) + } +} + +#[derive(Clone, Debug, Serialize, ToSchema)] +pub struct RemoveTaskResponse { + pub success: bool, +} + +#[derive(Error, Debug)] +pub enum RemoveTaskError { + #[error("Task with name [{0}] does not exist")] + NotFound(String), + #[error(transparent)] + Validation(#[from] ChromaValidationError), + #[error(transparent)] + Internal(#[from] Box), +} + +impl ChromaError for RemoveTaskError { + fn code(&self) -> ErrorCodes { + match self { + RemoveTaskError::NotFound(_) => ErrorCodes::NotFound, + RemoveTaskError::Validation(err) => err.code(), + RemoveTaskError::Internal(err) => err.code(), + } + } +} + #[cfg(test)] mod test { use super::*; diff --git a/rust/types/src/task.rs b/rust/types/src/task.rs index d9721f18339..0c5b3c953cd 100644 --- a/rust/types/src/task.rs +++ b/rust/types/src/task.rs @@ -53,7 +53,7 @@ pub struct Task { pub id: TaskUuid, /// Human-readable name for the task instance pub name: String, - /// Identifier for the operator/built-in definition this task uses + /// Name of the operator/built-in definition this task uses (despite field name, this is a name not a UUID) pub operator_id: String, /// Source collection that triggers the task pub input_collection_id: CollectionUuid, @@ -63,9 +63,9 @@ pub struct Task { pub output_collection_id: Option, /// Optional JSON parameters for the operator pub params: Option, - /// Tenant this task belongs to + /// Tenant name this task belongs to (despite field name, this is a name not a UUID) pub tenant_id: String, - /// Database this task belongs to + /// Database name this task belongs to (despite field name, this is a name not a UUID) pub database_id: String, /// Timestamp of the last successful task run #[serde(skip, default)]