66"""
77
88import os
9- import requests
109from typing import cast
10+
11+ import requests
1112from chromadb import Documents , EmbeddingFunction , Embeddings
1213from chromadb .api .types import validate_embedding_function
1314
1415
1516class MistralEmbeddingFunction (EmbeddingFunction ):
1617 """
1718 Mistral AI embedding function compatible with ChromaDB.
18-
19+
1920 This class implements the ChromaDB EmbeddingFunction interface
2021 to provide seamless integration with Mistral AI's embedding models.
21-
22+
2223 Attributes:
2324 api_key (str): Mistral API key for authentication
2425 model_name (str): Name of the Mistral embedding model to use
2526 base_url (str): Base URL for Mistral API endpoints
2627 max_retries (int): Maximum number of retries for failed requests
2728 timeout (int): Request timeout in seconds
2829 """
29-
30+
3031 def __init__ (
3132 self ,
32- api_key : str = None ,
33+ api_key : str | None = None ,
3334 model_name : str = "mistral-embed" ,
3435 base_url : str = "https://api.mistral.ai/v1" ,
3536 max_retries : int = 3 ,
3637 timeout : int = 30 ,
3738 ):
3839 """
3940 Initialize Mistral embedding function.
40-
41+
4142 Args:
4243 api_key: Mistral API key (defaults to MISTRAL_API_KEY env var)
4344 model_name: Mistral embedding model name
4445 base_url: Mistral API base URL
4546 max_retries: Maximum number of retries for API calls
4647 timeout: Request timeout in seconds
47-
48+
4849 Raises:
4950 ValueError: If API key is not provided or invalid
5051 """
@@ -53,52 +54,52 @@ def __init__(
5354 raise ValueError (
5455 "Mistral API key is required. Set MISTRAL_API_KEY environment variable."
5556 )
56-
57+
5758 self .model_name = model_name
5859 self .base_url = base_url .rstrip ('/' )
5960 self .max_retries = max_retries
6061 self .timeout = timeout
61-
62+
6263 # Validate the embedding function
6364 try :
6465 validate_embedding_function (self )
6566 except Exception as e :
66- raise ValueError (f"Invalid Mistral embedding function: { str ( e ) } " ) from e
67-
67+ raise ValueError (f"Invalid Mistral embedding function: { e !s } " ) from e
68+
6869 def __call__ (self , input : Documents ) -> Embeddings :
6970 """
7071 Generate embeddings for input documents.
71-
72+
7273 Args:
7374 input: Documents to embed (string or list of strings)
74-
75+
7576 Returns:
7677 List of embedding vectors
77-
78+
7879 Raises:
7980 RuntimeError: If API calls fail after max_retries attempts
8081 """
8182 if isinstance (input , str ):
8283 input = [input ]
83-
84+
8485 if not input :
8586 return []
86-
87+
8788 # Prepare the request
8889 headers = {
8990 "Authorization" : f"Bearer { self .api_key } " ,
9091 "Content-Type" : "application/json"
9192 }
92-
93+
9394 data = {
9495 "model" : self .model_name ,
9596 "input" : input
9697 }
97-
98+
9899 # Make API request with retry logic
99100 # Ensure at least one attempt is made, even if max_retries is 0
100101 attempts = max (1 , self .max_retries )
101-
102+
102103 for attempt in range (attempts ):
103104 try :
104105 response = requests .post (
@@ -108,26 +109,29 @@ def __call__(self, input: Documents) -> Embeddings:
108109 timeout = self .timeout
109110 )
110111 response .raise_for_status ()
111-
112+
112113 result = response .json ()
113114 embeddings = [item ["embedding" ] for item in result ["data" ]]
114-
115+
115116 return cast (Embeddings , embeddings )
116-
117+
117118 except requests .exceptions .RequestException as e :
118119 # If this is the last attempt, raise the error
119120 if attempt == attempts - 1 :
120121 raise RuntimeError (
121122 f"Failed to get embeddings from Mistral API after "
122- f"{ attempts } attempts: { str ( e ) } "
123+ f"{ attempts } attempts: { e !s } "
123124 ) from e
124125 # Otherwise, continue to next attempt
125126 continue
126-
127+
128+ # This should never be reached, but added for type safety
129+ raise RuntimeError ("Unexpected end of retry loop" )
130+
127131 def get_model_info (self ) -> dict :
128132 """
129133 Get information about the current model.
130-
134+
131135 Returns:
132136 Dictionary containing model information
133137 """
0 commit comments