@@ -125,103 +125,47 @@ def _create_llm(self, llm_config: dict) -> object:
125125 self .model_token = llm_params ["model_tokens" ]
126126 except KeyError as exc :
127127 raise KeyError ("model_tokens not specified" ) from exc
128- return llm_params ["model_instance" ]
129-
130- def handle_model (model_name , provider , token_key , default_token = 8192 ):
131- try :
132- self .model_token = models_tokens [provider ][token_key ]
133- except KeyError :
134- print (f"Model not found, using default token size ({ default_token } )" )
135- self .model_token = default_token
136- llm_params ["model_provider" ] = provider
137- llm_params ["model" ] = model_name
138- with warnings .catch_warnings ():
139- warnings .simplefilter ("ignore" )
140- return init_chat_model (** llm_params )
141-
142- known_models = {"chatgpt" ,"gpt" ,"openai" , "azure_openai" , "google_genai" ,
143- "ollama" , "oneapi" , "nvidia" , "groq" , "google_vertexai" ,
144- "bedrock" , "mistralai" , "hugging_face" , "deepseek" , "ernie" ,
145- "fireworks" , "claude-3-" }
146-
147- if llm_params ["model" ].split ("/" )[0 ] not in known_models and llm_params ["model" ].split ("-" )[0 ] not in known_models :
148- raise ValueError (f"Model '{ llm_params ['model' ]} ' is not supported" )
149-
128+ return llm_params ["model_instance" ]
129+
130+ known_providers = {"openai" , "azure_openai" , "google_genai" , "google_vertexai" ,
131+ "ollama" , "oneapi" , "nvidia" , "groq" , "anthropic" "bedrock" , "mistralai" ,
132+ "hugging_face" , "deepseek" , "ernie" , "fireworks" }
133+
134+ split_model_provider = llm_params ["model" ].split ("/" )
135+ llm_params ["model_provider" ] = split_model_provider [0 ]
136+ llm_params ["model" ] = split_model_provider [1 :]
137+
138+ if llm_params ["model_provider" ] not in known_providers :
139+ raise ValueError (f"Provider { llm_params ['model_provider' ]} is not supported. If possible, try to use a model instance instead." )
140+
150141 try :
151- if "fireworks" in llm_params ["model" ]:
152- model_name = "/" .join (llm_params ["model" ].split ("/" )[1 :])
153- token_key = llm_params ["model" ].split ("/" )[- 1 ]
154- return handle_model (model_name , "fireworks" , token_key )
155-
156- elif "gemini" in llm_params ["model" ]:
157- model_name = llm_params ["model" ].split ("/" )[- 1 ]
158- return handle_model (model_name , "google_genai" , model_name )
159-
160- elif llm_params ["model" ].startswith ("claude" ):
161- model_name = llm_params ["model" ].split ("/" )[- 1 ]
162- return handle_model (model_name , "anthropic" , model_name )
163-
164- elif llm_params ["model" ].startswith ("vertexai" ):
165- return handle_model (llm_params ["model" ], "google_vertexai" , llm_params ["model" ])
166-
167- elif "gpt-" in llm_params ["model" ]:
168- return handle_model (llm_params ["model" ], "openai" , llm_params ["model" ])
169-
170- elif "ollama" in llm_params ["model" ]:
171- model_name = llm_params ["model" ].split ("ollama/" )[- 1 ]
172- token_key = model_name if "model_tokens" not in llm_params else None
173- model_tokens = 8192 if "model_tokens" not in llm_params else llm_params ["model_tokens" ]
174- return handle_model (model_name , "ollama" , token_key , model_tokens )
175-
176- elif "claude-3-" in llm_params ["model" ]:
177- return handle_model (llm_params ["model" ], "anthropic" , "claude3" )
178-
179- elif llm_params ["model" ].startswith ("mistral" ):
180- model_name = llm_params ["model" ].split ("/" )[- 1 ]
181- return handle_model (model_name , "mistralai" , model_name )
182-
183- elif "deepseek" in llm_params ["model" ]:
184- try :
185- self .model_token = models_tokens ["deepseek" ][llm_params ["model" ]]
186- except KeyError :
187- print ("model not found, using default token size (8192)" )
188- self .model_token = 8192
189- return DeepSeek (llm_params )
190-
191- elif "ernie" in llm_params ["model" ]:
192- from langchain_community .chat_models import ErnieBotChat
193-
194- try :
195- self .model_token = models_tokens ["ernie" ][llm_params ["model" ]]
196- except KeyError :
197- print ("model not found, using default token size (8192)" )
198- self .model_token = 8192
199- return ErnieBotChat (llm_params )
200-
201- elif "oneapi" in llm_params ["model" ]:
202- llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
203- try :
204- self .model_token = models_tokens ["oneapi" ][llm_params ["model" ]]
205- except KeyError :
206- raise KeyError ("Model not supported" )
207- return OneApi (llm_params )
208-
209- elif "nvidia" in llm_params ["model" ]:
210- from langchain_nvidia_ai_endpoints import ChatNVIDIA
211-
212- try :
213- self .model_token = models_tokens ["nvidia" ][llm_params ["model" ].split ("/" )[- 1 ]]
214- llm_params ["model" ] = "/" .join (llm_params ["model" ].split ("/" )[1 :])
215- except KeyError :
216- raise KeyError ("Model not supported" )
217- return ChatNVIDIA (llm_params )
142+ self .model_token = models_tokens [llm_params ["model" ]][llm_params ["model" ]]
143+ except KeyError :
144+ print ("Model not found, using default token size (8192)" )
145+ self .model_token = 8192
218146
147+ try :
148+ if llm_params ["model_provider" ] not in {"oneapi" , "nvidia" , "ernie" , "deepseek" }:
149+ with warnings .catch_warnings ():
150+ warnings .simplefilter ("ignore" )
151+ return init_chat_model (** llm_params )
219152 else :
220- model_name = llm_params ["model" ].split ("/" )[- 1 ]
221- return handle_model (model_name , llm_params ["model" ], model_name )
153+ if "deepseek" in llm_params ["model" ]:
154+ return DeepSeek (** llm_params )
155+
156+ if "ernie" in llm_params ["model" ]:
157+ from langchain_community .chat_models import ErnieBotChat
158+ return ErnieBotChat (** llm_params )
159+
160+ if "oneapi" in llm_params ["model" ]:
161+ return OneApi (** llm_params )
162+
163+ if "nvidia" in llm_params ["model" ]:
164+ from langchain_nvidia_ai_endpoints import ChatNVIDIA
165+ return ChatNVIDIA (** llm_params )
222166
223- except KeyError as e :
224- print (f"Model not supported : { e } " )
167+ except Exception as e :
168+ print (f"Error instancing model : { e } " )
225169
226170
227171 def get_state (self , key = None ) -> dict :
0 commit comments