|
19 | 19 | from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings |
20 | 20 | from langchain_fireworks import FireworksEmbeddings, ChatFireworks |
21 | 21 | from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings, ChatOpenAI, AzureChatOpenAI |
22 | | -from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings |
| 22 | +from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, ChatNVIDIA |
| 23 | +from langchain_community.chat_models import ErnieBotChat |
23 | 24 | from ..helpers import models_tokens |
24 | 25 | from ..models import ( |
25 | 26 | OneApi, |
26 | | - Nvidia, |
27 | 27 | DeepSeek |
28 | 28 | ) |
29 | | -from ..models.ernie import Ernie |
| 29 | + |
30 | 30 | from langchain.chat_models import init_chat_model |
31 | 31 |
|
32 | 32 | from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info |
@@ -192,7 +192,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object: |
192 | 192 | llm_params["model"] = "/".join(llm_params["model"].split("/")[1:]) |
193 | 193 | except KeyError as exc: |
194 | 194 | raise KeyError("Model not supported") from exc |
195 | | - return Nvidia(llm_params) |
| 195 | + return ChatNVIDIA(llm_params) |
196 | 196 | elif "gemini" in llm_params["model"]: |
197 | 197 | llm_params["model"] = llm_params["model"].split("/")[-1] |
198 | 198 | try: |
@@ -289,7 +289,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object: |
289 | 289 | except KeyError: |
290 | 290 | print("model not found, using default token size (8192)") |
291 | 291 | self.model_token = 8192 |
292 | | - return Ernie(llm_params) |
| 292 | + return ErnieBotChat(llm_params) |
293 | 293 | else: |
294 | 294 | raise ValueError("Model provided by the configuration not supported") |
295 | 295 |
|
@@ -320,7 +320,7 @@ def _create_default_embedder(self, llm_config=None) -> object: |
320 | 320 | return AzureOpenAIEmbeddings() |
321 | 321 | elif isinstance(self.llm_model, ChatFireworks): |
322 | 322 | return FireworksEmbeddings(model=self.llm_model.model_name) |
323 | | - elif isinstance(self.llm_model, Nvidia): |
| 323 | + elif isinstance(self.llm_model, ChatNVIDIA): |
324 | 324 | return NVIDIAEmbeddings(model=self.llm_model.model_name) |
325 | 325 | elif isinstance(self.llm_model, ChatOllama): |
326 | 326 | # unwrap the kwargs from the model whihc is a dict |
|
0 commit comments