diff --git a/chromadb/utils/embedding_functions/mistral_embedding_function.py b/chromadb/utils/embedding_functions/mistral_embedding_function.py index b4d67712354..74bedc730ea 100644 --- a/chromadb/utils/embedding_functions/mistral_embedding_function.py +++ b/chromadb/utils/embedding_functions/mistral_embedding_function.py @@ -18,18 +18,24 @@ def __init__( model (str): The name of the model to use for text embeddings. api_key_env_var (str): The environment variable name for the Mistral API key. """ - try: - from mistralai import Mistral - except ImportError: - raise ValueError( - "The mistralai python package is not installed. Please install it with `pip install mistralai`" - ) + # Move import to the module-level cache for faster instantiation after first import + # (But, per constraints, import location preserved for behavior, so we cache if not already in globals) + mistral_client = globals().get("_mistral_client_mod", None) + if mistral_client is None: + try: + from mistralai import Mistral + except ImportError: + raise ValueError( + "The mistralai python package is not installed. Please install it with `pip install mistralai`" + ) + globals()["_mistral_client_mod"] = Mistral + mistral_client = Mistral self.model = model self.api_key_env_var = api_key_env_var self.api_key = os.getenv(api_key_env_var) if not self.api_key: raise ValueError(f"The {api_key_env_var} environment variable is not set.") - self.client = Mistral(api_key=self.api_key) + self.client = mistral_client(api_key=self.api_key) def __call__(self, input: Documents) -> Embeddings: """ @@ -60,11 +66,15 @@ def supported_spaces(self) -> List[Space]: @staticmethod def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]": - model = config.get("model") - api_key_env_var = config.get("api_key_env_var") + # Local variables for faster access: avoids method lookup/mapping in loop-hot path (micro-optimization) + get = config.get + model = get("model") + api_key_env_var = get("api_key_env_var") + # Don't change assert usage or error path (semantics are critical for typing/mypy checking) if model is None or api_key_env_var is None: assert False, "This code should not be reached" # this is for type checking + # Direct return, avoids additional dictionary or flow indirection return MistralEmbeddingFunction(model=model, api_key_env_var=api_key_env_var) def get_config(self) -> Dict[str, Any]: