@@ -46,7 +46,7 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
4646 self .source = source
4747 self .config = config
4848 self .llm_model = self ._create_llm (config ["llm" ], chat = True )
49- self .embedder_model = self ._create_default_embedder (
49+ self .embedder_model = self ._create_default_embedder (llm_config = config [ "llm" ]
5050 ) if "embeddings" not in config else self ._create_embedder (
5151 config ["embeddings" ])
5252
@@ -91,6 +91,13 @@ def _set_model_token(self, llm):
9191 self .model_token = models_tokens ['mistral' ][llm .repo_id ]
9292 except KeyError :
9393 raise KeyError ("Model not supported" )
94+
95+ elif 'Google' in str (type (llm )):
96+ try :
97+ if 'gemini' in llm .model :
98+ self .model_token = models_tokens ['gemini' ][llm .model ]
99+ except KeyError :
100+ raise KeyError ("Model not supported" )
94101
95102 def _create_llm (self , llm_config : dict , chat = False ) -> object :
96103 """
@@ -197,7 +204,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
197204 raise ValueError (
198205 "Model provided by the configuration not supported" )
199206
200- def _create_default_embedder (self ) -> object :
207+ def _create_default_embedder (self , llm_config = None ) -> object :
201208 """
202209 Create an embedding model instance based on the chosen llm model.
203210
@@ -207,6 +214,8 @@ def _create_default_embedder(self) -> object:
207214 Raises:
208215 ValueError: If the model is not supported.
209216 """
217+ if isinstance (self .llm_model , Gemini ):
218+ return GoogleGenerativeAIEmbeddings (google_api_key = llm_config ['api_key' ], model = "models/embedding-001" )
210219 if isinstance (self .llm_model , OpenAI ):
211220 return OpenAIEmbeddings (api_key = self .llm_model .openai_api_key )
212221 elif isinstance (self .llm_model , AzureOpenAIEmbeddings ):
@@ -241,7 +250,6 @@ def _create_embedder(self, embedder_config: dict) -> object:
241250 Raises:
242251 KeyError: If the model is not supported.
243252 """
244-
245253 if 'model_instance' in embedder_config :
246254 return embedder_config ['model_instance' ]
247255 # Instantiate the embedding model based on the model name
0 commit comments