Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 14 additions & 19 deletions chromadb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,8 @@
from typing_extensions import Literal
import platform

in_pydantic_v2 = False
try:
from pydantic import BaseSettings
except ImportError:
in_pydantic_v2 = True
from pydantic.v1 import BaseSettings
from pydantic.v1 import validator

if not in_pydantic_v2:
from pydantic import validator # type: ignore # noqa
from pydantic_settings import BaseSettings
from pydantic import field_validator

# The thin client will have a flag to control which implementations to use
is_thin_client = False
Expand Down Expand Up @@ -117,7 +109,13 @@ class RoutingMode(Enum):
ID = "id"


class Settings(BaseSettings): # type: ignore
class Settings(BaseSettings):
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"extra": "ignore", # Ignore extra environment variables not defined in model
}

# ==============
# Generic config
# ==============
Expand All @@ -127,7 +125,8 @@ class Settings(BaseSettings): # type: ignore
# Can be "chromadb.api.segment.SegmentAPI" or "chromadb.api.fastapi.FastAPI" or "chromadb.api.rust.RustBindingsAPI"
chroma_api_impl: str = "chromadb.api.rust.RustBindingsAPI"

@validator("chroma_server_nofile", pre=True, always=True, allow_reuse=True)
@field_validator("chroma_server_nofile", mode="before")
@classmethod
def empty_str_to_none(cls, v: str) -> Optional[str]:
if type(v) is str and v.strip() == "":
return None
Expand Down Expand Up @@ -256,7 +255,7 @@ def empty_str_to_none(cls, v: str) -> Optional[str]:
chroma_memberlist_provider_impl: str = "chromadb.segment.impl.distributed.segment_directory.CustomResourceMemberlistProvider"
worker_memberlist_name: str = "query-service-memberlist"

chroma_coordinator_host = "localhost"
chroma_coordinator_host: str = "localhost"
# TODO this is the sysdb port. Should probably rename it.
chroma_server_grpc_port: Optional[int] = None
chroma_sysdb_impl: str = "chromadb.db.impl.sqlite.SqliteDB"
Expand All @@ -270,8 +269,8 @@ def empty_str_to_none(cls, v: str) -> Optional[str]:
chroma_executor_impl: str = "chromadb.execution.executor.local.LocalExecutor"
chroma_query_replication_factor: int = 2

chroma_logservice_host = "localhost"
chroma_logservice_port = 50052
chroma_logservice_host: str = "localhost"
chroma_logservice_port: int = 50052

chroma_quota_provider_impl: Optional[str] = None
chroma_rate_limiting_provider_impl: Optional[str] = None
Expand Down Expand Up @@ -323,10 +322,6 @@ def __getitem__(self, key: str) -> Any:
raise ValueError(LEGACY_ERROR)
return val

class Config:
env_file = ".env"
env_file_encoding = "utf-8"


T = TypeVar("T", bound="Component")

