Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/_python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ jobs:
- "chromadb/test/distributed/test_sanity.py"
- "chromadb/test/distributed/test_log_backpressure.py"
- "chromadb/test/distributed/test_repair_collection_log_offset.py"
- "chromadb/test/distributed/test_task_api.py"
include:
- test-glob: "chromadb/test/property/test_add.py"
parallelized: false
Expand Down
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

52 changes: 51 additions & 1 deletion chromadb/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Sequence, Optional, List
from typing import Sequence, Optional, List, Dict, Any
from uuid import UUID

from overrides import override
Expand Down Expand Up @@ -775,3 +775,53 @@ def _delete(
database: str = DEFAULT_DATABASE,
) -> None:
pass

@abstractmethod
def create_task(
self,
task_name: str,
operator_name: str,
input_collection_id: UUID,
output_collection_name: str,
params: Optional[Dict[str, Any]] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> tuple[bool, str]:
"""Create a recurring task on a collection.

Args:
task_name: Unique name for this task instance
operator_name: Built-in operator name (e.g., 'record_counter')
input_collection_id: Source collection that triggers the task
output_collection_name: Target collection where task output is stored
params: Optional dictionary with operator-specific parameters
tenant: The tenant name
database: The database name

Returns:
tuple: (success: bool, task_id: str)
"""
pass

@abstractmethod
def remove_task(
self,
task_name: str,
input_collection_id: UUID,
delete_output: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
"""Delete a task and prevent any further runs.

Args:
task_name: Name of the task to remove
input_collection_id: Id of the input collection the task is registered on
delete_output: Whether to also delete the output collection
tenant: The tenant name
database: The database name

Returns:
bool: True if successful
"""
pass
46 changes: 46 additions & 0 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,3 +702,49 @@ def get_max_batch_size(self) -> int:
pre_flight_checks = self.get_pre_flight_checks()
max_batch_size = cast(int, pre_flight_checks.get("max_batch_size", -1))
return max_batch_size

@trace_method("FastAPI.create_task", OpenTelemetryGranularity.ALL)
@override
def create_task(
self,
task_name: str,
operator_name: str,
input_collection_id: UUID,
output_collection_name: str,
params: Optional[Dict[str, Any]] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> tuple[bool, str]:
"""Register a recurring task on a collection."""
resp_json = self._make_request(
"post",
f"/tenants/{tenant}/databases/{database}/collections/{input_collection_id}/tasks/create",
json={
"task_name": task_name,
"operator_name": operator_name,
"output_collection_name": output_collection_name,
"params": params,
},
)
return cast(bool, resp_json["success"]), cast(str, resp_json["task_id"])

@trace_method("FastAPI.remove_task", OpenTelemetryGranularity.ALL)
@override
def remove_task(
self,
task_name: str,
input_collection_id: UUID,
delete_output: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
"""Delete a task and prevent any further runs."""
resp_json = self._make_request(
"post",
f"/tenants/{tenant}/databases/{database}/collections/{input_collection_id}/tasks/delete",
json={
"task_name": task_name,
"delete_output": delete_output,
},
)
return cast(bool, resp_json["success"])
73 changes: 67 additions & 6 deletions chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional, Union, List, cast
from typing import TYPE_CHECKING, Optional, Union, List, cast, Dict, Any

from chromadb.api.models.CollectionCommon import CollectionCommon
from chromadb.api.types import (
Expand Down Expand Up @@ -327,29 +327,29 @@ def search(
from chromadb.execution.expression import (
Search, Key, K, Knn, Val
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these from your editor? Or did you run the python formatter? I'm always suspicious of whitespace changes that would imply the formatter has changed or was not run.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

editor, left them in because i thought it was good to get rid of trailing whitespace

# Note: K is an alias for Key, so K.DOCUMENT == Key.DOCUMENT
search = (Search()
.where((K("category") == "science") & (K("score") > 0.5))
.rank(Knn(query=[0.1, 0.2, 0.3]) * 0.8 + Val(0.5) * 0.2)
.limit(10, offset=0)
.select(K.DOCUMENT, K.SCORE, "title"))

# Direct construction
from chromadb.execution.expression import (
Search, Eq, And, Gt, Knn, Limit, Select, Key
)

search = Search(
where=And([Eq("category", "science"), Gt("score", 0.5)]),
rank=Knn(query=[0.1, 0.2, 0.3]),
limit=Limit(offset=0, limit=10),
select=Select(keys={Key.DOCUMENT, Key.SCORE, "title"})
)

# Single search
result = collection.search(search)

# Multiple searches at once
searches = [
Search().where(K("type") == "article").rank(Knn(query=[0.1, 0.2])),
Expand Down Expand Up @@ -490,3 +490,64 @@ def delete(
tenant=self.tenant,
database=self.database,
)

def create_task(
self,
task_name: str,
operator_name: str,
output_collection_name: str,
params: Optional[Dict[str, Any]] = None,
) -> tuple[bool, str]:
"""Create a recurring task that processes this collection.

Args:
task_name: Unique name for this task instance
operator_name: Built-in operator name (e.g., "record_counter")
output_collection_name: Name of the collection where task output will be stored
params: Optional dictionary with operator-specific parameters
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this blob just passed in to the operator as e.g. a JSON value? How does the operator get these?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are stored in the task definition as a JSON string that a TaskRunner receives and passes into the operator it executes for that Task.


Returns:
tuple: (success: bool, task_id: str)

Example:
>>> success, task_id = collection.create_task(
... task_name="count_docs",
... operator_name="record_counter",
... output_collection_name="doc_counts",
... params={"threshold": 100}
... )
"""
return self._client.create_task(
task_name=task_name,
operator_name=operator_name,
input_collection_id=self.id,
output_collection_name=output_collection_name,
params=params,
tenant=self.tenant,
database=self.database,
)

def remove_task(
self,
task_name: str,
delete_output: bool = False,
) -> bool:
"""Delete a task and prevent any further runs.

Args:
task_name: Name of the task to remove
delete_output: Whether to also delete the output collection. Defaults to False.

Returns:
bool: True if successful

Example:
>>> success = collection.remove_task("count_docs", delete_output=True)
"""
return self._client.remove_task(
task_name=task_name,
input_collection_id=self.id,
delete_output=delete_output,
tenant=self.tenant,
database=self.database,
)
34 changes: 33 additions & 1 deletion chromadb/api/rust.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
import chromadb_rust_bindings


from typing import Optional, Sequence, List
from typing import Optional, Sequence, List, Dict, Any
from overrides import override
from uuid import UUID
import json
Expand Down Expand Up @@ -587,6 +587,38 @@ def get_settings(self) -> Settings:
def get_max_batch_size(self) -> int:
return self.bindings.get_max_batch_size()

@override
def create_task(
self,
task_name: str,
operator_name: str,
input_collection_id: UUID,
output_collection_name: str,
params: Optional[Dict[str, Any]] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> tuple[bool, str]:
"""Tasks are not supported in the Rust bindings (local embedded mode)."""
raise NotImplementedError(
"Tasks are only supported when connecting to a Chroma server via HttpClient. "
"The Rust bindings (embedded mode) do not support task operations."
)

@override
def remove_task(
self,
task_name: str,
input_collection_id: UUID,
delete_output: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
"""Tasks are not supported in the Rust bindings (local embedded mode)."""
raise NotImplementedError(
"Tasks are only supported when connecting to a Chroma server via HttpClient. "
"The Rust bindings (embedded mode) do not support task operations."
)

# TODO: Remove this if it's not planned to be used
@override
def get_user_identity(self) -> UserIdentity:
Expand Down
33 changes: 33 additions & 0 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
Generator,
List,
Any,
Dict,
Callable,
TypeVar,
)
Expand Down Expand Up @@ -906,6 +907,38 @@ def get_settings(self) -> Settings:
def get_max_batch_size(self) -> int:
return self._producer.max_batch_size

@override
def create_task(
self,
task_name: str,
operator_name: str,
input_collection_id: UUID,
output_collection_name: str,
params: Optional[Dict[str, Any]] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> tuple[bool, str]:
"""Tasks are not supported in the Segment API (local embedded mode)."""
raise NotImplementedError(
"Tasks are only supported when connecting to a Chroma server via HttpClient. "
"The Segment API (embedded mode) does not support task operations."
)

@override
def remove_task(
self,
task_name: str,
input_collection_id: UUID,
delete_output: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
"""Tasks are not supported in the Segment API (local embedded mode)."""
raise NotImplementedError(
"Tasks are only supported when connecting to a Chroma server via HttpClient. "
"The Segment API (embedded mode) does not support task operations."
)

# TODO: This could potentially cause race conditions in a distributed version of the
# system, since the cache is only local.
# TODO: promote collection -> topic to a base class method so that it can be
Expand Down
Loading
Loading