Skip to content

Commit fc0dadb

Browse files
committed
feat: add nvidia connection
1 parent f078fe8 commit fc0dadb

File tree

7 files changed

+75
-1
lines changed

7 files changed

+75
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ dependencies = [
2020
"langchain-groq==0.1.3",
2121
"langchain-aws==0.1.3",
2222
"langchain-anthropic==0.1.11",
23+
"langchain-nvidia-ai-endpoints==0.1.6",
2324
"html2text==2024.2.26",
2425
"faiss-cpu==1.8.0",
2526
"beautifulsoup4==4.12.3",

requirements-dev.lock

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ aiohttp==3.9.5
1414
# via langchain
1515
# via langchain-community
1616
# via langchain-fireworks
17+
# via langchain-nvidia-ai-endpoints
1718
aiosignal==1.3.1
1819
# via aiohttp
1920
alabaster==0.7.16
@@ -268,6 +269,7 @@ langchain-core==0.1.52
268269
# via langchain-google-genai
269270
# via langchain-google-vertexai
270271
# via langchain-groq
272+
# via langchain-nvidia-ai-endpoints
271273
# via langchain-openai
272274
# via langchain-text-splitters
273275
langchain-fireworks==0.1.3
@@ -278,6 +280,8 @@ langchain-google-vertexai==1.0.4
278280
# via scrapegraphai
279281
langchain-groq==0.1.3
280282
# via scrapegraphai
283+
langchain-nvidia-ai-endpoints==0.1.6
284+
# via scrapegraphai
281285
langchain-openai==0.1.6
282286
# via scrapegraphai
283287
langchain-text-splitters==0.0.2
@@ -348,6 +352,7 @@ pandas==2.2.2
348352
# via streamlit
349353
pillow==10.3.0
350354
# via fireworks-ai
355+
# via langchain-nvidia-ai-endpoints
351356
# via matplotlib
352357
# via streamlit
353358
platformdirs==4.2.2

requirements.lock

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ aiohttp==3.9.5
1212
# via langchain
1313
# via langchain-community
1414
# via langchain-fireworks
15+
# via langchain-nvidia-ai-endpoints
1516
aiosignal==1.3.1
1617
# via aiohttp
1718
annotated-types==0.7.0
@@ -187,6 +188,7 @@ langchain-core==0.1.52
187188
# via langchain-google-genai
188189
# via langchain-google-vertexai
189190
# via langchain-groq
191+
# via langchain-nvidia-ai-endpoints
190192
# via langchain-openai
191193
# via langchain-text-splitters
192194
langchain-fireworks==0.1.3
@@ -197,6 +199,8 @@ langchain-google-vertexai==1.0.4
197199
# via scrapegraphai
198200
langchain-groq==0.1.3
199201
# via scrapegraphai
202+
langchain-nvidia-ai-endpoints==0.1.6
203+
# via scrapegraphai
200204
langchain-openai==0.1.6
201205
# via scrapegraphai
202206
langchain-text-splitters==0.0.2
@@ -238,6 +242,7 @@ pandas==2.2.2
238242
# via scrapegraphai
239243
pillow==10.3.0
240244
# via fireworks-ai
245+
# via langchain-nvidia-ai-endpoints
241246
playwright==1.43.0
242247
# via scrapegraphai
243248
# via undetected-playwright

scrapegraphai/graphs/abstract_graph.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
1515
from langchain_fireworks import FireworksEmbeddings
1616
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
17+
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
1718
from ..helpers import models_tokens
1819
from ..models import (
1920
Anthropic,
@@ -26,7 +27,8 @@
2627
OpenAI,
2728
OneApi,
2829
Fireworks,
29-
VertexAI
30+
VertexAI,
31+
Nvidia
3032
)
3133
from ..models.ernie import Ernie
3234
from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info
@@ -180,6 +182,13 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
180182
except KeyError as exc:
181183
raise KeyError("Model not supported") from exc
182184
return AzureOpenAI(llm_params)
185+
elif "nvidia" in llm_params["model"]:
186+
try:
187+
self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]]
188+
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
189+
except KeyError as exc:
190+
raise KeyError("Model not supported") from exc
191+
return Nvidia(llm_params)
183192
elif "gemini" in llm_params["model"]:
184193
llm_params["model"] = llm_params["model"].split("/")[-1]
185194
try:
@@ -305,6 +314,8 @@ def _create_default_embedder(self, llm_config=None) -> object:
305314
return AzureOpenAIEmbeddings()
306315
elif isinstance(self.llm_model, Fireworks):
307316
return FireworksEmbeddings(model=self.llm_model.model_name)
317+
elif isinstance(self.llm_model, Nvidia):
318+
return NVIDIAEmbeddings(model=self.llm_model.model_name)
308319
elif isinstance(self.llm_model, Ollama):
309320
# unwrap the kwargs from the model whihc is a dict
310321
params = self.llm_model._lc_kwargs
@@ -341,6 +352,14 @@ def _create_embedder(self, embedder_config: dict) -> object:
341352
return OpenAIEmbeddings(api_key=embedder_params["api_key"])
342353
elif "azure" in embedder_params["model"]:
343354
return AzureOpenAIEmbeddings()
355+
if "nvidia" in embedder_params["model"]:
356+
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
357+
try:
358+
models_tokens["nvidia"][embedder_params["model"]]
359+
except KeyError as exc:
360+
raise KeyError("Model not supported") from exc
361+
return NVIDIAEmbeddings(model=embedder_params["model"],
362+
nvidia_api_key=embedder_params["api_key"])
344363
elif "ollama" in embedder_params["model"]:
345364
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
346365
try:

