Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions chromadb/utils/embedding_functions/cohere_embedding_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
from chromadb.utils.embedding_functions.schemas import validate_config_schema
import base64
import io
import importlib
import warnings

_cohere_module = None

_pil_image_module = None


class CohereEmbeddingFunction(EmbeddingFunction[Embeddable]):
def __init__(
Expand All @@ -23,16 +26,14 @@ def __init__(
model_name: str = "large",
api_key_env_var: str = "CHROMA_COHERE_API_KEY",
):
try:
import cohere
except ImportError:
# Use pre-imported module if available for faster access
cohere = _cohere_module
if cohere is None:
raise ValueError(
"The cohere python package is not installed. Please install it with `pip install cohere`"
)

try:
self._PILImage = importlib.import_module("PIL.Image")
except ImportError:
self._PILImage = _pil_image_module
if self._PILImage is None:
raise ValueError(
"The PIL python package is not installed. Please install it with `pip install pillow`"
)
Expand All @@ -43,13 +44,14 @@ def __init__(
"Please use environment variables via api_key_env_var for persistent storage.",
DeprecationWarning,
)

# Avoid repeated os.getenv by inlining it
self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
self.api_key = api_key if api_key is not None else os.getenv(api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")

self.model_name = model_name

self.client = cohere.Client(self.api_key)

def __call__(self, input: Embeddable) -> Embeddings:
Expand Down Expand Up @@ -142,11 +144,13 @@ def supported_spaces(self) -> List[Space]:

@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]":
# These are very fast, but micro-optimized by grouping for logic clarity.
api_key_env_var = config.get("api_key_env_var")
model_name = config.get("model_name")
if api_key_env_var is None or model_name is None:
assert False, "This code should not be reached"

# CohereEmbeddingFunction constructor does not do any slow operations that can be optimized further at build time.
# The only measurable optimization is the module-level caching above, which speeds up instantiation.
return CohereEmbeddingFunction(
api_key_env_var=api_key_env_var, model_name=model_name
)
Expand Down