diff --git a/Cargo.lock b/Cargo.lock index 1be1f95059e..4b9a96aeee0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2044,7 +2044,7 @@ dependencies = [ [[package]] name = "chromadb_rust_bindings" -version = "0.1.0" +version = "1.3.3" dependencies = [ "chroma-cache", "chroma-cli", diff --git a/chromadb/config.py b/chromadb/config.py index ee264bcf03c..8496eb8d3ab 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -12,16 +12,8 @@ from typing_extensions import Literal import platform -in_pydantic_v2 = False -try: - from pydantic import BaseSettings -except ImportError: - in_pydantic_v2 = True - from pydantic.v1 import BaseSettings - from pydantic.v1 import validator - -if not in_pydantic_v2: - from pydantic import validator # type: ignore # noqa +from pydantic_settings import BaseSettings +from pydantic import field_validator # The thin client will have a flag to control which implementations to use is_thin_client = False @@ -117,7 +109,13 @@ class RoutingMode(Enum): ID = "id" -class Settings(BaseSettings): # type: ignore +class Settings(BaseSettings): + model_config = { + "env_file": ".env", + "env_file_encoding": "utf-8", + "extra": "ignore", # Ignore extra environment variables not defined in model + } + # ============== # Generic config # ============== @@ -127,7 +125,8 @@ class Settings(BaseSettings): # type: ignore # Can be "chromadb.api.segment.SegmentAPI" or "chromadb.api.fastapi.FastAPI" or "chromadb.api.rust.RustBindingsAPI" chroma_api_impl: str = "chromadb.api.rust.RustBindingsAPI" - @validator("chroma_server_nofile", pre=True, always=True, allow_reuse=True) + @field_validator("chroma_server_nofile", mode="before") + @classmethod def empty_str_to_none(cls, v: str) -> Optional[str]: if type(v) is str and v.strip() == "": return None @@ -256,7 +255,7 @@ def empty_str_to_none(cls, v: str) -> Optional[str]: chroma_memberlist_provider_impl: str = "chromadb.segment.impl.distributed.segment_directory.CustomResourceMemberlistProvider" worker_memberlist_name: str = "query-service-memberlist" - chroma_coordinator_host = "localhost" + chroma_coordinator_host: str = "localhost" # TODO this is the sysdb port. Should probably rename it. chroma_server_grpc_port: Optional[int] = None chroma_sysdb_impl: str = "chromadb.db.impl.sqlite.SqliteDB" @@ -270,8 +269,8 @@ def empty_str_to_none(cls, v: str) -> Optional[str]: chroma_executor_impl: str = "chromadb.execution.executor.local.LocalExecutor" chroma_query_replication_factor: int = 2 - chroma_logservice_host = "localhost" - chroma_logservice_port = 50052 + chroma_logservice_host: str = "localhost" + chroma_logservice_port: int = 50052 chroma_quota_provider_impl: Optional[str] = None chroma_rate_limiting_provider_impl: Optional[str] = None @@ -323,10 +322,6 @@ def __getitem__(self, key: str) -> Any: raise ValueError(LEGACY_ERROR) return val - class Config: - env_file = ".env" - env_file_encoding = "utf-8" - T = TypeVar("T", bound="Component") diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index a3fc1b063e5..1e568108d30 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -152,14 +152,6 @@ async def check_http_version_middleware( D = TypeVar("D", bound=BaseModel, contravariant=True) -def validate_model(model: Type[D], data: Any) -> D: # type: ignore - """Used for backward compatibility with Pydantic 1.x""" - try: - return model.model_validate(data) # pydantic 2.x - except AttributeError: - return model.parse_obj(data) # pydantic 1.x - - class ChromaAPIRouter(fastapi.APIRouter): # type: ignore # A simple subclass of fastapi's APIRouter which treats URLs with a # trailing "/" the same as URLs without. Docs will only contain URLs @@ -555,7 +547,7 @@ async def create_database( def process_create_database( tenant: str, headers: Headers, raw_body: bytes ) -> None: - db = validate_model(CreateDatabase, orjson.loads(raw_body)) + db = CreateDatabase.model_validate(orjson.loads(raw_body)) # NOTE(rescrv, iron will auth): Implemented. self.sync_auth_request( @@ -633,7 +625,7 @@ async def create_tenant( request: Request, ) -> None: def process_create_tenant(request: Request, raw_body: bytes) -> None: - tenant = validate_model(CreateTenant, orjson.loads(raw_body)) + tenant = CreateTenant.model_validate(orjson.loads(raw_body)) # NOTE(rescrv, iron will auth): Implemented. self.sync_auth_request( @@ -785,7 +777,7 @@ async def create_collection( def process_create_collection( request: Request, tenant: str, database: str, raw_body: bytes ) -> CollectionModel: - create = validate_model(CreateCollection, orjson.loads(raw_body)) + create = CreateCollection.model_validate(orjson.loads(raw_body)) if not create.configuration: if create.metadata: configuration = ( @@ -877,7 +869,7 @@ async def update_collection( def process_update_collection( request: Request, collection_id: str, raw_body: bytes ) -> None: - update = validate_model(UpdateCollection, orjson.loads(raw_body)) + update = UpdateCollection.model_validate(orjson.loads(raw_body)) # NOTE(rescrv, iron will auth): Implemented. self.sync_auth_request( request.headers, @@ -950,7 +942,7 @@ async def add( try: def process_add(request: Request, raw_body: bytes) -> bool: - add = validate_model(AddEmbedding, orjson.loads(raw_body)) + add = AddEmbedding.model_validate(orjson.loads(raw_body)) # NOTE(rescrv, iron will auth): Implemented. self.sync_auth_request( request.headers, @@ -999,7 +991,7 @@ async def update( collection_id: str, ) -> None: def process_update(request: Request, raw_body: bytes) -> bool: - update = validate_model(UpdateEmbedding, orjson.loads(raw_body)) + update = UpdateEmbedding.model_validate(orjson.loads(raw_body)) # NOTE(rescrv, iron will auth): Implemented. self.sync_auth_request( @@ -1042,7 +1034,7 @@ async def upsert( collection_id: str, ) -> None: def process_upsert(request: Request, raw_body: bytes) -> bool: - upsert = validate_model(AddEmbedding, orjson.loads(raw_body)) + upsert = AddEmbedding.model_validate(orjson.loads(raw_body)) # NOTE(rescrv, iron will auth): Implemented. self.sync_auth_request( @@ -1088,7 +1080,7 @@ async def get( request: Request, ) -> GetResult: def process_get(request: Request, raw_body: bytes) -> GetResult: - get = validate_model(GetEmbedding, orjson.loads(raw_body)) + get = GetEmbedding.model_validate(orjson.loads(raw_body)) # NOTE(rescrv, iron will auth): Implemented. self.sync_auth_request( request.headers, @@ -1139,7 +1131,7 @@ async def delete( request: Request, ) -> None: def process_delete(request: Request, raw_body: bytes) -> None: - delete = validate_model(DeleteEmbedding, orjson.loads(raw_body)) + delete = DeleteEmbedding.model_validate(orjson.loads(raw_body)) # NOTE(rescrv, iron will auth): Implemented. self.sync_auth_request( request.headers, @@ -1232,7 +1224,7 @@ async def get_nearest_neighbors( "internal.get_nearest_neighbors", OpenTelemetryGranularity.OPERATION ) def process_query(request: Request, raw_body: bytes) -> QueryResult: - query = validate_model(QueryEmbedding, orjson.loads(raw_body)) + query = QueryEmbedding.model_validate(orjson.loads(raw_body)) # NOTE(rescrv, iron will auth): Implemented. self.sync_auth_request( @@ -1521,7 +1513,7 @@ async def create_database_v1( def process_create_database( tenant: str, headers: Headers, raw_body: bytes ) -> None: - db = validate_model(CreateDatabase, orjson.loads(raw_body)) + db = CreateDatabase.model_validate(orjson.loads(raw_body)) ( maybe_tenant, @@ -1589,7 +1581,7 @@ async def create_tenant_v1( request: Request, ) -> None: def process_create_tenant(request: Request, raw_body: bytes) -> None: - tenant = validate_model(CreateTenant, orjson.loads(raw_body)) + tenant = CreateTenant.model_validate(orjson.loads(raw_body)) # NOTE(rescrv, iron will auth): v1 maybe_tenant, _ = self.sync_auth_and_get_tenant_and_database_for_request( @@ -1720,7 +1712,7 @@ async def create_collection_v1( def process_create_collection( request: Request, tenant: str, database: str, raw_body: bytes ) -> CollectionModel: - create = validate_model(CreateCollection, orjson.loads(raw_body)) + create = CreateCollection.model_validate(orjson.loads(raw_body)) configuration = ( CreateCollectionConfiguration() if not create.configuration @@ -1816,7 +1808,7 @@ async def update_collection_v1( def process_update_collection( request: Request, collection_id: str, raw_body: bytes ) -> None: - update = validate_model(UpdateCollection, orjson.loads(raw_body)) + update = UpdateCollection.model_validate(orjson.loads(raw_body)) # NOTE(rescrv, iron will auth): v1 self.sync_auth_and_get_tenant_and_database_for_request( request.headers, @@ -1889,7 +1881,7 @@ async def add_v1( try: def process_add(request: Request, raw_body: bytes) -> bool: - add = validate_model(AddEmbedding, orjson.loads(raw_body)) + add = AddEmbedding.model_validate(orjson.loads(raw_body)) # NOTE(rescrv, iron will auth): v1 self.sync_auth_and_get_tenant_and_database_for_request( request.headers, @@ -1931,7 +1923,7 @@ async def update_v1( collection_id: str, ) -> None: def process_update(request: Request, raw_body: bytes) -> bool: - update = validate_model(UpdateEmbedding, orjson.loads(raw_body)) + update = UpdateEmbedding.model_validate(orjson.loads(raw_body)) # NOTE(rescrv, iron will auth): v1 self.sync_auth_and_get_tenant_and_database_for_request( @@ -1968,7 +1960,7 @@ async def upsert_v1( collection_id: str, ) -> None: def process_upsert(request: Request, raw_body: bytes) -> bool: - upsert = validate_model(AddEmbedding, orjson.loads(raw_body)) + upsert = AddEmbedding.model_validate(orjson.loads(raw_body)) # NOTE(rescrv, iron will auth): v1 self.sync_auth_and_get_tenant_and_database_for_request( @@ -2008,7 +2000,7 @@ async def get_v1( request: Request, ) -> GetResult: def process_get(request: Request, raw_body: bytes) -> GetResult: - get = validate_model(GetEmbedding, orjson.loads(raw_body)) + get = GetEmbedding.model_validate(orjson.loads(raw_body)) # NOTE(rescrv, iron will auth): v1 self.sync_auth_and_get_tenant_and_database_for_request( request.headers, @@ -2053,7 +2045,7 @@ async def delete_v1( request: Request, ) -> None: def process_delete(request: Request, raw_body: bytes) -> None: - delete = validate_model(DeleteEmbedding, orjson.loads(raw_body)) + delete = DeleteEmbedding.model_validate(orjson.loads(raw_body)) # NOTE(rescrv, iron will auth): v1 self.sync_auth_and_get_tenant_and_database_for_request( request.headers, @@ -2132,7 +2124,7 @@ async def get_nearest_neighbors_v1( request: Request, ) -> QueryResult: def process_query(request: Request, raw_body: bytes) -> QueryResult: - query = validate_model(QueryEmbedding, orjson.loads(raw_body)) + query = QueryEmbedding.model_validate(orjson.loads(raw_body)) # NOTE(rescrv, iron will auth): v1 self.sync_auth_and_get_tenant_and_database_for_request( diff --git a/chromadb/types.py b/chromadb/types.py index be48337e314..5ffe2b22dcd 100644 --- a/chromadb/types.py +++ b/chromadb/types.py @@ -120,7 +120,7 @@ def __getitem__(self, key: str) -> Optional[Any]: if key == "configuration": return self.get_configuration() # For the other model attributes we allow the user to access them directly - if key in self.get_model_fields(): + if key in type(self).model_fields: return getattr(self, key) return None @@ -130,18 +130,18 @@ def __setitem__(self, key: str, value: Any) -> None: # For the model attributes we allow the user to access them directly if key == "configuration": self.set_configuration(value) - if key in self.get_model_fields(): + if key in type(self).model_fields: setattr(self, key, value) else: raise KeyError( - f"No such key: {key}, valid keys are: {self.get_model_fields()}" + f"No such key: {key}, valid keys are: {type(self).model_fields}" ) def __eq__(self, __value: object) -> bool: # Check that all the model fields are equal if not isinstance(__value, Collection): return False - for field in self.get_model_fields(): + for field in type(self).model_fields: if getattr(self, field) != getattr(__value, field): return False return True @@ -167,13 +167,6 @@ def set_serialized_schema(self, serialized_schema: Dict[str, Any]) -> None: """Sets the serialized_schema of the collection""" self.serialized_schema = serialized_schema - def get_model_fields(self) -> Dict[Any, Any]: - """Used for backward compatibility with Pydantic 1.x""" - try: - return type(self).model_fields # pydantic 2.x, pydantic 3.x - except AttributeError: - return self.__fields__ # pydantic 1.x - def pretty_schema(self) -> str: """Returns a pretty-printed version of the serialized schema.""" if self.serialized_schema is None: diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 8448765b299..c6cb869fe77 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ 'opentelemetry-sdk>=1.2.0', 'overrides >= 7.3.1', 'posthog >= 2.4.0, < 6.0.0', - 'pydantic>=1.9', + 'pydantic>=2.12.4', 'typing_extensions >= 4.5.0', 'tenacity>=8.2.3', 'PyYAML>=6.0.0', diff --git a/clients/python/requirements.txt b/clients/python/requirements.txt index 94f622de48f..1633132e1f0 100644 --- a/clients/python/requirements.txt +++ b/clients/python/requirements.txt @@ -7,7 +7,7 @@ orjson>=3.9.12 overrides >= 7.3.1 posthog>=2.4.0,<6.0.0 pybase64>=1.4.1 -pydantic>=1.9 +pydantic>=2.12.4 PyYAML>=6.0.0 tenacity>=8.2.3 typing_extensions >= 4.5.0 diff --git a/pyproject.toml b/pyproject.toml index 6cb42955dea..3b5aea83133 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ description = "Chroma." readme = "README.md" requires-python = ">=3.9" classifiers = ["Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent"] -dependencies = ['build >= 1.0.3', 'pydantic >= 1.9', 'pybase64>=1.4.1', 'uvicorn[standard] >= 0.18.3', 'numpy >= 1.22.5', 'posthog >= 2.4.0, < 6.0.0', 'typing_extensions >= 4.5.0', 'onnxruntime >= 1.14.1', 'opentelemetry-api>=1.2.0', 'opentelemetry-exporter-otlp-proto-grpc>=1.2.0', 'opentelemetry-sdk>=1.2.0', 'tokenizers >= 0.13.2', 'pypika >= 0.48.9', 'tqdm >= 4.65.0', 'overrides >= 7.3.1', 'importlib-resources', 'graphlib_backport >= 1.0.3; python_version < "3.9"', 'grpcio >= 1.58.0', 'bcrypt >= 4.0.1', 'typer >= 0.9.0', 'kubernetes>=28.1.0', 'tenacity>=8.2.3', 'PyYAML>=6.0.0', 'mmh3>=4.0.1', 'orjson>=3.9.12', 'httpx>=0.27.0', 'rich>=10.11.0', 'jsonschema>=4.19.0'] +dependencies = ['build >= 1.0.3', 'pydantic >= 2.0, <3.0', 'pydantic-settings >= 2.0, <3.0', 'pybase64>=1.4.1', 'uvicorn[standard] >= 0.18.3', 'numpy >= 1.22.5', 'posthog >= 2.4.0, < 6.0.0', 'typing_extensions >= 4.5.0', 'onnxruntime >= 1.14.1', 'opentelemetry-api>=1.2.0', 'opentelemetry-exporter-otlp-proto-grpc>=1.2.0', 'opentelemetry-sdk>=1.2.0', 'tokenizers >= 0.13.2', 'pypika >= 0.48.9', 'tqdm >= 4.65.0', 'overrides >= 7.3.1', 'importlib-resources', 'grpcio >= 1.58.0', 'bcrypt >= 4.0.1', 'typer >= 0.9.0', 'kubernetes>=28.1.0', 'tenacity>=8.2.3', 'PyYAML>=6.0.0', 'mmh3>=4.0.1', 'orjson>=3.9.12', 'httpx>=0.27.0', 'rich>=10.11.0', 'jsonschema>=4.19.0'] [project.optional-dependencies] dev = ['chroma-hnswlib==0.7.6', 'fastapi>=0.115.9', 'opentelemetry-instrumentation-fastapi>=0.41b0'] diff --git a/requirements.txt b/requirements.txt index 256f531fed4..88961ed4ac8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ bcrypt>=4.0.1 -graphlib_backport==1.0.3; python_version < '3.9' grpcio>=1.58.0 httpx>=0.27.0 importlib-resources @@ -15,7 +14,8 @@ orjson>=3.9.12 overrides>=7.3.1 posthog>=2.4.0,<6.0.0 pybase64>=1.4.1 -pydantic>=1.9 +pydantic>=2.0,<3.0 +pydantic-settings>=2.0,<3.0 pypika>=0.48.9 PyYAML>=6.0.0 rich>=10.11.0 diff --git a/rust/python_bindings/Cargo.toml b/rust/python_bindings/Cargo.toml index 65b7c9dc160..4e6596aac0a 100644 --- a/rust/python_bindings/Cargo.toml +++ b/rust/python_bindings/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "chromadb_rust_bindings" -version = "0.1.0" +version = "1.3.3" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html