@@ -146,78 +146,84 @@ def handle_model(model_name, provider, token_key, default_token=8192):
146146 with warnings .catch_warnings ():
147147 warnings .simplefilter ("ignore" )
148148 return init_chat_model (** llm_params )
149-
150- if "fireworks" in llm_params ["model" ]:
151- model_name = "/" .join (llm_params ["model" ].split ("/" )[1 :])
152- token_key = llm_params ["model" ].split ("/" )[- 1 ]
153- return handle_model (model_name , "fireworks" , token_key )
154-
155- elif "gemini" in llm_params ["model" ]:
156- model_name = llm_params ["model" ].split ("/" )[- 1 ]
157- return handle_model (model_name , "google_genai" , model_name )
158-
159- elif llm_params ["model" ].startswith ("claude" ):
160- model_name = llm_params ["model" ].split ("/" )[- 1 ]
161- return handle_model (model_name , "anthropic" , model_name )
162-
163- elif llm_params ["model" ].startswith ("vertexai" ):
164- return handle_model (llm_params ["model" ], "google_vertexai" , llm_params ["model" ])
165- elif "gpt-" in llm_params ["model" ]:
166- return handle_model (llm_params ["model" ], "openai" , llm_params ["model" ])
167-
168- elif "ollama" in llm_params ["model" ]:
169- model_name = llm_params ["model" ].split ("ollama/" )[- 1 ]
170- token_key = model_name if "model_tokens" not in llm_params else llm_params ["model_tokens" ]
171- return handle_model (model_name , "ollama" , token_key )
172-
173- elif "claude-3-" in llm_params ["model" ]:
174- return handle_model (llm_params ["model" ], "anthropic" , "claude3" )
175-
176- elif llm_params ["model" ].startswith ("mistral" ):
177- model_name = llm_params ["model" ].split ("/" )[- 1 ]
178- return handle_model (model_name , "mistralai" , model_name )
179-
180- # Instantiate the language model based on the model name (models that do not use the common interface)
181- elif "deepseek" in llm_params ["model" ]:
182- try :
183- self .model_token = models_tokens ["deepseek" ][llm_params ["model" ]]
184- except KeyError :
185- print ("model not found, using default token size (8192)" )
186- self .model_token = 8192
187- return DeepSeek (llm_params )
188-
189- elif "ernie" in llm_params ["model" ]:
190- try :
191- self .model_token = models_tokens ["ernie" ][llm_params ["model" ]]
192- except KeyError :
193- print ("model not found, using default token size (8192)" )
194- self .model_token = 8192
195- return ErnieBotChat (llm_params )
196-
197- elif "oneapi" in llm_params ["model" ]:
198-
199- # take the model after the last dash
200- llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
201- try :
202- self .model_token = models_tokens ["oneapi" ][llm_params ["model" ]]
203- except KeyError as exc :
204- raise KeyError ("Model not supported" ) from exc
205- return OneApi (llm_params )
206-
207- elif "nvidia" in llm_params ["model" ]:
208-
209- try :
210- self .model_token = models_tokens ["nvidia" ][llm_params ["model" ].split ("/" )[- 1 ]]
211- llm_params ["model" ] = "/" .join (llm_params ["model" ].split ("/" )[1 :])
212- except KeyError as exc :
213- raise KeyError ("Model not supported" ) from exc
214- return ChatNVIDIA (llm_params )
215- else :
216- model_name = llm_params ["model" ].split ("/" )[- 1 ]
217- return handle_model (model_name , llm_params ["model" ], model_name )
218-
219- raise ValueError ("Model provided by the configuration not supported" )
220-
149+
150+ known_models = ["azure" , "fireworks" , "gemini" , "claude" , "vertexai" , "hugging_face" , "groq" , "gpt-" , "ollama" , "claude-3-" , "bedrock" , "mistral" , "ernie" , "oneapi" , "nvidia" ]
151+
152+ if llm_params ["model" ] not in known_models :
153+ raise ValueError (f"Model '{ llm_params ['model' ]} ' is not supported" )
154+
155+ try :
156+ if "fireworks" in llm_params ["model" ]:
157+ model_name = "/" .join (llm_params ["model" ].split ("/" )[1 :])
158+ token_key = llm_params ["model" ].split ("/" )[- 1 ]
159+ return handle_model (model_name , "fireworks" , token_key )
160+
161+ elif "gemini" in llm_params ["model" ]:
162+ model_name = llm_params ["model" ].split ("/" )[- 1 ]
163+ return handle_model (model_name , "google_genai" , model_name )
164+
165+ elif llm_params ["model" ].startswith ("claude" ):
166+ model_name = llm_params ["model" ].split ("/" )[- 1 ]
167+ return handle_model (model_name , "anthropic" , model_name )
168+
169+ elif llm_params ["model" ].startswith ("vertexai" ):
170+ return handle_model (llm_params ["model" ], "google_vertexai" , llm_params ["model" ])
171+
172+ elif "gpt-" in llm_params ["model" ]:
173+ return handle_model (llm_params ["model" ], "openai" , llm_params ["model" ])
174+
175+ elif "ollama" in llm_params ["model" ]:
176+ model_name = llm_params ["model" ].split ("ollama/" )[- 1 ]
177+ token_key = model_name if "model_tokens" not in llm_params else llm_params ["model_tokens" ]
178+ return handle_model (model_name , "ollama" , token_key )
179+
180+ elif "claude-3-" in llm_params ["model" ]:
181+ return handle_model (llm_params ["model" ], "anthropic" , "claude3" )
182+
183+ elif llm_params ["model" ].startswith ("mistral" ):
184+ model_name = llm_params ["model" ].split ("/" )[- 1 ]
185+ return handle_model (model_name , "mistralai" , model_name )
186+
187+ # Instantiate the language model based on the model name (models that do not use the common interface)
188+ elif "deepseek" in llm_params ["model" ]:
189+ try :
190+ self .model_token = models_tokens ["deepseek" ][llm_params ["model" ]]
191+ except KeyError :
192+ print ("model not found, using default token size (8192)" )
193+ self .model_token = 8192
194+ return DeepSeek (llm_params )
195+
196+ elif "ernie" in llm_params ["model" ]:
197+ try :
198+ self .model_token = models_tokens ["ernie" ][llm_params ["model" ]]
199+ except KeyError :
200+ print ("model not found, using default token size (8192)" )
201+ self .model_token = 8192
202+ return ErnieBotChat (llm_params )
203+
204+ elif "oneapi" in llm_params ["model" ]:
205+ # take the model after the last dash
206+ llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
207+ try :
208+ self .model_token = models_tokens ["oneapi" ][llm_params ["model" ]]
209+ except KeyError :
210+ raise KeyError ("Model not supported" )
211+ return OneApi (llm_params )
212+
213+ elif "nvidia" in llm_params ["model" ]:
214+ try :
215+ self .model_token = models_tokens ["nvidia" ][llm_params ["model" ].split ("/" )[- 1 ]]
216+ llm_params ["model" ] = "/" .join (llm_params ["model" ].split ("/" )[1 :])
217+ except KeyError :
218+ raise KeyError ("Model not supported" )
219+ return ChatNVIDIA (llm_params )
220+
221+ else :
222+ model_name = llm_params ["model" ].split ("/" )[- 1 ]
223+ return handle_model (model_name , llm_params ["model" ], model_name )
224+
225+ except KeyError as e :
226+ print (f"Model not supported: { e } " )
221227
222228 def get_state (self , key = None ) -> dict :
223229 """ ""
@@ -264,4 +270,4 @@ def _create_graph(self):
264270 def run (self ) -> str :
265271 """
266272 Abstract method to execute the graph and return the result.
267- """
273+ """
0 commit comments