@@ -146,138 +146,61 @@ def _create_llm(self, llm_config: dict) -> object:
146146 raise KeyError ("model_tokens not specified" ) from exc
147147 return llm_params ["model_instance" ]
148148
149- # Instantiate the language model based on the model name
150- if "gpt-" in llm_params [ "model" ] :
149+ # Instantiate the language model based on the model name (models that use the common interface)
150+ def handle_model ( model_name , provider , token_key , default_token = 8192 ) :
151151 try :
152- self .model_token = models_tokens ["openai" ][llm_params ["model" ]]
153- llm_params ["model_provider" ] = "openai"
154- except KeyError as exc :
155- raise KeyError ("Model not supported" ) from exc
152+ self .model_token = models_tokens [provider ][token_key ]
153+ except KeyError :
154+ print (f"Model not found, using default token size ({ default_token } )" )
155+ self .model_token = default_token
156+ llm_params ["model_provider" ] = provider
157+ llm_params ["model" ] = model_name
156158 return init_chat_model (** llm_params )
157159
158- if "oneapi" in llm_params ["model" ]:
159- # take the model after the last dash
160- llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
161- try :
162- self .model_token = models_tokens ["oneapi" ][llm_params ["model" ]]
163- except KeyError as exc :
164- raise KeyError ("Model not supported" ) from exc
165- return OneApi (llm_params )
160+ if "gpt-" in llm_params ["model" ]:
161+ return handle_model (llm_params ["model" ], "openai" , llm_params ["model" ])
166162
167163 if "fireworks" in llm_params ["model" ]:
168- try :
169- self .model_token = models_tokens ["fireworks" ][llm_params ["model" ].split ("/" )[- 1 ]]
170- llm_params ["model" ] = "/" .join (llm_params ["model" ].split ("/" )[1 :])
171- except KeyError as exc :
172- raise KeyError ("Model not supported" ) from exc
173- llm_params ["model_provider" ] = "fireworks"
174- return init_chat_model (** llm_params )
164+ model_name = "/" .join (llm_params ["model" ].split ("/" )[1 :])
165+ token_key = llm_params ["model" ].split ("/" )[- 1 ]
166+ return handle_model (model_name , "fireworks" , token_key )
175167
176168 if "azure" in llm_params ["model" ]:
177- # take the model after the last dash
178- llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
179- try :
180- self .model_token = models_tokens ["azure" ][llm_params ["model" ]]
181- except KeyError as exc :
182- raise KeyError ("Model not supported" ) from exc
183- llm_params ["model_provider" ] = "azure_openai"
184- return init_chat_model (** llm_params )
185-
186- if "nvidia" in llm_params ["model" ]:
187- try :
188- self .model_token = models_tokens ["nvidia" ][llm_params ["model" ].split ("/" )[- 1 ]]
189- llm_params ["model" ] = "/" .join (llm_params ["model" ].split ("/" )[1 :])
190- except KeyError as exc :
191- raise KeyError ("Model not supported" ) from exc
192- return ChatNVIDIA (llm_params )
169+ model_name = llm_params ["model" ].split ("/" )[- 1 ]
170+ return handle_model (model_name , "azure_openai" , model_name )
193171
194172 if "gemini" in llm_params ["model" ]:
195- llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
196- try :
197- self .model_token = models_tokens ["gemini" ][llm_params ["model" ]]
198- except KeyError as exc :
199- raise KeyError ("Model not supported" ) from exc
200- llm_params ["model_provider" ] = "google_genai "
201- return init_chat_model (** llm_params )
173+ model_name = llm_params ["model" ].split ("/" )[- 1 ]
174+ return handle_model (model_name , "google_genai" , model_name )
202175
203176 if llm_params ["model" ].startswith ("claude" ):
204- llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
205- try :
206- self .model_token = models_tokens ["claude" ][llm_params ["model" ]]
207- except KeyError as exc :
208- raise KeyError ("Model not supported" ) from exc
209- llm_params ["model_provider" ] = "anthropic"
210- return init_chat_model (** llm_params )
177+ model_name = llm_params ["model" ].split ("/" )[- 1 ]
178+ return handle_model (model_name , "anthropic" , model_name )
211179
212180 if llm_params ["model" ].startswith ("vertexai" ):
213- try :
214- self .model_token = models_tokens ["vertexai" ][llm_params ["model" ]]
215- except KeyError as exc :
216- raise KeyError ("Model not supported" ) from exc
217- llm_params ["model_provider" ] = "google_vertexai"
218- return init_chat_model (** llm_params )
181+ return handle_model (llm_params ["model" ], "google_vertexai" , llm_params ["model" ])
219182
220183 if "ollama" in llm_params ["model" ]:
221- llm_params ["model" ] = llm_params ["model" ].split ("ollama/" )[- 1 ]
222- llm_params ["model_provider" ] = "ollama"
223-
224- # allow user to set model_tokens in config
225- try :
226- if "model_tokens" in llm_params :
227- self .model_token = llm_params ["model_tokens" ]
228- elif llm_params ["model" ] in models_tokens ["ollama" ]:
229- try :
230- self .model_token = models_tokens ["ollama" ][llm_params ["model" ]]
231- except KeyError as exc :
232- print ("model not found, using default token size (8192)" )
233- self .model_token = 8192
234- else :
235- self .model_token = 8192
236- except AttributeError :
237- self .model_token = 8192
238-
239- return init_chat_model (** llm_params )
184+ model_name = llm_params ["model" ].split ("ollama/" )[- 1 ]
185+ token_key = model_name if "model_tokens" not in llm_params else llm_params ["model_tokens" ]
186+ return handle_model (model_name , "ollama" , token_key )
240187
241188 if "hugging_face" in llm_params ["model" ]:
242- llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
243- try :
244- self .model_token = models_tokens ["hugging_face" ][llm_params ["model" ]]
245- except KeyError :
246- print ("model not found, using default token size (8192)" )
247- self .model_token = 8192
248- llm_params ["model_provider" ] = "hugging_face"
249- return init_chat_model (** llm_params )
189+ model_name = llm_params ["model" ].split ("/" )[- 1 ]
190+ return handle_model (model_name , "hugging_face" , model_name )
250191
251192 if "groq" in llm_params ["model" ]:
252- llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
253-
254- try :
255- self .model_token = models_tokens ["groq" ][llm_params ["model" ]]
256- except KeyError :
257- print ("model not found, using default token size (8192)" )
258- self .model_token = 8192
259- llm_params ["model_provider" ] = "groq"
260- return init_chat_model (** llm_params )
193+ model_name = llm_params ["model" ].split ("/" )[- 1 ]
194+ return handle_model (model_name , "groq" , model_name )
261195
262196 if "bedrock" in llm_params ["model" ]:
263- llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
264- try :
265- self .model_token = models_tokens ["bedrock" ][llm_params ["model" ]]
266- except KeyError :
267- print ("model not found, using default token size (8192)" )
268- self .model_token = 8192
269- llm_params ["model_provider" ] = "bedrock"
270- return init_chat_model (** llm_params )
197+ model_name = llm_params ["model" ].split ("/" )[- 1 ]
198+ return handle_model (model_name , "bedrock" , model_name )
271199
272200 if "claude-3-" in llm_params ["model" ]:
273- try :
274- self .model_token = models_tokens ["claude" ]["claude3" ]
275- except KeyError :
276- print ("model not found, using default token size (8192)" )
277- self .model_token = 8192
278- llm_params ["model_provider" ] = "anthropic"
279- return init_chat_model (** llm_params )
201+ return handle_model (llm_params ["model" ], "anthropic" , "claude3" )
280202
203+ # Instantiate the language model based on the model name (models that do not use the common interface)
281204 if "deepseek" in llm_params ["model" ]:
282205 try :
283206 self .model_token = models_tokens ["deepseek" ][llm_params ["model" ]]
@@ -293,7 +216,25 @@ def _create_llm(self, llm_config: dict) -> object:
293216 print ("model not found, using default token size (8192)" )
294217 self .model_token = 8192
295218 return ErnieBotChat (llm_params )
219+
220+ if "oneapi" in llm_params ["model" ]:
221+ # take the model after the last dash
222+ llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
223+ try :
224+ self .model_token = models_tokens ["oneapi" ][llm_params ["model" ]]
225+ except KeyError as exc :
226+ raise KeyError ("Model not supported" ) from exc
227+ return OneApi (llm_params )
228+
229+ if "nvidia" in llm_params ["model" ]:
230+ try :
231+ self .model_token = models_tokens ["nvidia" ][llm_params ["model" ].split ("/" )[- 1 ]]
232+ llm_params ["model" ] = "/" .join (llm_params ["model" ].split ("/" )[1 :])
233+ except KeyError as exc :
234+ raise KeyError ("Model not supported" ) from exc
235+ return ChatNVIDIA (llm_params )
296236
237+ # Raise an error if the model did not match any of the previous cases
297238 raise ValueError ("Model provided by the configuration not supported" )
298239
299240 def _create_default_embedder (self , llm_config = None ) -> object :
0 commit comments