Skip to content

Commit e67eb4f

Browse files
committed
[ENH] add query config on collection configuration
1 parent 533f190 commit e67eb4f

File tree

7 files changed

+462
-62
lines changed

7 files changed

+462
-62
lines changed

chromadb/api/collection_configuration.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def load_collection_configuration_from_json(
9999
raise ValueError(
100100
f"Could not build embedding function {ef_config['name']} from config {ef_config['config']}: {e}"
101101
)
102-
103102
else:
104103
ef = None
105104

@@ -148,11 +147,6 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st
148147
if ef is None:
149148
ef = None
150149
ef_config = {"type": "legacy"}
151-
return {
152-
"hnsw": hnsw_config,
153-
"spann": spann_config,
154-
"embedding_function": ef_config,
155-
}
156150

157151
if ef is not None:
158152
try:
@@ -260,16 +254,6 @@ class CreateCollectionConfiguration(TypedDict, total=False):
260254
embedding_function: Optional[EmbeddingFunction] # type: ignore
261255

262256

263-
def load_collection_configuration_from_create_collection_configuration(
264-
config: CreateCollectionConfiguration,
265-
) -> CollectionConfiguration:
266-
return CollectionConfiguration(
267-
hnsw=config.get("hnsw"),
268-
spann=config.get("spann"),
269-
embedding_function=config.get("embedding_function"),
270-
)
271-
272-
273257
def create_collection_configuration_from_legacy_collection_metadata(
274258
metadata: CollectionMetadata,
275259
) -> CreateCollectionConfiguration:
@@ -301,13 +285,6 @@ def create_collection_configuration_from_legacy_metadata_dict(
301285
return CreateCollectionConfiguration(hnsw=hnsw_config)
302286

303287

304-
def load_create_collection_configuration_from_json_str(
305-
json_str: str,
306-
) -> CreateCollectionConfiguration:
307-
json_map = json.loads(json_str)
308-
return load_create_collection_configuration_from_json(json_map)
309-
310-
311288
# TODO: make warnings prettier and add link to migration docs
312289
def load_create_collection_configuration_from_json(
313290
json_map: Dict[str, Any]

chromadb/api/models/CollectionCommon.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,9 @@ def _validate_and_prepare_query_request(
313313
# Prepare
314314
if query_records["embeddings"] is None:
315315
validate_record_set_for_embedding(record_set=query_records)
316-
request_embeddings = self._embed_record_set(record_set=query_records)
316+
request_embeddings = self._embed_record_set(
317+
record_set=query_records, is_query=True
318+
)
317319
else:
318320
request_embeddings = query_records["embeddings"]
319321

@@ -531,7 +533,10 @@ def _update_model_after_modify_success(
531533
)
532534

533535
def _embed_record_set(
534-
self, record_set: BaseRecordSet, embeddable_fields: Optional[Set[str]] = None
536+
self,
537+
record_set: BaseRecordSet,
538+
embeddable_fields: Optional[Set[str]] = None,
539+
is_query: bool = False,
535540
) -> Embeddings:
536541
if embeddable_fields is None:
537542
embeddable_fields = get_default_embeddable_record_set_fields()
@@ -545,27 +550,41 @@ def _embed_record_set(
545550
"You must set a data loader on the collection if loading from URIs."
546551
)
547552
return self._embed(
548-
input=self._data_loader(uris=cast(URIs, record_set[field])) # type: ignore[literal-required]
553+
input=self._data_loader(uris=cast(URIs, record_set[field])), # type: ignore[literal-required]
554+
is_query=is_query,
549555
)
550556
else:
551-
return self._embed(input=record_set[field]) # type: ignore[literal-required]
557+
return self._embed(
558+
input=record_set[field], # type: ignore[literal-required]
559+
is_query=is_query,
560+
)
552561
raise ValueError(
553562
"Record does not contain any non-None fields that can be embedded."
554563
f"Embeddable Fields: {embeddable_fields}"
555564
f"Record Fields: {record_set}"
556565
)
557566

558-
def _embed(self, input: Any) -> Embeddings:
567+
def _embed(self, input: Any, is_query: bool = False) -> Embeddings:
559568
if self._embedding_function is not None and not isinstance(
560569
self._embedding_function, ef.DefaultEmbeddingFunction
561570
):
562-
return self._embedding_function(input=input)
571+
if is_query:
572+
return self._embedding_function.embed_query(input=input)
573+
else:
574+
return self._embedding_function(input=input)
575+
563576
config_ef = self.configuration.get("embedding_function")
564577
if config_ef is not None:
565-
return config_ef(input=input)
578+
if is_query:
579+
return config_ef.embed_query(input=input)
580+
else:
581+
return config_ef(input=input)
566582
if self._embedding_function is None:
567583
raise ValueError(
568584
"You must provide an embedding function to compute embeddings."
569585
"https://docs.trychroma.com/guides/embeddings"
570586
)
571-
return self._embedding_function(input=input)
587+
if is_query:
588+
return self._embedding_function.embed_query(input=input)
589+
else:
590+
return self._embedding_function(input=input)

chromadb/api/types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,13 @@ class EmbeddingFunction(Protocol[D]):
545545
def __call__(self, input: D) -> Embeddings:
546546
...
547547

548+
def embed_query(self, input: D) -> Embeddings:
549+
"""
550+
Get the embeddings for a query input.
551+
This method is optional, and if not implemented, the default behavior is to call __call__.
552+
"""
553+
return self.__call__(input)
554+
548555
def __init_subclass__(cls) -> None:
549556
super().__init_subclass__()
550557
# Raise an exception if __call__ is not defined since it is expected to be defined

chromadb/utils/embedding_functions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from chromadb.utils.embedding_functions.jina_embedding_function import (
3434
JinaEmbeddingFunction,
35+
JinaQueryConfig,
3536
)
3637
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
3738
VoyageAIEmbeddingFunction,
@@ -232,6 +233,7 @@ def config_to_embedding_function(config: Dict[str, Any]) -> EmbeddingFunction:
232233
"OllamaEmbeddingFunction",
233234
"InstructorEmbeddingFunction",
234235
"JinaEmbeddingFunction",
236+
"JinaQueryConfig",
235237
"MistralEmbeddingFunction",
236238
"VoyageAIEmbeddingFunction",
237239
"ONNXMiniLM_L6_V2",

0 commit comments

Comments
 (0)