Expand Down
48 changes: 20 additions & 28 deletions chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,6 @@ async def check_http_version_middleware(
D = TypeVar("D", bound=BaseModel, contravariant=True)


def validate_model(model: Type[D], data: Any) -> D: # type: ignore
"""Used for backward compatibility with Pydantic 1.x"""
try:
return model.model_validate(data) # pydantic 2.x
except AttributeError:
return model.parse_obj(data) # pydantic 1.x


class ChromaAPIRouter(fastapi.APIRouter): # type: ignore
# A simple subclass of fastapi's APIRouter which treats URLs with a
# trailing "/" the same as URLs without. Docs will only contain URLs
Expand Down Expand Up @@ -555,7 +547,7 @@ async def create_database(
def process_create_database(
tenant: str, headers: Headers, raw_body: bytes
) -> None:
db = validate_model(CreateDatabase, orjson.loads(raw_body))
db = CreateDatabase.model_validate(orjson.loads(raw_body))

# NOTE(rescrv, iron will auth): Implemented.
self.sync_auth_request(
Expand Down Expand Up @@ -633,7 +625,7 @@ async def create_tenant(
request: Request,
) -> None:
def process_create_tenant(request: Request, raw_body: bytes) -> None:
tenant = validate_model(CreateTenant, orjson.loads(raw_body))
tenant = CreateTenant.model_validate(orjson.loads(raw_body))

# NOTE(rescrv, iron will auth): Implemented.
self.sync_auth_request(
Expand Down Expand Up @@ -785,7 +777,7 @@ async def create_collection(
def process_create_collection(
request: Request, tenant: str, database: str, raw_body: bytes
) -> CollectionModel:
create = validate_model(CreateCollection, orjson.loads(raw_body))
create = CreateCollection.model_validate(orjson.loads(raw_body))
if not create.configuration:
if create.metadata:
configuration = (
Expand Down Expand Up @@ -877,7 +869,7 @@ async def update_collection(
def process_update_collection(
request: Request, collection_id: str, raw_body: bytes
) -> None:
update = validate_model(UpdateCollection, orjson.loads(raw_body))
update = UpdateCollection.model_validate(orjson.loads(raw_body))
# NOTE(rescrv, iron will auth): Implemented.
self.sync_auth_request(
request.headers,
Expand Down Expand Up @@ -950,7 +942,7 @@ async def add(
try:

def process_add(request: Request, raw_body: bytes) -> bool:
add = validate_model(AddEmbedding, orjson.loads(raw_body))
add = AddEmbedding.model_validate(orjson.loads(raw_body))
# NOTE(rescrv, iron will auth): Implemented.
self.sync_auth_request(
request.headers,
Expand Down Expand Up @@ -999,7 +991,7 @@ async def update(
collection_id: str,
) -> None:
def process_update(request: Request, raw_body: bytes) -> bool:
update = validate_model(UpdateEmbedding, orjson.loads(raw_body))
update = UpdateEmbedding.model_validate(orjson.loads(raw_body))

# NOTE(rescrv, iron will auth): Implemented.
self.sync_auth_request(
Expand Down Expand Up @@ -1042,7 +1034,7 @@ async def upsert(
collection_id: str,
) -> None:
def process_upsert(request: Request, raw_body: bytes) -> bool:
upsert = validate_model(AddEmbedding, orjson.loads(raw_body))
upsert = AddEmbedding.model_validate(orjson.loads(raw_body))

# NOTE(rescrv, iron will auth): Implemented.
self.sync_auth_request(
Expand Down Expand Up @@ -1088,7 +1080,7 @@ async def get(
request: Request,
) -> GetResult:
def process_get(request: Request, raw_body: bytes) -> GetResult:
get = validate_model(GetEmbedding, orjson.loads(raw_body))
get = GetEmbedding.model_validate(orjson.loads(raw_body))
# NOTE(rescrv, iron will auth): Implemented.
self.sync_auth_request(
request.headers,
Expand Down Expand Up @@ -1139,7 +1131,7 @@ async def delete(
request: Request,
) -> None:
def process_delete(request: Request, raw_body: bytes) -> None:
delete = validate_model(DeleteEmbedding, orjson.loads(raw_body))
delete = DeleteEmbedding.model_validate(orjson.loads(raw_body))
# NOTE(rescrv, iron will auth): Implemented.
self.sync_auth_request(
request.headers,
Expand Down Expand Up @@ -1232,7 +1224,7 @@ async def get_nearest_neighbors(
"internal.get_nearest_neighbors", OpenTelemetryGranularity.OPERATION
)
def process_query(request: Request, raw_body: bytes) -> QueryResult:
query = validate_model(QueryEmbedding, orjson.loads(raw_body))
query = QueryEmbedding.model_validate(orjson.loads(raw_body))

# NOTE(rescrv, iron will auth): Implemented.
self.sync_auth_request(
Expand Down Expand Up @@ -1521,7 +1513,7 @@ async def create_database_v1(
def process_create_database(
tenant: str, headers: Headers, raw_body: bytes
) -> None:
db = validate_model(CreateDatabase, orjson.loads(raw_body))
db = CreateDatabase.model_validate(orjson.loads(raw_body))

(
maybe_tenant,
Expand Down Expand Up @@ -1589,7 +1581,7 @@ async def create_tenant_v1(
request: Request,
) -> None:
def process_create_tenant(request: Request, raw_body: bytes) -> None:
tenant = validate_model(CreateTenant, orjson.loads(raw_body))
tenant = CreateTenant.model_validate(orjson.loads(raw_body))

# NOTE(rescrv, iron will auth): v1
maybe_tenant, _ = self.sync_auth_and_get_tenant_and_database_for_request(
Expand Down Expand Up @@ -1720,7 +1712,7 @@ async def create_collection_v1(
def process_create_collection(
request: Request, tenant: str, database: str, raw_body: bytes
) -> CollectionModel:
create = validate_model(CreateCollection, orjson.loads(raw_body))
create = CreateCollection.model_validate(orjson.loads(raw_body))
configuration = (
CreateCollectionConfiguration()
if not create.configuration
Expand Down Expand Up @@ -1816,7 +1808,7 @@ async def update_collection_v1(
def process_update_collection(
request: Request, collection_id: str, raw_body: bytes
) -> None:
update = validate_model(UpdateCollection, orjson.loads(raw_body))
update = UpdateCollection.model_validate(orjson.loads(raw_body))
# NOTE(rescrv, iron will auth): v1
self.sync_auth_and_get_tenant_and_database_for_request(
request.headers,
Expand Down Expand Up @@ -1889,7 +1881,7 @@ async def add_v1(
try:

def process_add(request: Request, raw_body: bytes) -> bool:
add = validate_model(AddEmbedding, orjson.loads(raw_body))
add = AddEmbedding.model_validate(orjson.loads(raw_body))
# NOTE(rescrv, iron will auth): v1
self.sync_auth_and_get_tenant_and_database_for_request(
request.headers,
Expand Down Expand Up @@ -1931,7 +1923,7 @@ async def update_v1(
collection_id: str,
) -> None:
def process_update(request: Request, raw_body: bytes) -> bool:
update = validate_model(UpdateEmbedding, orjson.loads(raw_body))
update = UpdateEmbedding.model_validate(orjson.loads(raw_body))

# NOTE(rescrv, iron will auth): v1
self.sync_auth_and_get_tenant_and_database_for_request(
Expand Down Expand Up @@ -1968,7 +1960,7 @@ async def upsert_v1(
collection_id: str,
) -> None:
def process_upsert(request: Request, raw_body: bytes) -> bool:
upsert = validate_model(AddEmbedding, orjson.loads(raw_body))
upsert = AddEmbedding.model_validate(orjson.loads(raw_body))

# NOTE(rescrv, iron will auth): v1
self.sync_auth_and_get_tenant_and_database_for_request(
Expand Down Expand Up @@ -2008,7 +2000,7 @@ async def get_v1(
request: Request,
) -> GetResult:
def process_get(request: Request, raw_body: bytes) -> GetResult:
get = validate_model(GetEmbedding, orjson.loads(raw_body))
get = GetEmbedding.model_validate(orjson.loads(raw_body))
# NOTE(rescrv, iron will auth): v1
self.sync_auth_and_get_tenant_and_database_for_request(
request.headers,
Expand Down Expand Up @@ -2053,7 +2045,7 @@ async def delete_v1(
request: Request,
) -> None:
def process_delete(request: Request, raw_body: bytes) -> None:
delete = validate_model(DeleteEmbedding, orjson.loads(raw_body))
delete = DeleteEmbedding.model_validate(orjson.loads(raw_body))
# NOTE(rescrv, iron will auth): v1
self.sync_auth_and_get_tenant_and_database_for_request(
request.headers,
Expand Down Expand Up @@ -2132,7 +2124,7 @@ async def get_nearest_neighbors_v1(
request: Request,
) -> QueryResult:
def process_query(request: Request, raw_body: bytes) -> QueryResult:
query = validate_model(QueryEmbedding, orjson.loads(raw_body))
query = QueryEmbedding.model_validate(orjson.loads(raw_body))

# NOTE(rescrv, iron will auth): v1
self.sync_auth_and_get_tenant_and_database_for_request(
Expand Down
15 changes: 4 additions & 11 deletions chromadb/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __getitem__(self, key: str) -> Optional[Any]:
if key == "configuration":
return self.get_configuration()
# For the other model attributes we allow the user to access them directly
if key in self.get_model_fields():
if key in type(self).model_fields:
return getattr(self, key)
return None

Expand All @@ -130,18 +130,18 @@ def __setitem__(self, key: str, value: Any) -> None:
# For the model attributes we allow the user to access them directly
if key == "configuration":
self.set_configuration(value)
if key in self.get_model_fields():
if key in type(self).model_fields:
setattr(self, key, value)
else:
raise KeyError(
f"No such key: {key}, valid keys are: {self.get_model_fields()}"
f"No such key: {key}, valid keys are: {type(self).model_fields}"
)

def __eq__(self, __value: object) -> bool:
# Check that all the model fields are equal
if not isinstance(__value, Collection):
return False
for field in self.get_model_fields():
for field in type(self).model_fields:
if getattr(self, field) != getattr(__value, field):
return False
return True
Expand All @@ -167,13 +167,6 @@ def set_serialized_schema(self, serialized_schema: Dict[str, Any]) -> None:
"""Sets the serialized_schema of the collection"""
self.serialized_schema = serialized_schema

def get_model_fields(self) -> Dict[Any, Any]:
"""Used for backward compatibility with Pydantic 1.x"""
try:
return type(self).model_fields # pydantic 2.x, pydantic 3.x
except AttributeError:
return self.__fields__ # pydantic 1.x

def pretty_schema(self) -> str:
"""Returns a pretty-printed version of the serialized schema."""
if self.serialized_schema is None:
Expand Down
2 changes: 1 addition & 1 deletion clients/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = [
'opentelemetry-sdk>=1.2.0',
'overrides >= 7.3.1',
'posthog >= 2.4.0, < 6.0.0',
'pydantic>=1.9',
'pydantic>=2.12.4',
'typing_extensions >= 4.5.0',
'tenacity>=8.2.3',
'PyYAML>=6.0.0',
Expand Down
2 changes: 1 addition & 1 deletion clients/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ orjson>=3.9.12
overrides >= 7.3.1
posthog>=2.4.0,<6.0.0
pybase64>=1.4.1
pydantic>=1.9
pydantic>=2.12.4
PyYAML>=6.0.0
tenacity>=8.2.3
typing_extensions >= 4.5.0
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ description = "Chroma."
readme = "README.md"
requires-python = ">=3.9"
classifiers = ["Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent"]
dependencies = ['build >= 1.0.3', 'pydantic >= 1.9', 'pybase64>=1.4.1', 'uvicorn[standard] >= 0.18.3', 'numpy >= 1.22.5', 'posthog >= 2.4.0, < 6.0.0', 'typing_extensions >= 4.5.0', 'onnxruntime >= 1.14.1', 'opentelemetry-api>=1.2.0', 'opentelemetry-exporter-otlp-proto-grpc>=1.2.0', 'opentelemetry-sdk>=1.2.0', 'tokenizers >= 0.13.2', 'pypika >= 0.48.9', 'tqdm >= 4.65.0', 'overrides >= 7.3.1', 'importlib-resources', 'graphlib_backport >= 1.0.3; python_version < "3.9"', 'grpcio >= 1.58.0', 'bcrypt >= 4.0.1', 'typer >= 0.9.0', 'kubernetes>=28.1.0', 'tenacity>=8.2.3', 'PyYAML>=6.0.0', 'mmh3>=4.0.1', 'orjson>=3.9.12', 'httpx>=0.27.0', 'rich>=10.11.0', 'jsonschema>=4.19.0']
dependencies = ['build >= 1.0.3', 'pydantic >= 2.0, <3.0', 'pydantic-settings >= 2.0, <3.0', 'pybase64>=1.4.1', 'uvicorn[standard] >= 0.18.3', 'numpy >= 1.22.5', 'posthog >= 2.4.0, < 6.0.0', 'typing_extensions >= 4.5.0', 'onnxruntime >= 1.14.1', 'opentelemetry-api>=1.2.0', 'opentelemetry-exporter-otlp-proto-grpc>=1.2.0', 'opentelemetry-sdk>=1.2.0', 'tokenizers >= 0.13.2', 'pypika >= 0.48.9', 'tqdm >= 4.65.0', 'overrides >= 7.3.1', 'importlib-resources', 'grpcio >= 1.58.0', 'bcrypt >= 4.0.1', 'typer >= 0.9.0', 'kubernetes>=28.1.0', 'tenacity>=8.2.3', 'PyYAML>=6.0.0', 'mmh3>=4.0.1', 'orjson>=3.9.12', 'httpx>=0.27.0', 'rich>=10.11.0', 'jsonschema>=4.19.0']

[project.optional-dependencies]
dev = ['chroma-hnswlib==0.7.6', 'fastapi>=0.115.9', 'opentelemetry-instrumentation-fastapi>=0.41b0']
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
bcrypt>=4.0.1
graphlib_backport==1.0.3; python_version < '3.9'
grpcio>=1.58.0
httpx>=0.27.0
importlib-resources
Expand All @@ -15,7 +14,8 @@ orjson>=3.9.12
overrides>=7.3.1
posthog>=2.4.0,<6.0.0
pybase64>=1.4.1
pydantic>=1.9
pydantic>=2.0,<3.0
pydantic-settings>=2.0,<3.0
pypika>=0.48.9
PyYAML>=6.0.0
rich>=10.11.0
Expand Down
2 changes: 1 addition & 1 deletion rust/python_bindings/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "chromadb_rust_bindings"
version = "0.1.0"
version = "1.3.3"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down