@@ -136,7 +136,6 @@ def _create_llm(self, llm_config: dict) -> object:
136136 raise KeyError ("model_tokens not specified" ) from exc
137137 return llm_params ["model_instance" ]
138138
139- # Instantiate the language model based on the model name (models that use the common interface)
140139 def handle_model (model_name , provider , token_key , default_token = 8192 ):
141140 try :
142141 self .model_token = models_tokens [provider ][token_key ]
@@ -153,84 +152,74 @@ def handle_model(model_name, provider, token_key, default_token=8192):
153152 model_name = llm_params ["model" ].split ("/" )[- 1 ]
154153 return handle_model (model_name , "azure_openai" , model_name )
155154
156- if "gpt-" in llm_params ["model" ]:
157- return handle_model (llm_params ["model" ], "openai" , llm_params ["model" ])
158-
159- if "fireworks" in llm_params ["model" ]:
155+ elif "fireworks" in llm_params ["model" ]:
160156 model_name = "/" .join (llm_params ["model" ].split ("/" )[1 :])
161157 token_key = llm_params ["model" ].split ("/" )[- 1 ]
162158 return handle_model (model_name , "fireworks" , token_key )
163159
164- if "gemini" in llm_params ["model" ]:
160+ elif "gemini" in llm_params ["model" ]:
165161 model_name = llm_params ["model" ].split ("/" )[- 1 ]
166162 return handle_model (model_name , "google_genai" , model_name )
167163
168- if llm_params ["model" ].startswith ("claude" ):
164+ elif llm_params ["model" ].startswith ("claude" ):
169165 model_name = llm_params ["model" ].split ("/" )[- 1 ]
170166 return handle_model (model_name , "anthropic" , model_name )
171167
172- if llm_params ["model" ].startswith ("vertexai" ):
168+ elif llm_params ["model" ].startswith ("vertexai" ):
173169 return handle_model (llm_params ["model" ], "google_vertexai" , llm_params ["model" ])
174170
175- if "ollama" in llm_params ["model" ]:
171+ elif "gpt-" in llm_params ["model" ]:
172+ return handle_model (llm_params ["model" ], "openai" , llm_params ["model" ])
173+
174+ elif "ollama" in llm_params ["model" ]:
176175 model_name = llm_params ["model" ].split ("ollama/" )[- 1 ]
177176 token_key = model_name if "model_tokens" not in llm_params else llm_params ["model_tokens" ]
178177 return handle_model (model_name , "ollama" , token_key )
179178
180- if "hugging_face" in llm_params ["model" ]:
181- model_name = llm_params ["model" ].split ("/" )[- 1 ]
182- return handle_model (model_name , "hugging_face" , model_name )
183-
184- if "groq" in llm_params ["model" ]:
185- model_name = llm_params ["model" ].split ("/" )[- 1 ]
186- return handle_model (model_name , "groq" , model_name )
187-
188- if "bedrock" in llm_params ["model" ]:
189- model_name = llm_params ["model" ].split ("/" )[- 1 ]
190- return handle_model (model_name , "bedrock" , model_name )
191-
192- if "claude-3-" in llm_params ["model" ]:
179+ elif "claude-3-" in llm_params ["model" ]:
193180 return handle_model (llm_params ["model" ], "anthropic" , "claude3" )
194-
195- if llm_params ["model" ].startswith ("mistral" ):
181+
182+ elif llm_params ["model" ].startswith ("mistral" ):
196183 model_name = llm_params ["model" ].split ("/" )[- 1 ]
197184 return handle_model (model_name , "mistralai" , model_name )
198185
199186 # Instantiate the language model based on the model name (models that do not use the common interface)
200- if "deepseek" in llm_params ["model" ]:
187+ elif "deepseek" in llm_params ["model" ]:
201188 try :
202189 self .model_token = models_tokens ["deepseek" ][llm_params ["model" ]]
203190 except KeyError :
204191 print ("model not found, using default token size (8192)" )
205192 self .model_token = 8192
206193 return DeepSeek (llm_params )
207194
208- if "ernie" in llm_params ["model" ]:
195+ elif "ernie" in llm_params ["model" ]:
209196 try :
210197 self .model_token = models_tokens ["ernie" ][llm_params ["model" ]]
211198 except KeyError :
212199 print ("model not found, using default token size (8192)" )
213200 self .model_token = 8192
214201 return ErnieBotChat (llm_params )
215-
216- if "oneapi" in llm_params ["model" ]:
202+
203+ elif "oneapi" in llm_params ["model" ]:
217204 # take the model after the last dash
218205 llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
219206 try :
220207 self .model_token = models_tokens ["oneapi" ][llm_params ["model" ]]
221208 except KeyError as exc :
222209 raise KeyError ("Model not supported" ) from exc
223210 return OneApi (llm_params )
224-
225- if "nvidia" in llm_params ["model" ]:
211+
212+ elif "nvidia" in llm_params ["model" ]:
226213 try :
227214 self .model_token = models_tokens ["nvidia" ][llm_params ["model" ].split ("/" )[- 1 ]]
228215 llm_params ["model" ] = "/" .join (llm_params ["model" ].split ("/" )[1 :])
229216 except KeyError as exc :
230217 raise KeyError ("Model not supported" ) from exc
231218 return ChatNVIDIA (llm_params )
219+ else :
220+ model_name = llm_params ["model" ].split ("/" )[- 1 ]
221+ return handle_model (model_name , llm_params ["model" ], model_name )
232222
233- # Raise an error if the model did not match any of the previous cases
234223 raise ValueError ("Model provided by the configuration not supported" )
235224
236225
0 commit comments