Skip to content

Commit 8d05470

Browse files
authored
[ENH] Validate schemas for ef in schema (#5833)
## Description of changes _Summarize the changes made by this PR._ - Improvements & Bug fixes - This PR adds validation for embedding functions in schema when serializing and adds a missing schema for baseten ef - New functionality - ... ## Test plan _How are these changes tested?_ - [ ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Migration plan _Are there any migrations, or any forwards/backwards compatibility changes needed in order to make sure this change deploys reliably?_ ## Observability plan _What is the plan to instrument and monitor this change?_ ## Documentation Changes _Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs section](https://github.com/chroma-core/chroma/tree/main/docs/docs.trychroma.com)?_
1 parent f557b7c commit 8d05470

File tree

7 files changed

+620
-6
lines changed

7 files changed

+620
-6
lines changed

chromadb/api/types.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1530,7 +1530,9 @@ class VectorIndexConfig(BaseModel):
15301530
model_config = {"arbitrary_types_allowed": True}
15311531
space: Optional[Space] = None
15321532
embedding_function: Optional[Any] = DefaultEmbeddingFunction()
1533-
source_key: Optional[str] = None # key to source the vector from (accepts str or Key)
1533+
source_key: Optional[
1534+
str
1535+
] = None # key to source the vector from (accepts str or Key)
15341536
hnsw: Optional[HnswIndexConfig] = None
15351537
spann: Optional[SpannIndexConfig] = None
15361538

@@ -1542,6 +1544,7 @@ def validate_source_key_field(cls, v: Any) -> Optional[str]:
15421544
return None
15431545
# Import Key at runtime to avoid circular import
15441546
from chromadb.execution.expression.operator import Key as KeyType
1547+
15451548
if isinstance(v, KeyType):
15461549
v = v.name # Extract string from Key
15471550
elif isinstance(v, str):
@@ -1577,7 +1580,9 @@ class SparseVectorIndexConfig(BaseModel):
15771580
model_config = {"arbitrary_types_allowed": True}
15781581
# TODO(Sanket): Change this to the appropriate sparse ef and use a default here.
15791582
embedding_function: Optional[Any] = None
1580-
source_key: Optional[str] = None # key to source the sparse vector from (accepts str or Key)
1583+
source_key: Optional[
1584+
str
1585+
] = None # key to source the sparse vector from (accepts str or Key)
15811586
bm25: Optional[bool] = None
15821587

15831588
@field_validator("source_key", mode="before")
@@ -1588,6 +1593,7 @@ def validate_source_key_field(cls, v: Any) -> Optional[str]:
15881593
return None
15891594
# Import Key at runtime to avoid circular import
15901595
from chromadb.execution.expression.operator import Key as KeyType
1596+
15911597
if isinstance(v, KeyType):
15921598
v = v.name # Extract string from Key
15931599
elif isinstance(v, str):
@@ -1787,11 +1793,14 @@ def __init__(self) -> None:
17871793
self._initialize_keys()
17881794

17891795
def create_index(
1790-
self, config: Optional[IndexConfig] = None, key: Optional[Union[str, "Key"]] = None
1796+
self,
1797+
config: Optional[IndexConfig] = None,
1798+
key: Optional[Union[str, "Key"]] = None,
17911799
) -> "Schema":
17921800
"""Create an index configuration."""
17931801
# Convert Key to string if provided
17941802
from chromadb.execution.expression.operator import Key as KeyType
1803+
17951804
if key is not None and isinstance(key, KeyType):
17961805
key = key.name
17971806

@@ -1869,11 +1878,14 @@ def create_index(
18691878
return self
18701879

18711880
def delete_index(
1872-
self, config: Optional[IndexConfig] = None, key: Optional[Union[str, "Key"]] = None
1881+
self,
1882+
config: Optional[IndexConfig] = None,
1883+
key: Optional[Union[str, "Key"]] = None,
18731884
) -> "Schema":
18741885
"""Disable an index configuration (set enabled=False)."""
18751886
# Convert Key to string if provided
18761887
from chromadb.execution.expression.operator import Key as KeyType
1888+
18771889
if key is not None and isinstance(key, KeyType):
18781890
key = key.name
18791891

@@ -2410,6 +2422,10 @@ def _serialize_config(self, config: IndexConfig) -> Dict[str, Any]:
24102422
if embedding_func.is_legacy():
24112423
config_dict["embedding_function"] = {"type": "legacy"}
24122424
else:
2425+
if hasattr(embedding_func, "validate_config"):
2426+
embedding_func.validate_config(
2427+
embedding_func.get_config()
2428+
)
24132429
config_dict["embedding_function"] = {
24142430
"name": embedding_func.name(),
24152431
"type": "known",
@@ -2439,6 +2455,8 @@ def _serialize_config(self, config: IndexConfig) -> Dict[str, Any]:
24392455
config_dict["embedding_function"] = {"type": "unknown"}
24402456
else:
24412457
embedding_func = cast(SparseEmbeddingFunction, embedding_func) # type: ignore
2458+
if hasattr(embedding_func, "validate_config"):
2459+
embedding_func.validate_config(embedding_func.get_config())
24422460
config_dict["embedding_function"] = {
24432461
"name": embedding_func.name(),
24442462
"type": "known",

0 commit comments

Comments
 (0)