@@ -146,90 +146,84 @@ def handle_model(model_name, provider, token_key, default_token=8192):
146146 with warnings .catch_warnings ():
147147 warnings .simplefilter ("ignore" )
148148 return init_chat_model (** llm_params )
149-
150- if "azure" in llm_params ["model" ]:
151- model_name = llm_params ["model" ].split ("/" )[- 1 ]
152- return handle_model (model_name , "azure_openai" , model_name )
153-
154- if "gpt-" in llm_params ["model" ]:
155- return handle_model (llm_params ["model" ], "openai" , llm_params ["model" ])
156-
157- if "fireworks" in llm_params ["model" ]:
158- model_name = "/" .join (llm_params ["model" ].split ("/" )[1 :])
159- token_key = llm_params ["model" ].split ("/" )[- 1 ]
160- return handle_model (model_name , "fireworks" , token_key )
161-
162- if "gemini" in llm_params ["model" ]:
163- model_name = llm_params ["model" ].split ("/" )[- 1 ]
164- return handle_model (model_name , "google_genai" , model_name )
165-
166- if llm_params ["model" ].startswith ("claude" ):
167- model_name = llm_params ["model" ].split ("/" )[- 1 ]
168- return handle_model (model_name , "anthropic" , model_name )
169-
170- if llm_params ["model" ].startswith ("vertexai" ):
171- return handle_model (llm_params ["model" ], "google_vertexai" , llm_params ["model" ])
172149
173- if "ollama" in llm_params ["model" ]:
174- model_name = llm_params ["model" ].split ("ollama/" )[- 1 ]
175- token_key = model_name if "model_tokens" not in llm_params else llm_params ["model_tokens" ]
176- return handle_model (model_name , "ollama" , token_key )
177-
178- if "hugging_face" in llm_params ["model" ]:
179- model_name = llm_params ["model" ].split ("/" )[- 1 ]
180- return handle_model (model_name , "hugging_face" , model_name )
181-
182- if "groq" in llm_params ["model" ]:
183- model_name = llm_params ["model" ].split ("/" )[- 1 ]
184- return handle_model (model_name , "groq" , model_name )
185-
186- if "bedrock" in llm_params ["model" ]:
187- model_name = llm_params ["model" ].split ("/" )[- 1 ]
188- return handle_model (model_name , "bedrock" , model_name )
189-
190- if "claude-3-" in llm_params ["model" ]:
191- return handle_model (llm_params ["model" ], "anthropic" , "claude3" )
192-
193- if llm_params ["model" ].startswith ("mistral" ):
194- model_name = llm_params ["model" ].split ("/" )[- 1 ]
195- return handle_model (model_name , "mistralai" , model_name )
196-
197- # Instantiate the language model based on the model name (models that do not use the common interface)
198- if "deepseek" in llm_params ["model" ]:
199- try :
200- self .model_token = models_tokens ["deepseek" ][llm_params ["model" ]]
201- except KeyError :
202- print ("model not found, using default token size (8192)" )
203- self .model_token = 8192
204- return DeepSeek (llm_params )
205-
206- if "ernie" in llm_params ["model" ]:
207- try :
208- self .model_token = models_tokens ["ernie" ][llm_params ["model" ]]
209- except KeyError :
210- print ("model not found, using default token size (8192)" )
211- self .model_token = 8192
212- return ErnieBotChat (** llm_params )
213-
214- if "oneapi" in llm_params ["model" ]:
215- # take the model after the last dash
216- llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
217- try :
218- self .model_token = models_tokens ["oneapi" ][llm_params ["model" ]]
219- except KeyError as exc :
220- raise KeyError ("Model not supported" ) from exc
221- return OneApi (llm_params )
222-
223- if "nvidia" in llm_params ["model" ]:
224- try :
225- self .model_token = models_tokens ["nvidia" ][llm_params ["model" ].split ("/" )[- 1 ]]
226- llm_params ["model" ] = "/" .join (llm_params ["model" ].split ("/" )[1 :])
227- except KeyError as exc :
228- raise KeyError ("Model not supported" ) from exc
229- return ChatNVIDIA (** llm_params )
230-
231- # Raise an error if the model did not match any of the previous cases
232- raise ValueError ("Model provided by the configuration not supported" )
150+ known_models = ["openai" , "azure_openai" , "google_genai" , "ollama" , "oneapi" , "nvidia" , "groq" , "google_vertexai" , "bedrock" , "mistralai" , "hugging_face" , "deepseek" , "ernie" , "fireworks" ]
151+
152+ if llm_params ["model" ] not in known_models :
153+ raise ValueError (f"Model '{ llm_params ['model' ]} ' is not supported" )
154+
155+ try :
156+ if "fireworks" in llm_params ["model" ]:
157+ model_name = "/" .join (llm_params ["model" ].split ("/" )[1 :])
158+ token_key = llm_params ["model" ].split ("/" )[- 1 ]
159+ return handle_model (model_name , "fireworks" , token_key )
160+
161+ elif "gemini" in llm_params ["model" ]:
162+ model_name = llm_params ["model" ].split ("/" )[- 1 ]
163+ return handle_model (model_name , "google_genai" , model_name )
164+
165+ elif llm_params ["model" ].startswith ("claude" ):
166+ model_name = llm_params ["model" ].split ("/" )[- 1 ]
167+ return handle_model (model_name , "anthropic" , model_name )
168+
169+ elif llm_params ["model" ].startswith ("vertexai" ):
170+ return handle_model (llm_params ["model" ], "google_vertexai" , llm_params ["model" ])
171+
172+ elif "gpt-" in llm_params ["model" ]:
173+ return handle_model (llm_params ["model" ], "openai" , llm_params ["model" ])
174+
175+ elif "ollama" in llm_params ["model" ]:
176+ model_name = llm_params ["model" ].split ("ollama/" )[- 1 ]
177+ token_key = model_name if "model_tokens" not in llm_params else llm_params ["model_tokens" ]
178+ return handle_model (model_name , "ollama" , token_key )
179+
180+ elif "claude-3-" in llm_params ["model" ]:
181+ return handle_model (llm_params ["model" ], "anthropic" , "claude3" )
182+
183+ elif llm_params ["model" ].startswith ("mistral" ):
184+ model_name = llm_params ["model" ].split ("/" )[- 1 ]
185+ return handle_model (model_name , "mistralai" , model_name )
186+
187+ # Instantiate the language model based on the model name (models that do not use the common interface)
188+ elif "deepseek" in llm_params ["model" ]:
189+ try :
190+ self .model_token = models_tokens ["deepseek" ][llm_params ["model" ]]
191+ except KeyError :
192+ print ("model not found, using default token size (8192)" )
193+ self .model_token = 8192
194+ return DeepSeek (llm_params )
195+
196+ elif "ernie" in llm_params ["model" ]:
197+ try :
198+ self .model_token = models_tokens ["ernie" ][llm_params ["model" ]]
199+ except KeyError :
200+ print ("model not found, using default token size (8192)" )
201+ self .model_token = 8192
202+ return ErnieBotChat (llm_params )
203+
204+ elif "oneapi" in llm_params ["model" ]:
205+ # take the model after the last dash
206+ llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
207+ try :
208+ self .model_token = models_tokens ["oneapi" ][llm_params ["model" ]]
209+ except KeyError :
210+ raise KeyError ("Model not supported" )
211+ return OneApi (llm_params )
212+
213+ elif "nvidia" in llm_params ["model" ]:
214+ try :
215+ self .model_token = models_tokens ["nvidia" ][llm_params ["model" ].split ("/" )[- 1 ]]
216+ llm_params ["model" ] = "/" .join (llm_params ["model" ].split ("/" )[1 :])
217+ except KeyError :
218+ raise KeyError ("Model not supported" )
219+ return ChatNVIDIA (llm_params )
220+
221+ else :
222+ model_name = llm_params ["model" ].split ("/" )[- 1 ]
223+ return handle_model (model_name , llm_params ["model" ], model_name )
224+
225+ except KeyError as e :
226+ print (f"Model not supported: { e } " )
233227
234228
235229 def get_state (self , key = None ) -> dict :
@@ -277,4 +271,4 @@ def _create_graph(self):
277271 def run (self ) -> str :
278272 """
279273 Abstract method to execute the graph and return the result.
280- """
274+ """
0 commit comments