Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions chromadb/test/ef/test_ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def test_get_builtins_holds() -> None:
"ChromaCloudQwenEmbeddingFunction",
"ChromaCloudSpladeEmbeddingFunction",
"ChromaBm25EmbeddingFunction",
"PylateColBERTEmbeddingFunction",
}

assert expected_builtins == embedding_functions.get_builtins()
Expand Down
6 changes: 6 additions & 0 deletions chromadb/utils/embedding_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@
ChromaBm25EmbeddingFunction,
)

from chromadb.utils.embedding_functions.pylate_colbert_embedding_function import (
PylateColBERTEmbeddingFunction,
)

# Get all the class names for backward compatibility
_all_classes: Set[str] = {
Expand Down Expand Up @@ -127,6 +130,7 @@
"ChromaCloudQwenEmbeddingFunction",
"ChromaCloudSpladeEmbeddingFunction",
"ChromaBm25EmbeddingFunction",
"PylateColBERTEmbeddingFunction",
}


Expand Down Expand Up @@ -163,6 +167,7 @@ def get_builtins() -> Set[str]:
"cloudflare_workers_ai": CloudflareWorkersAIEmbeddingFunction,
"together_ai": TogetherAIEmbeddingFunction,
"chroma-cloud-qwen": ChromaCloudQwenEmbeddingFunction,
"pylate_colbert": PylateColBERTEmbeddingFunction,
}

