diff --git a/chromadb/utils/embedding_functions/cohere_embedding_function.py b/chromadb/utils/embedding_functions/cohere_embedding_function.py index fb1a5edfc57..e40e479c79f 100644 --- a/chromadb/utils/embedding_functions/cohere_embedding_function.py +++ b/chromadb/utils/embedding_functions/cohere_embedding_function.py @@ -12,9 +12,12 @@ from chromadb.utils.embedding_functions.schemas import validate_config_schema import base64 import io -import importlib import warnings +_cohere_module = None + +_pil_image_module = None + class CohereEmbeddingFunction(EmbeddingFunction[Embeddable]): def __init__( @@ -23,16 +26,14 @@ def __init__( model_name: str = "large", api_key_env_var: str = "CHROMA_COHERE_API_KEY", ): - try: - import cohere - except ImportError: + # Use pre-imported module if available for faster access + cohere = _cohere_module + if cohere is None: raise ValueError( "The cohere python package is not installed. Please install it with `pip install cohere`" ) - - try: - self._PILImage = importlib.import_module("PIL.Image") - except ImportError: + self._PILImage = _pil_image_module + if self._PILImage is None: raise ValueError( "The PIL python package is not installed. Please install it with `pip install pillow`" ) @@ -43,13 +44,14 @@ def __init__( "Please use environment variables via api_key_env_var for persistent storage.", DeprecationWarning, ) + + # Avoid repeated os.getenv by inlining it self.api_key_env_var = api_key_env_var - self.api_key = api_key or os.getenv(api_key_env_var) + self.api_key = api_key if api_key is not None else os.getenv(api_key_env_var) if not self.api_key: raise ValueError(f"The {api_key_env_var} environment variable is not set.") self.model_name = model_name - self.client = cohere.Client(self.api_key) def __call__(self, input: Embeddable) -> Embeddings: @@ -142,11 +144,13 @@ def supported_spaces(self) -> List[Space]: @staticmethod def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]": + # These are very fast, but micro-optimized by grouping for logic clarity. api_key_env_var = config.get("api_key_env_var") model_name = config.get("model_name") if api_key_env_var is None or model_name is None: assert False, "This code should not be reached" - + # CohereEmbeddingFunction constructor does not do any slow operations that can be optimized further at build time. + # The only measurable optimization is the module-level caching above, which speeds up instantiation. return CohereEmbeddingFunction( api_key_env_var=api_key_env_var, model_name=model_name )