Skip to content

Commit 6c083cf

Browse files
committed
[ENH] add query config on collection configuration
1 parent 5e1ec94 commit 6c083cf

File tree

9 files changed

+737
-62
lines changed

9 files changed

+737
-62
lines changed

chromadb/api/collection_configuration.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def load_collection_configuration_from_json(
9999
raise ValueError(
100100
f"Could not build embedding function {ef_config['name']} from config {ef_config['config']}: {e}"
101101
)
102-
103102
else:
104103
ef = None
105104

@@ -148,11 +147,6 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st
148147
if ef is None:
149148
ef = None
150149
ef_config = {"type": "legacy"}
151-
return {
152-
"hnsw": hnsw_config,
153-
"spann": spann_config,
154-
"embedding_function": ef_config,
155-
}
156150

157151
if ef is not None:
158152
try:
@@ -260,16 +254,6 @@ class CreateCollectionConfiguration(TypedDict, total=False):
260254
embedding_function: Optional[EmbeddingFunction] # type: ignore
261255

262256

263-
def load_collection_configuration_from_create_collection_configuration(
264-
config: CreateCollectionConfiguration,
265-
) -> CollectionConfiguration:
266-
return CollectionConfiguration(
267-
hnsw=config.get("hnsw"),
268-
spann=config.get("spann"),
269-
embedding_function=config.get("embedding_function"),
270-
)
271-
272-
273257
def create_collection_configuration_from_legacy_collection_metadata(
274258
metadata: CollectionMetadata,
275259
) -> CreateCollectionConfiguration:
@@ -301,13 +285,6 @@ def create_collection_configuration_from_legacy_metadata_dict(
301285
return CreateCollectionConfiguration(hnsw=hnsw_config)
302286

303287

304-
def load_create_collection_configuration_from_json_str(
305-
json_str: str,
306-
) -> CreateCollectionConfiguration:
307-
json_map = json.loads(json_str)
308-
return load_create_collection_configuration_from_json(json_map)
309-
310-
311288
# TODO: make warnings prettier and add link to migration docs
312289
def load_create_collection_configuration_from_json(
313290
json_map: Dict[str, Any]

chromadb/api/models/CollectionCommon.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,9 @@ def _validate_and_prepare_query_request(
313313
# Prepare
314314
if query_records["embeddings"] is None:
315315
validate_record_set_for_embedding(record_set=query_records)
316-
request_embeddings = self._embed_record_set(record_set=query_records)
316+
request_embeddings = self._embed_record_set(
317+
record_set=query_records, is_query=True
318+
)
317319
else:
318320
request_embeddings = query_records["embeddings"]
319321

@@ -531,7 +533,10 @@ def _update_model_after_modify_success(
531533
)
532534

533535
def _embed_record_set(
534-
self, record_set: BaseRecordSet, embeddable_fields: Optional[Set[str]] = None
536+
self,
537+
record_set: BaseRecordSet,
538+
embeddable_fields: Optional[Set[str]] = None,
539+
is_query: bool = False,
535540
) -> Embeddings:
536541
if embeddable_fields is None:
537542
embeddable_fields = get_default_embeddable_record_set_fields()
@@ -545,27 +550,41 @@ def _embed_record_set(
545550
"You must set a data loader on the collection if loading from URIs."
546551
)
547552
return self._embed(
548-
input=self._data_loader(uris=cast(URIs, record_set[field])) # type: ignore[literal-required]
553+
input=self._data_loader(uris=cast(URIs, record_set[field])), # type: ignore[literal-required]
554+
is_query=is_query,
549555
)
550556
else:
551-
return self._embed(input=record_set[field]) # type: ignore[literal-required]
557+
return self._embed(
558+
input=record_set[field], # type: ignore[literal-required]
559+
is_query=is_query,
560+
)
552561
raise ValueError(
553562
"Record does not contain any non-None fields that can be embedded."
554563
f"Embeddable Fields: {embeddable_fields}"
555564
f"Record Fields: {record_set}"
556565
)
557566

558-
def _embed(self, input: Any) -> Embeddings:
567+
def _embed(self, input: Any, is_query: bool = False) -> Embeddings:
559568
if self._embedding_function is not None and not isinstance(
560569
self._embedding_function, ef.DefaultEmbeddingFunction
561570
):
562-
return self._embedding_function(input=input)
571+
if is_query:
572+
return self._embedding_function.embed_query(input=input)
573+
else:
574+
return self._embedding_function(input=input)
575+
563576
config_ef = self.configuration.get("embedding_function")
564577
if config_ef is not None:
565-
return config_ef(input=input)
578+
if is_query:
579+
return config_ef.embed_query(input=input)
580+
else:
581+
return config_ef(input=input)
566582
if self._embedding_function is None:
567583
raise ValueError(
568584
"You must provide an embedding function to compute embeddings."
569585
"https://docs.trychroma.com/guides/embeddings"
570586
)
571-
return self._embedding_function(input=input)
587+
if is_query:
588+
return self._embedding_function.embed_query(input=input)
589+
else:
590+
return self._embedding_function(input=input)

chromadb/api/types.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ def maybe_cast_one_to_many(target: Optional[OneOrMany[T]]) -> Optional[List[T]]:
7979
PyEmbeddings = List[PyEmbedding]
8080
Embedding = Vector
8181
Embeddings = List[Embedding]
82+
SparseEmbedding = SparseVector
83+
SparseEmbeddings = List[SparseEmbedding]
8284

8385
Space = Literal["cosine", "l2", "ip"]
8486

@@ -569,6 +571,13 @@ class EmbeddingFunction(Protocol[D]):
569571
def __call__(self, input: D) -> Embeddings:
570572
...
571573

574+
def embed_query(self, input: D) -> Embeddings:
575+
"""
576+
Get the embeddings for a query input.
577+
This method is optional, and if not implemented, the default behavior is to call __call__.
578+
"""
579+
return self.__call__(input)
580+
572581
def __init_subclass__(cls) -> None:
573582
super().__init_subclass__()
574583
# Raise an exception if __call__ is not defined since it is expected to be defined
@@ -1096,6 +1105,21 @@ def validate_embeddings(embeddings: Embeddings) -> Embeddings:
10961105
return embeddings
10971106

10981107

1108+
def validate_sparse_embeddings(embeddings: SparseEmbeddings) -> SparseEmbeddings:
1109+
"""Validates sparse embeddings to ensure it is a list of sparse vectors"""
1110+
if not isinstance(embeddings, list):
1111+
raise ValueError(
1112+
f"Expected sparse embeddings to be a list, got {type(embeddings).__name__}"
1113+
)
1114+
if len(embeddings) == 0:
1115+
raise ValueError(
1116+
f"Expected sparse embeddings to be a non-empty list, got {len(embeddings)} sparse embeddings"
1117+
)
1118+
for embedding in embeddings:
1119+
validate_sparse_vector(embedding)
1120+
return embeddings
1121+
1122+
10991123
def validate_documents(documents: Documents, nullable: bool = False) -> None:
11001124
"""Validates documents to ensure it is a list of strings"""
11011125
if not isinstance(documents, list):
@@ -1150,3 +1174,95 @@ def convert_np_embeddings_to_list(embeddings: Embeddings) -> PyEmbeddings:
11501174

11511175
def convert_list_embeddings_to_np(embeddings: PyEmbeddings) -> Embeddings:
11521176
return [np.array(embedding) for embedding in embeddings]
1177+
1178+
1179+
@runtime_checkable
1180+
class SparseEmbeddingFunction(Protocol[D]):
1181+
"""
1182+
A protocol for sparse embedding functions. To implement a new sparse embedding function,
1183+
you need to implement the following methods at minimum:
1184+
- __call__
1185+
1186+
For future compatibility, it is strongly recommended to also implement:
1187+
- __init__
1188+
- name
1189+
- build_from_config
1190+
- get_config
1191+
"""
1192+
1193+
@abstractmethod
1194+
def __call__(self, input: D) -> SparseEmbeddings:
1195+
...
1196+
1197+
def embed_query(self, input: D) -> SparseEmbeddings:
1198+
"""
1199+
Get the embeddings for a query input.
1200+
This method is optional, and if not implemented, the default behavior is to call __call__.
1201+
"""
1202+
return self.__call__(input)
1203+
1204+
def __init_subclass__(cls) -> None:
1205+
super().__init_subclass__()
1206+
# Raise an exception if __call__ is not defined since it is expected to be defined
1207+
call = getattr(cls, "__call__")
1208+
1209+
def __call__(self: SparseEmbeddingFunction[D], input: D) -> SparseEmbeddings:
1210+
result = call(self, input)
1211+
assert result is not None
1212+
return validate_sparse_embeddings(cast(SparseEmbeddings, result))
1213+
1214+
setattr(cls, "__call__", __call__)
1215+
1216+
def embed_with_retries(
1217+
self, input: D, **retry_kwargs: Dict[str, Any]
1218+
) -> SparseEmbeddings:
1219+
return cast(SparseEmbeddings, retry(**retry_kwargs)(self.__call__)(input))
1220+
1221+
@abstractmethod
1222+
def __init__(self, *args: Any, **kwargs: Any) -> None:
1223+
"""
1224+
Initialize the embedding function.
1225+
Pass any arguments that will be needed to build the embedding function
1226+
config.
1227+
"""
1228+
...
1229+
1230+
@staticmethod
1231+
@abstractmethod
1232+
def name() -> str:
1233+
"""
1234+
Return the name of the embedding function.
1235+
"""
1236+
...
1237+
1238+
@staticmethod
1239+
@abstractmethod
1240+
def build_from_config(config: Dict[str, Any]) -> "SparseEmbeddingFunction[D]":
1241+
"""
1242+
Build the embedding function from a config, which will be used to
1243+
deserialize the embedding function.
1244+
"""
1245+
...
1246+
1247+
@abstractmethod
1248+
def get_config(self) -> Dict[str, Any]:
1249+
"""
1250+
Return the config for the embedding function, which will be used to
1251+
serialize the embedding function.
1252+
"""
1253+
...
1254+
1255+
def validate_config_update(
1256+
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
1257+
) -> None:
1258+
"""
1259+
Validate the update to the config.
1260+
"""
1261+
return
1262+
1263+
@staticmethod
1264+
def validate_config(config: Dict[str, Any]) -> None:
1265+
"""
1266+
Validate the config.
1267+
"""
1268+
return

chromadb/test/configurations/test_collection_configuration.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from chromadb.test.conftest import ClientFactories
2626
from chromadb.test.conftest import is_spann_disabled_mode, skip_reason_spann_disabled
2727
from chromadb.types import Collection as CollectionModel
28+
from typing import Optional, TypedDict
2829

2930

3031
class LegacyEmbeddingFunction(EmbeddingFunction[Embeddable]):
@@ -1616,3 +1617,125 @@ def test_default_space_custom_embedding_function_with_metadata_and_config(
16161617
spann_config = coll.configuration.get("spann")
16171618
assert spann_config is not None
16181619
assert spann_config.get("space") == "ip"
1620+
1621+
1622+
class CustomEmbeddingFunctionQueryConfig(TypedDict):
1623+
task: str
1624+
1625+
1626+
@register_embedding_function
1627+
class CustomEmbeddingFunctionWithQueryConfig(EmbeddingFunction[Embeddable]):
1628+
def __init__(
1629+
self,
1630+
task: str,
1631+
model_name: str,
1632+
dim: int = 3,
1633+
query_config: Optional[CustomEmbeddingFunctionQueryConfig] = None,
1634+
):
1635+
self._dim = dim
1636+
self._model_name = model_name
1637+
self._task = task
1638+
self._query_config = query_config
1639+
1640+
def __call__(self, input: Embeddable) -> Embeddings:
1641+
return cast(Embeddings, np.array([[1.0] * self._dim], dtype=np.float32))
1642+
1643+
def embed_query(self, input: Embeddable) -> Embeddings:
1644+
if self._query_config is not None and self._query_config.get("task") == "query":
1645+
return cast(Embeddings, np.array([[2.0] * self._dim], dtype=np.float32))
1646+
else:
1647+
return self.__call__(input)
1648+
1649+
@staticmethod
1650+
def name() -> str:
1651+
return "custom_ef_with_query_config"
1652+
1653+
def get_config(self) -> Dict[str, Any]:
1654+
return {
1655+
"model_name": self._model_name,
1656+
"dim": self._dim,
1657+
"task": self._task,
1658+
"query_config": self._query_config,
1659+
}
1660+
1661+
@staticmethod
1662+
def build_from_config(
1663+
config: Dict[str, Any]
1664+
) -> "CustomEmbeddingFunctionWithQueryConfig":
1665+
model_name = config.get("model_name")
1666+
dim = config.get("dim")
1667+
task = config.get("task")
1668+
query_config = config.get("query_config")
1669+
1670+
if model_name is None or dim is None:
1671+
assert False, "This code should not be reached"
1672+
1673+
return CustomEmbeddingFunctionWithQueryConfig(
1674+
model_name=model_name, dim=dim, task=task, query_config=query_config # type: ignore
1675+
)
1676+
1677+
def default_space(self) -> Space:
1678+
return "cosine"
1679+
1680+
def supported_spaces(self) -> List[Space]:
1681+
return ["cosine"]
1682+
1683+
1684+
def test_custom_embedding_function_with_query_config(client: ClientAPI) -> None:
1685+
client.reset()
1686+
coll = client.create_collection(
1687+
name="test_custom_embedding_function_with_query_config",
1688+
embedding_function=CustomEmbeddingFunctionWithQueryConfig(
1689+
task="document",
1690+
model_name="i_want_anything",
1691+
dim=3,
1692+
query_config={"task": "query"},
1693+
),
1694+
)
1695+
assert coll is not None
1696+
ef = coll.configuration.get("embedding_function")
1697+
assert ef is not None
1698+
assert ef.name() == "custom_ef_with_query_config"
1699+
assert ef.get_config() == {
1700+
"model_name": "i_want_anything",
1701+
"dim": 3,
1702+
"task": "document",
1703+
"query_config": {"task": "query"},
1704+
}
1705+
assert ef.default_space() == "cosine"
1706+
assert ef.supported_spaces() == ["cosine"]
1707+
assert np.array_equal(
1708+
ef.embed_query(input="How many people in Berlin?"),
1709+
np.array([[2.0, 2.0, 2.0]], dtype=np.float32),
1710+
)
1711+
1712+
1713+
def test_deserializing_custom_embedding_function_with_query_config_no_query_config(
1714+
client: ClientAPI,
1715+
) -> None:
1716+
json_string = """
1717+
{
1718+
"embedding_function": {
1719+
"type": "known",
1720+
"name": "custom_ef_with_query_config",
1721+
"config": {"model_name": "i_want_anything", "dim": 3, "task": "document"}
1722+
}
1723+
}
1724+
"""
1725+
config = load_collection_configuration_from_json(json.loads(json_string))
1726+
assert config is not None
1727+
assert config.get("embedding_function") is not None
1728+
ef = config.get("embedding_function")
1729+
assert ef is not None
1730+
assert ef.get_config() == {
1731+
"model_name": "i_want_anything",
1732+
"dim": 3,
1733+
"task": "document",
1734+
"query_config": None,
1735+
}
1736+
assert ef.default_space() == "cosine"
1737+
assert ef.supported_spaces() == ["cosine"]
1738+
assert np.array_equal(
1739+
ef.embed_query(input="How many people in Berlin?"),
1740+
np.array([[1.0, 1.0, 1.0]], dtype=np.float32),
1741+
)

0 commit comments

Comments
 (0)