1010from langchain_aws import BedrockEmbeddings
1111from langchain_community .embeddings import HuggingFaceHubEmbeddings , OllamaEmbeddings
1212from langchain_google_genai import GoogleGenerativeAIEmbeddings
13+ from langchain_google_vertexai import VertexAIEmbeddings
1314from langchain_google_genai .embeddings import GoogleGenerativeAIEmbeddings
1415from langchain_openai import AzureOpenAIEmbeddings , OpenAIEmbeddings
15-
1616from ..helpers import models_tokens
1717from ..models import (
1818 Anthropic ,
2323 HuggingFace ,
2424 Ollama ,
2525 OpenAI ,
26- OneApi
26+ OneApi ,
27+ VertexAI
2728)
2829from ..models .ernie import Ernie
2930from ..utils .logging import set_verbosity_debug , set_verbosity_warning , set_verbosity_info
@@ -71,7 +72,7 @@ def __init__(self, prompt: str, config: dict,
7172 self .config = config
7273 self .schema = schema
7374 self .llm_model = self ._create_llm (config ["llm" ], chat = True )
74- self .embedder_model = self ._create_default_embedder (llm_config = config ["llm" ] ) if "embeddings" not in config else self ._create_embedder (
75+ self .embedder_model = self ._create_default_embedder (llm_config = config ["llm" ]) if "embeddings" not in config else self ._create_embedder (
7576 config ["embeddings" ])
7677 self .verbose = False if config is None else config .get (
7778 "verbose" , False )
@@ -102,7 +103,7 @@ def __init__(self, prompt: str, config: dict,
102103 "embedder_model" : self .embedder_model ,
103104 "cache_path" : self .cache_path ,
104105 }
105-
106+
106107 self .set_common_params (common_params , overwrite = True )
107108
108109 # set burr config
@@ -125,7 +126,7 @@ def set_common_params(self, params: dict, overwrite=False):
125126
126127 for node in self .graph .nodes :
127128 node .update_config (params , overwrite )
128-
129+
129130 def _create_llm (self , llm_config : dict , chat = False ) -> object :
130131 """
131132 Create a large language model instance based on the configuration provided.
@@ -170,7 +171,6 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
170171 except KeyError as exc :
171172 raise KeyError ("Model not supported" ) from exc
172173 return AzureOpenAI (llm_params )
173-
174174 elif "gemini" in llm_params ["model" ]:
175175 try :
176176 self .model_token = models_tokens ["gemini" ][llm_params ["model" ]]
@@ -183,6 +183,12 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
183183 except KeyError as exc :
184184 raise KeyError ("Model not supported" ) from exc
185185 return Anthropic (llm_params )
186+ elif llm_params ["model" ].startswith ("vertexai" ):
187+ try :
188+ self .model_token = models_tokens ["vertexai" ][llm_params ["model" ]]
189+ except KeyError as exc :
190+ raise KeyError ("Model not supported" ) from exc
191+ return VertexAI (llm_params )
186192 elif "ollama" in llm_params ["model" ]:
187193 llm_params ["model" ] = llm_params ["model" ].split ("ollama/" )[- 1 ]
188194
@@ -275,10 +281,12 @@ def _create_default_embedder(self, llm_config=None) -> object:
275281 google_api_key = llm_config ["api_key" ], model = "models/embedding-001"
276282 )
277283 if isinstance (self .llm_model , OpenAI ):
278- return OpenAIEmbeddings (api_key = self .llm_model .openai_api_key , base_url = self .llm_model .openai_api_base )
284+ return OpenAIEmbeddings (api_key = self .llm_model .openai_api_key ,
285+ base_url = self .llm_model .openai_api_base )
279286 elif isinstance (self .llm_model , DeepSeek ):
280- return OpenAIEmbeddings (api_key = self .llm_model .openai_api_key )
281-
287+ return OpenAIEmbeddings (api_key = self .llm_model .openai_api_key )
288+ elif isinstance (self .llm_model , VertexAI ):
289+ return VertexAIEmbeddings ()
282290 elif isinstance (self .llm_model , AzureOpenAIEmbeddings ):
283291 return self .llm_model
284292 elif isinstance (self .llm_model , AzureOpenAI ):
0 commit comments