Skip to content

Commit 5edc33f

Browse files
committed
[ENH] add query config on collection configuration
1 parent 004c30f commit 5edc33f

File tree

6 files changed

+258
-61
lines changed

6 files changed

+258
-61
lines changed

chromadb/api/collection_configuration.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def load_collection_configuration_from_json(
9999
raise ValueError(
100100
f"Could not build embedding function {ef_config['name']} from config {ef_config['config']}: {e}"
101101
)
102-
103102
else:
104103
ef = None
105104

@@ -148,11 +147,6 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st
148147
if ef is None:
149148
ef = None
150149
ef_config = {"type": "legacy"}
151-
return {
152-
"hnsw": hnsw_config,
153-
"spann": spann_config,
154-
"embedding_function": ef_config,
155-
}
156150

157151
if ef is not None:
158152
try:
@@ -260,16 +254,6 @@ class CreateCollectionConfiguration(TypedDict, total=False):
260254
embedding_function: Optional[EmbeddingFunction] # type: ignore
261255

262256

263-
def load_collection_configuration_from_create_collection_configuration(
264-
config: CreateCollectionConfiguration,
265-
) -> CollectionConfiguration:
266-
return CollectionConfiguration(
267-
hnsw=config.get("hnsw"),
268-
spann=config.get("spann"),
269-
embedding_function=config.get("embedding_function"),
270-
)
271-
272-
273257
def create_collection_configuration_from_legacy_collection_metadata(
274258
metadata: CollectionMetadata,
275259
) -> CreateCollectionConfiguration:
@@ -301,13 +285,6 @@ def create_collection_configuration_from_legacy_metadata_dict(
301285
return CreateCollectionConfiguration(hnsw=hnsw_config)
302286

303287

304-
def load_create_collection_configuration_from_json_str(
305-
json_str: str,
306-
) -> CreateCollectionConfiguration:
307-
json_map = json.loads(json_str)
308-
return load_create_collection_configuration_from_json(json_map)
309-
310-
311288
# TODO: make warnings prettier and add link to migration docs
312289
def load_create_collection_configuration_from_json(
313290
json_map: Dict[str, Any]

chromadb/api/models/CollectionCommon.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,9 @@ def _validate_and_prepare_query_request(
313313
# Prepare
314314
if query_records["embeddings"] is None:
315315
validate_record_set_for_embedding(record_set=query_records)
316-
request_embeddings = self._embed_record_set(record_set=query_records)
316+
request_embeddings = self._embed_record_set(
317+
record_set=query_records, is_query=True
318+
)
317319
else:
318320
request_embeddings = query_records["embeddings"]
319321

@@ -531,7 +533,10 @@ def _update_model_after_modify_success(
531533
)
532534

533535
def _embed_record_set(
534-
self, record_set: BaseRecordSet, embeddable_fields: Optional[Set[str]] = None
536+
self,
537+
record_set: BaseRecordSet,
538+
embeddable_fields: Optional[Set[str]] = None,
539+
is_query: bool = False,
535540
) -> Embeddings:
536541
if embeddable_fields is None:
537542
embeddable_fields = get_default_embeddable_record_set_fields()
@@ -545,27 +550,41 @@ def _embed_record_set(
545550
"You must set a data loader on the collection if loading from URIs."
546551
)
547552
return self._embed(
548-
input=self._data_loader(uris=cast(URIs, record_set[field])) # type: ignore[literal-required]
553+
input=self._data_loader(uris=cast(URIs, record_set[field])), # type: ignore[literal-required]
554+
is_query=is_query,
549555
)
550556
else:
551-
return self._embed(input=record_set[field]) # type: ignore[literal-required]
557+
return self._embed(
558+
input=record_set[field], # type: ignore[literal-required]
559+
is_query=is_query,
560+
)
552561
raise ValueError(
553562
"Record does not contain any non-None fields that can be embedded."
554563
f"Embeddable Fields: {embeddable_fields}"
555564
f"Record Fields: {record_set}"
556565
)
557566

558-
def _embed(self, input: Any) -> Embeddings:
567+
def _embed(self, input: Any, is_query: bool = False) -> Embeddings:
559568
if self._embedding_function is not None and not isinstance(
560569
self._embedding_function, ef.DefaultEmbeddingFunction
561570
):
562-
return self._embedding_function(input=input)
571+
if is_query:
572+
return self._embedding_function.embed_query(input=input)
573+
else:
574+
return self._embedding_function(input=input)
575+
563576
config_ef = self.configuration.get("embedding_function")
564577
if config_ef is not None:
565-
return config_ef(input=input)
578+
if is_query:
579+
return config_ef.embed_query(input=input)
580+
else:
581+
return config_ef(input=input)
566582
if self._embedding_function is None:
567583
raise ValueError(
568584
"You must provide an embedding function to compute embeddings."
569585
"https://docs.trychroma.com/guides/embeddings"
570586
)
571-
return self._embedding_function(input=input)
587+
if is_query:
588+
return self._embedding_function.embed_query(input=input)
589+
else:
590+
return self._embedding_function(input=input)

chromadb/api/types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,13 @@ class EmbeddingFunction(Protocol[D]):
569569
def __call__(self, input: D) -> Embeddings:
570570
...
571571

572+
def embed_query(self, input: D) -> Embeddings:
573+
"""
574+
Get the embeddings for a query input.
575+
This method is optional, and if not implemented, the default behavior is to call __call__.
576+
"""
577+
return self.__call__(input)
578+
572579
def __init_subclass__(cls) -> None:
573580
super().__init_subclass__()
574581
# Raise an exception if __call__ is not defined since it is expected to be defined

chromadb/test/configurations/test_collection_configuration.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from chromadb.test.conftest import ClientFactories
2626
from chromadb.test.conftest import is_spann_disabled_mode, skip_reason_spann_disabled
2727
from chromadb.types import Collection as CollectionModel
28+
from typing import Optional, TypedDict
2829

2930

3031
class LegacyEmbeddingFunction(EmbeddingFunction[Embeddable]):
@@ -1616,3 +1617,125 @@ def test_default_space_custom_embedding_function_with_metadata_and_config(
16161617
spann_config = coll.configuration.get("spann")
16171618
assert spann_config is not None
16181619
assert spann_config.get("space") == "ip"
1620+
1621+
1622+
class CustomEmbeddingFunctionQueryConfig(TypedDict):
1623+
task: str
1624+
1625+
1626+
@register_embedding_function
1627+
class CustomEmbeddingFunctionWithQueryConfig(EmbeddingFunction[Embeddable]):
1628+
def __init__(
1629+
self,
1630+
task: str,
1631+
model_name: str,
1632+
dim: int = 3,
1633+
query_config: Optional[CustomEmbeddingFunctionQueryConfig] = None,
1634+
):
1635+
self._dim = dim
1636+
self._model_name = model_name
1637+
self._task = task
1638+
self._query_config = query_config
1639+
1640+
def __call__(self, input: Embeddable) -> Embeddings:
1641+
return cast(Embeddings, np.array([[1.0] * self._dim], dtype=np.float32))
1642+
1643+
def embed_query(self, input: Embeddable) -> Embeddings:
1644+
if self._query_config is not None and self._query_config.get("task") == "query":
1645+
return cast(Embeddings, np.array([[2.0] * self._dim], dtype=np.float32))
1646+
else:
1647+
return self.__call__(input)
1648+
1649+
@staticmethod
1650+
def name() -> str:
1651+
return "custom_ef_with_query_config"
1652+
1653+
def get_config(self) -> Dict[str, Any]:
1654+
return {
1655+
"model_name": self._model_name,
1656+
"dim": self._dim,
1657+
"task": self._task,
1658+
"query_config": self._query_config,
1659+
}
1660+
1661+
@staticmethod
1662+
def build_from_config(
1663+
config: Dict[str, Any]
1664+
) -> "CustomEmbeddingFunctionWithQueryConfig":
1665+
model_name = config.get("model_name")
1666+
dim = config.get("dim")
1667+
task = config.get("task")
1668+
query_config = config.get("query_config")
1669+
1670+
if model_name is None or dim is None:
1671+
assert False, "This code should not be reached"
1672+
1673+
return CustomEmbeddingFunctionWithQueryConfig(
1674+
model_name=model_name, dim=dim, task=task, query_config=query_config # type: ignore
1675+
)
1676+
1677+
def default_space(self) -> Space:
1678+
return "cosine"
1679+
1680+
def supported_spaces(self) -> List[Space]:
1681+
return ["cosine"]
1682+
1683+
1684+
def test_custom_embedding_function_with_query_config(client: ClientAPI) -> None:
1685+
client.reset()
1686+
coll = client.create_collection(
1687+
name="test_custom_embedding_function_with_query_config",
1688+
embedding_function=CustomEmbeddingFunctionWithQueryConfig(
1689+
task="document",
1690+
model_name="i_want_anything",
1691+
dim=3,
1692+
query_config={"task": "query"},
1693+
),
1694+
)
1695+
assert coll is not None
1696+
ef = coll.configuration.get("embedding_function")
1697+
assert ef is not None
1698+
assert ef.name() == "custom_ef_with_query_config"
1699+
assert ef.get_config() == {
1700+
"model_name": "i_want_anything",
1701+
"dim": 3,
1702+
"task": "document",
1703+
"query_config": {"task": "query"},
1704+
}
1705+
assert ef.default_space() == "cosine"
1706+
assert ef.supported_spaces() == ["cosine"]
1707+
assert np.array_equal(
1708+
ef.embed_query(input="How many people in Berlin?"),
1709+
np.array([[2.0, 2.0, 2.0]], dtype=np.float32),
1710+
)
1711+
1712+
1713+
def test_deserializing_custom_embedding_function_with_query_config_no_query_config(
1714+
client: ClientAPI,
1715+
) -> None:
1716+
json_string = """
1717+
{
1718+
"embedding_function": {
1719+
"type": "known",
1720+
"name": "custom_ef_with_query_config",
1721+
"config": {"model_name": "i_want_anything", "dim": 3, "task": "document"}
1722+
}
1723+
}
1724+
"""
1725+
config = load_collection_configuration_from_json(json.loads(json_string))
1726+
assert config is not None
1727+
assert config.get("embedding_function") is not None
1728+
ef = config.get("embedding_function")
1729+
assert ef is not None
1730+
assert ef.get_config() == {
1731+
"model_name": "i_want_anything",
1732+
"dim": 3,
1733+
"task": "document",
1734+
"query_config": None,
1735+
}
1736+
assert ef.default_space() == "cosine"
1737+
assert ef.supported_spaces() == ["cosine"]
1738+
assert np.array_equal(
1739+
ef.embed_query(input="How many people in Berlin?"),
1740+
np.array([[1.0, 1.0, 1.0]], dtype=np.float32),
1741+
)

chromadb/utils/embedding_functions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from chromadb.utils.embedding_functions.jina_embedding_function import (
3434
JinaEmbeddingFunction,
35+
JinaQueryConfig,
3536
)
3637
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
3738
VoyageAIEmbeddingFunction,
@@ -237,6 +238,7 @@ def config_to_embedding_function(config: Dict[str, Any]) -> EmbeddingFunction:
237238
"OllamaEmbeddingFunction",
238239
"InstructorEmbeddingFunction",
239240
"JinaEmbeddingFunction",
241+
"JinaQueryConfig",
240242
"MistralEmbeddingFunction",
241243
"MorphEmbeddingFunction",
242244
"VoyageAIEmbeddingFunction",

0 commit comments

Comments
 (0)