scrapegraphai/helpers/models_tokens.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,24 @@
7979
"oneapi": {
8080
"qwen-turbo": 6000
8181
},
82+
"nvidia": {
83+
"meta/llama3-70b-instruct": 419,
84+
"meta/llama3-8b-instruct": 419,
85+
"nemotron-4-340b-instruct": 1024,
86+
"databricks/dbrx-instruct": 4096,
87+
"google/codegemma-7b": 8192,
88+
"google/gemma-2b": 2048,
89+
"google/gemma-7b": 8192,
90+
"google/recurrentgemma-2b": 2048,
91+
"meta/codellama-70b": 16384,
92+
"meta/llama2-70b": 4096,
93+
"microsoft/phi-3-mini-128k-instruct": 122880,
94+
"mistralai/mistral-7b-instruct-v0.2": 4096,
95+
"mistralai/mistral-large": 8192,
96+
"mistralai/mixtral-8x22b-instruct-v0.1": 32768,
97+
"mistralai/mixtral-8x7b-instruct-v0.1": 8192,
98+
"snowflake/arctic": 16384,
99+
},
82100
"groq": {
83101
"llama3-8b-8192": 8192,
84102
"llama3-70b-8192": 8192,

scrapegraphai/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
from .oneapi import OneApi
1717
from .fireworks import Fireworks
1818
from .vertex import VertexAI
19+
from .nvidia import Nvidia

scrapegraphai/models/nvidia.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""
2+
This is a Python wrapper class for ChatNVIDIA.
3+
It provides default configuration and could be extended with additional methods if needed.
4+
The purpose of this wrapper is to simplify the creation of instances of ChatNVIDIA by providing
5+
default configurations for certain parameters,
6+
allowing users to focus on specifying other important parameters without having
7+
to understand all the details of the underlying class's constructor.
8+
It inherits from the base class ChatNVIDIA and overrides
9+
its init method to provide a more user-friendly interface.
10+
The constructor takes one argument: llm_config, which is used to initialize the superclass
11+
with default configuration.
12+
"""
13+
14+
from langchain_nvidia_ai_endpoints import ChatNVIDIA
15+
16+
class Nvidia(ChatNVIDIA):
17+
""" A wrapper for the Nvidia class that provides default configuration
18+
and could be extended with additional methods if needed.
19+
20+
Args:
21+
llm_config (dict): Configuration parameters for the language model.
22+
"""
23+
24+
def __init__(self, llm_config: dict):
25+
super().__init__(**llm_config)

0 commit comments

Comments
 (0)