sparse_known_embedding_functions: Dict[str, Type[SparseEmbeddingFunction]] = { # type: ignore
Expand Down Expand Up @@ -291,6 +296,7 @@ def config_to_embedding_function(config: Dict[str, Any]) -> EmbeddingFunction:
"ChromaCloudQwenEmbeddingFunction",
"ChromaCloudSpladeEmbeddingFunction",
"ChromaBm25EmbeddingFunction",
"PylateColBERTEmbeddingFunction",
"register_embedding_function",
"config_to_embedding_function",
"known_embedding_functions",
Expand Down
37 changes: 37 additions & 0 deletions chromadb/utils/embedding_functions/jina_embedding_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import numpy as np
import warnings
from chromadb.utils.muvera import create_fdes
import importlib
import base64
import io
Expand All @@ -37,6 +38,7 @@ def __init__(
dimensions: Optional[int] = None,
embedding_type: Optional[str] = None,
normalized: Optional[bool] = None,
return_multivector: Optional[bool] = None,
query_config: Optional[JinaQueryConfig] = None,
):
"""
Expand Down Expand Up @@ -101,6 +103,7 @@ def __init__(
self.dimensions = dimensions
self.embedding_type = embedding_type
self.normalized = normalized
self.return_multivector = return_multivector
self.query_config = query_config

self._api_url = "https://api.jina.ai/v1/embeddings"
Expand Down Expand Up @@ -149,6 +152,8 @@ def _build_payload(self, input: Embeddable, is_query: bool) -> Dict[str, Any]:
payload["embedding_type"] = self.embedding_type
if self.normalized is not None:
payload["normalized"] = self.normalized
if self.return_multivector is not None:
payload["return_multivector"] = self.return_multivector

# overwrite parameteres when query payload is used
if is_query and self.query_config is not None:
Expand All @@ -170,6 +175,35 @@ def _convert_resp(self, resp: Any, is_query: bool = False) -> Embeddings:
if "data" not in resp:
raise RuntimeError(resp.get("detail", "Unknown error"))

if self.return_multivector:
# if it gives back multivector embeddings
multi_embeddings_data: List[Dict[str, Any]] = resp["data"]
sorted_multi_embeddings = sorted(
multi_embeddings_data, key=lambda e: e["index"]
)
multi_embeddings: List[Embeddings] = [
[
np.array(vec, dtype=np.float32)
for vec in multi_embedding_obj["embeddings"]
]
for multi_embedding_obj in sorted_multi_embeddings
]

if not multi_embeddings or not multi_embeddings[0]:
raise RuntimeError(
"Invalid multivector embeddings format from Jina API"
)

dims = len(multi_embeddings[0][0])
fdes = create_fdes(
multi_embeddings,
dims=dims,
is_query=is_query,
fill_empty_partitions=not is_query,
)

return fdes

embeddings_data: List[Dict[str, Union[int, List[float]]]] = resp["data"]

# Sort resulting embeddings by index
Expand Down Expand Up @@ -231,6 +265,7 @@ def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]"
dimensions = config.get("dimensions")
embedding_type = config.get("embedding_type")
normalized = config.get("normalized")
return_multivector = config.get("return_multivector")
query_config = config.get("query_config")

if api_key_env_var is None or model_name is None:
Expand All @@ -245,6 +280,7 @@ def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]"
dimensions=dimensions,
embedding_type=embedding_type,
normalized=normalized,
return_multivector=return_multivector,
query_config=query_config,
)

Expand All @@ -258,6 +294,7 @@ def get_config(self) -> Dict[str, Any]:
"dimensions": self.dimensions,
"embedding_type": self.embedding_type,
"normalized": self.normalized,
"return_multivector": self.return_multivector,
"query_config": self.query_config,
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
from typing import List, Dict, Any
from chromadb.utils.embedding_functions.schemas import validate_config_schema
from chromadb.utils.muvera import create_fdes


class PylateColBERTEmbeddingFunction(EmbeddingFunction[Documents]):
"""
This class is used to get embeddings for a list of texts using the ColBERT API.
"""

def __init__(
self,
model_name: str,
):
"""
Initialize the PylateColBERTEmbeddingFunction.

Args:
model_name (str): The name of the model to use for text embeddings.
Examples: "mixedbread-ai/mxbai-edge-colbert-v0-17m", "mixedbread-ai/mxbai-edge-colbert-v0-32m", "lightonai/colbertv2.0", "answerdotai/answerai-colbert-small-v1", "jinaai/jina-colbert-v2", "GTE-ModernColBERT-v1"
"""
try:
from pylate import models
except ImportError:
raise ValueError(
"The pylate colbert python package is not installed. Please install it with `pip install pylate-colbert`"
)

self.model_name = model_name
self.model = models.ColBERT(model_name_or_path=model_name)

def __call__(self, input: Documents) -> Embeddings:
"""
Get the embeddings for a list of texts.

Args:
input (Documents): A list of texts to get embeddings for.

Returns:
Embeddings: The embeddings for the texts.
"""
multivec = self.model.encode(input, batch_size=32, is_query=False)
if not multivec or not multivec[0]:
raise ValueError("Model returned empty multivector embeddings")
return create_fdes(
multivec,
dims=len(multivec[0][0]),
is_query=False,
fill_empty_partitions=True,
)

def embed_query(self, input: Documents) -> Embeddings:
"""
Get the embeddings for a list of texts.

Args:
input (Documents): A list of texts to get embeddings for.

Returns:
Embeddings: The embeddings for the texts.
"""
multivec = self.model.encode(input, batch_size=32, is_query=True)
if not multivec or not multivec[0]:
raise ValueError("Model returned empty multivector embeddings")
return create_fdes(
multivec,
dims=len(multivec[0][0]),
is_query=True,
fill_empty_partitions=False,
)

@staticmethod
def name() -> str:
return "pylate_colbert"

def default_space(self) -> Space:
return "ip" # muvera uses dot product to approximate multivec similarity

def supported_spaces(self) -> List[Space]:
return [
"ip"
] # no cosine bc muvera does not normalize the fde, no l2 bc muvera uses dot product

@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
model_name = config.get("model_name")

if model_name is None:
assert False, "This code should not be reached"

return PylateColBERTEmbeddingFunction(model_name=model_name)

def get_config(self) -> Dict[str, Any]:
return {"model_name": self.model_name}

def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
if "model_name" in new_config:
raise ValueError(
"The model name cannot be changed after the embedding function has been initialized."
)

@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.

Args:
config: Configuration to validate

Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "pylate_colbert")
Loading
Loading