Skip to content

Commit d77a622

Browse files
committed
Merge branch '423-add-vertex-ai-integration' into support
2 parents 27c2dd2 + 119514b commit d77a622

File tree

5 files changed

+37
-5
lines changed

5 files changed

+37
-5
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ dependencies = [
1616
"langchain==0.1.15",
1717
"langchain-openai==0.1.6",
1818
"langchain-google-genai==1.0.3",
19+
"langchain-google-vertexai==1.0.6",
1920
"langchain-groq==0.1.3",
2021
"langchain-aws==0.1.3",
2122
"langchain-anthropic==0.1.11",

scrapegraphai/graphs/abstract_graph.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
from langchain_aws import BedrockEmbeddings
1111
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
1212
from langchain_google_genai import GoogleGenerativeAIEmbeddings
13+
from langchain_google_vertexai import VertexAIEmbeddings
1314
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
1415
from langchain_fireworks import FireworksEmbeddings
1516
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
16-
1717
from ..helpers import models_tokens
1818
from ..models import (
1919
Anthropic,
@@ -25,7 +25,8 @@
2525
Ollama,
2626
OpenAI,
2727
OneApi,
28-
Fireworks
28+
Fireworks,
29+
VertexAI
2930
)
3031
from ..models.ernie import Ernie
3132
from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info
@@ -73,7 +74,7 @@ def __init__(self, prompt: str, config: dict,
7374
self.config = config
7475
self.schema = schema
7576
self.llm_model = self._create_llm(config["llm"], chat=True)
76-
self.embedder_model = self._create_default_embedder(llm_config=config["llm"] ) if "embeddings" not in config else self._create_embedder(
77+
self.embedder_model = self._create_default_embedder(llm_config=config["llm"]) if "embeddings" not in config else self._create_embedder(
7778
config["embeddings"])
7879
self.verbose = False if config is None else config.get(
7980
"verbose", False)
@@ -179,7 +180,6 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
179180
except KeyError as exc:
180181
raise KeyError("Model not supported") from exc
181182
return AzureOpenAI(llm_params)
182-
183183
elif "gemini" in llm_params["model"]:
184184
llm_params["model"] = llm_params["model"].split("/")[-1]
185185
try:
@@ -194,6 +194,12 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
194194
except KeyError as exc:
195195
raise KeyError("Model not supported") from exc
196196
return Anthropic(llm_params)
197+
elif llm_params["model"].startswith("vertexai"):
198+
try:
199+
self.model_token = models_tokens["vertexai"][llm_params["model"]]
200+
except KeyError as exc:
201+
raise KeyError("Model not supported") from exc
202+
return VertexAI(llm_params)
197203
elif "ollama" in llm_params["model"]:
198204
llm_params["model"] = llm_params["model"].split("ollama/")[-1]
199205

@@ -287,9 +293,12 @@ def _create_default_embedder(self, llm_config=None) -> object:
287293
google_api_key=llm_config["api_key"], model="models/embedding-001"
288294
)
289295
if isinstance(self.llm_model, OpenAI):
290-
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key, base_url=self.llm_model.openai_api_base)
296+
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key,
297+
base_url=self.llm_model.openai_api_base)
291298
elif isinstance(self.llm_model, DeepSeek):
292299
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
300+
elif isinstance(self.llm_model, VertexAI):
301+
return VertexAIEmbeddings()
293302
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
294303
return self.llm_model
295304
elif isinstance(self.llm_model, AzureOpenAI):

scrapegraphai/helpers/models_tokens.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@
8080
"claude3": 200000,
8181
"claude3.5": 200000
8282
},
83+
"vertexai": {
84+
"gemini-1.5-flash": 128000,
85+
"gemini-1.5-pro": 128000,
86+
"gemini-1.0-pro": 128000
87+
},
8388
"bedrock": {
8489
"anthropic.claude-3-haiku-20240307-v1:0": 200000,
8590
"anthropic.claude-3-sonnet-20240229-v1:0": 200000,

scrapegraphai/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@
1515
from .deepseek import DeepSeek
1616
from .oneapi import OneApi
1717
from .fireworks import Fireworks
18+
from .vertex import VertexAI

scrapegraphai/models/vertex.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""
2+
VertexAI Module
3+
"""
4+
from langchain_google_vertexai import ChatVertexAI
5+
6+
class VertexAI(ChatVertexAI):
7+
"""
8+
A wrapper for the ChatVertexAI class that provides default configuration
9+
and could be extended with additional methods if needed.
10+
11+
Args:
12+
llm_config (dict): Configuration parameters for the language model.
13+
"""
14+
15+
def __init__(self, llm_config: dict):
16+
super().__init__(**llm_config)

0 commit comments

Comments
 (0)