77import uuid
88from pydantic import BaseModel
99
10- from langchain_community .chat_models import ChatOllama
11- from langchain_openai import ChatOpenAI
12-
10+ from langchain_community .chat_models import ChatOllama , ErnieBotChat
1311from langchain_aws import BedrockEmbeddings , ChatBedrock
1412from langchain_huggingface import ChatHuggingFace , HuggingFaceEmbeddings
1513from langchain_community .embeddings import OllamaEmbeddings
16- from langchain_google_genai import GoogleGenerativeAIEmbeddings
14+ from langchain_google_genai import GoogleGenerativeAIEmbeddings , ChatGoogleGenerativeAI
1715from langchain_google_vertexai import ChatVertexAI , VertexAIEmbeddings
18- from langchain_google_genai import ChatGoogleGenerativeAI
19- from langchain_google_genai .embeddings import GoogleGenerativeAIEmbeddings
2016from langchain_fireworks import FireworksEmbeddings , ChatFireworks
2117from langchain_openai import AzureOpenAIEmbeddings , OpenAIEmbeddings , ChatOpenAI , AzureChatOpenAI
2218from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings , ChatNVIDIA
23- from langchain_community .chat_models import ErnieBotChat
19+ from langchain .chat_models import init_chat_model
20+
2421from ..helpers import models_tokens
2522from ..models import (
2623 OneApi ,
2724 DeepSeek
2825)
26+ from ..utils .logging import set_verbosity_warning , set_verbosity_info
2927
30- from langchain .chat_models import init_chat_model
31-
32- from ..utils .logging import set_verbosity_debug , set_verbosity_warning , set_verbosity_info
33-
34- from ..helpers import models_tokens
3528
3629
3730class AbstractGraph (ABC ):
@@ -65,14 +58,14 @@ class AbstractGraph(ABC):
6558 >>> result = my_graph.run()
6659 """
6760
68- def __init__ (self , prompt : str , config : dict ,
61+ def __init__ (self , prompt : str , config : dict ,
6962 source : Optional [str ] = None , schema : Optional [BaseModel ] = None ):
7063
7164 self .prompt = prompt
7265 self .source = source
7366 self .config = config
7467 self .schema = schema
75- self .llm_model = self ._create_llm (config ["llm" ], chat = True )
68+ self .llm_model = self ._create_llm (config ["llm" ])
7669 self .embedder_model = self ._create_default_embedder (llm_config = config ["llm" ]) if "embeddings" not in config else self ._create_embedder (
7770 config ["embeddings" ])
7871 self .verbose = False if config is None else config .get (
@@ -128,7 +121,7 @@ def set_common_params(self, params: dict, overwrite=False):
128121 for node in self .graph .nodes :
129122 node .update_config (params , overwrite )
130123
131- def _create_llm (self , llm_config : dict , chat = False ) -> object :
124+ def _create_llm (self , llm_config : dict ) -> object :
132125 """
133126 Create a large language model instance based on the configuration provided.
134127
@@ -148,9 +141,9 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
148141 # If model instance is passed directly instead of the model details
149142 if "model_instance" in llm_params :
150143 try :
151- self .model_token = llm_params ["model_tokens" ]
144+ self .model_token = llm_params ["model_tokens" ]
152145 except KeyError as exc :
153- raise KeyError ("model_tokens not specified" ) from exc
146+ raise KeyError ("model_tokens not specified" ) from exc
154147 return llm_params ["model_instance" ]
155148
156149 # Instantiate the language model based on the model name
@@ -161,23 +154,26 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
161154 except KeyError as exc :
162155 raise KeyError ("Model not supported" ) from exc
163156 return init_chat_model (** llm_params )
164- elif "oneapi" in llm_params ["model" ]:
157+
158+ if "oneapi" in llm_params ["model" ]:
165159 # take the model after the last dash
166160 llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
167161 try :
168162 self .model_token = models_tokens ["oneapi" ][llm_params ["model" ]]
169163 except KeyError as exc :
170164 raise KeyError ("Model not supported" ) from exc
171165 return OneApi (llm_params )
172- elif "fireworks" in llm_params ["model" ]:
166+
167+ if "fireworks" in llm_params ["model" ]:
173168 try :
174169 self .model_token = models_tokens ["fireworks" ][llm_params ["model" ].split ("/" )[- 1 ]]
175170 llm_params ["model" ] = "/" .join (llm_params ["model" ].split ("/" )[1 :])
176171 except KeyError as exc :
177172 raise KeyError ("Model not supported" ) from exc
178173 llm_params ["model_provider" ] = "fireworks"
179174 return init_chat_model (** llm_params )
180- elif "azure" in llm_params ["model" ]:
175+
176+ if "azure" in llm_params ["model" ]:
181177 # take the model after the last dash
182178 llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
183179 try :
@@ -186,38 +182,42 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
186182 raise KeyError ("Model not supported" ) from exc
187183 llm_params ["model_provider" ] = "azure_openai"
188184 return init_chat_model (** llm_params )
189- elif "nvidia" in llm_params ["model" ]:
185+
186+ if "nvidia" in llm_params ["model" ]:
190187 try :
191188 self .model_token = models_tokens ["nvidia" ][llm_params ["model" ].split ("/" )[- 1 ]]
192189 llm_params ["model" ] = "/" .join (llm_params ["model" ].split ("/" )[1 :])
193190 except KeyError as exc :
194191 raise KeyError ("Model not supported" ) from exc
195192 return ChatNVIDIA (llm_params )
196- elif "gemini" in llm_params ["model" ]:
193+
194+ if "gemini" in llm_params ["model" ]:
197195 llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
198196 try :
199197 self .model_token = models_tokens ["gemini" ][llm_params ["model" ]]
200198 except KeyError as exc :
201199 raise KeyError ("Model not supported" ) from exc
202200 llm_params ["model_provider" ] = "google_genai "
203201 return init_chat_model (** llm_params )
204- elif llm_params ["model" ].startswith ("claude" ):
202+
203+ if llm_params ["model" ].startswith ("claude" ):
205204 llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
206205 try :
207206 self .model_token = models_tokens ["claude" ][llm_params ["model" ]]
208207 except KeyError as exc :
209208 raise KeyError ("Model not supported" ) from exc
210209 llm_params ["model_provider" ] = "anthropic"
211210 return init_chat_model (** llm_params )
212- elif llm_params ["model" ].startswith ("vertexai" ):
211+
212+ if llm_params ["model" ].startswith ("vertexai" ):
213213 try :
214214 self .model_token = models_tokens ["vertexai" ][llm_params ["model" ]]
215215 except KeyError as exc :
216216 raise KeyError ("Model not supported" ) from exc
217217 llm_params ["model_provider" ] = "google_vertexai"
218218 return init_chat_model (** llm_params )
219219
220- elif "ollama" in llm_params ["model" ]:
220+ if "ollama" in llm_params ["model" ]:
221221 llm_params ["model" ] = llm_params ["model" ].split ("ollama/" )[- 1 ]
222222 llm_params ["model_provider" ] = "ollama"
223223
@@ -238,7 +238,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
238238
239239 return init_chat_model (** llm_params )
240240
241- elif "hugging_face" in llm_params ["model" ]:
241+ if "hugging_face" in llm_params ["model" ]:
242242 llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
243243 try :
244244 self .model_token = models_tokens ["hugging_face" ][llm_params ["model" ]]
@@ -247,7 +247,8 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
247247 self .model_token = 8192
248248 llm_params ["model_provider" ] = "hugging_face"
249249 return init_chat_model (** llm_params )
250- elif "groq" in llm_params ["model" ]:
250+
251+ if "groq" in llm_params ["model" ]:
251252 llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
252253
253254 try :
@@ -257,41 +258,43 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
257258 self .model_token = 8192
258259 llm_params ["model_provider" ] = "groq"
259260 return init_chat_model (** llm_params )
260- elif "bedrock" in llm_params ["model" ]:
261+
262+ if "bedrock" in llm_params ["model" ]:
261263 llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
262- model_id = llm_params ["model" ]
263- client = llm_params .get ("client" , None )
264264 try :
265265 self .model_token = models_tokens ["bedrock" ][llm_params ["model" ]]
266266 except KeyError :
267267 print ("model not found, using default token size (8192)" )
268268 self .model_token = 8192
269269 llm_params ["model_provider" ] = "bedrock"
270270 return init_chat_model (** llm_params )
271- elif "claude-3-" in llm_params ["model" ]:
271+
272+ if "claude-3-" in llm_params ["model" ]:
272273 try :
273274 self .model_token = models_tokens ["claude" ]["claude3" ]
274275 except KeyError :
275276 print ("model not found, using default token size (8192)" )
276277 self .model_token = 8192
277278 llm_params ["model_provider" ] = "anthropic"
278279 return init_chat_model (** llm_params )
279- elif "deepseek" in llm_params ["model" ]:
280+
281+ if "deepseek" in llm_params ["model" ]:
280282 try :
281283 self .model_token = models_tokens ["deepseek" ][llm_params ["model" ]]
282284 except KeyError :
283285 print ("model not found, using default token size (8192)" )
284286 self .model_token = 8192
285287 return DeepSeek (llm_params )
286- elif "ernie" in llm_params ["model" ]:
288+
289+ if "ernie" in llm_params ["model" ]:
287290 try :
288291 self .model_token = models_tokens ["ernie" ][llm_params ["model" ]]
289292 except KeyError :
290293 print ("model not found, using default token size (8192)" )
291294 self .model_token = 8192
292295 return ErnieBotChat (llm_params )
293- else :
294- raise ValueError ("Model provided by the configuration not supported" )
296+
297+ raise ValueError ("Model provided by the configuration not supported" )
295298
296299 def _create_default_embedder (self , llm_config = None ) -> object :
297300 """
@@ -308,7 +311,7 @@ def _create_default_embedder(self, llm_config=None) -> object:
308311 google_api_key = llm_config ["api_key" ], model = "models/embedding-001"
309312 )
310313 if isinstance (self .llm_model , ChatOpenAI ):
311- return OpenAIEmbeddings (api_key = self .llm_model .openai_api_key ,
314+ return OpenAIEmbeddings (api_key = self .llm_model .openai_api_key ,
312315 base_url = self .llm_model .openai_api_base )
313316 elif isinstance (self .llm_model , DeepSeek ):
314317 return OpenAIEmbeddings (api_key = self .llm_model .openai_api_key )
@@ -356,53 +359,53 @@ def _create_embedder(self, embedder_config: dict) -> object:
356359 # Instantiate the embedding model based on the model name
357360 if "openai" in embedder_params ["model" ]:
358361 return OpenAIEmbeddings (api_key = embedder_params ["api_key" ])
359- elif "azure" in embedder_params ["model" ]:
362+ if "azure" in embedder_params ["model" ]:
360363 return AzureOpenAIEmbeddings ()
361364 if "nvidia" in embedder_params ["model" ]:
362365 embedder_params ["model" ] = "/" .join (embedder_params ["model" ].split ("/" )[1 :])
363366 try :
364367 models_tokens ["nvidia" ][embedder_params ["model" ]]
365368 except KeyError as exc :
366369 raise KeyError ("Model not supported" ) from exc
367- return NVIDIAEmbeddings (model = embedder_params ["model" ],
370+ return NVIDIAEmbeddings (model = embedder_params ["model" ],
368371 nvidia_api_key = embedder_params ["api_key" ])
369- elif "ollama" in embedder_params ["model" ]:
372+ if "ollama" in embedder_params ["model" ]:
370373 embedder_params ["model" ] = "/" .join (embedder_params ["model" ].split ("/" )[1 :])
371374 try :
372375 models_tokens ["ollama" ][embedder_params ["model" ]]
373376 except KeyError as exc :
374377 raise KeyError ("Model not supported" ) from exc
375378 return OllamaEmbeddings (** embedder_params )
376- elif "hugging_face" in embedder_params ["model" ]:
379+ if "hugging_face" in embedder_params ["model" ]:
377380 embedder_params ["model" ] = "/" .join (embedder_params ["model" ].split ("/" )[1 :])
378381 try :
379382 models_tokens ["hugging_face" ][embedder_params ["model" ]]
380383 except KeyError as exc :
381384 raise KeyError ("Model not supported" ) from exc
382385 return HuggingFaceEmbeddings (model = embedder_params ["model" ])
383- elif "fireworks" in embedder_params ["model" ]:
386+ if "fireworks" in embedder_params ["model" ]:
384387 embedder_params ["model" ] = "/" .join (embedder_params ["model" ].split ("/" )[1 :])
385388 try :
386389 models_tokens ["fireworks" ][embedder_params ["model" ]]
387390 except KeyError as exc :
388391 raise KeyError ("Model not supported" ) from exc
389392 return FireworksEmbeddings (model = embedder_params ["model" ])
390- elif "gemini" in embedder_params ["model" ]:
393+ if "gemini" in embedder_params ["model" ]:
391394 try :
392395 models_tokens ["gemini" ][embedder_params ["model" ]]
393396 except KeyError as exc :
394397 raise KeyError ("Model not supported" ) from exc
395398 return GoogleGenerativeAIEmbeddings (model = embedder_params ["model" ])
396- elif "bedrock" in embedder_params ["model" ]:
399+ if "bedrock" in embedder_params ["model" ]:
397400 embedder_params ["model" ] = embedder_params ["model" ].split ("/" )[- 1 ]
398401 client = embedder_params .get ("client" , None )
399402 try :
400403 models_tokens ["bedrock" ][embedder_params ["model" ]]
401404 except KeyError as exc :
402405 raise KeyError ("Model not supported" ) from exc
403406 return BedrockEmbeddings (client = client , model_id = embedder_params ["model" ])
404- else :
405- raise ValueError ("Model provided by the configuration not supported" )
407+
408+ raise ValueError ("Model provided by the configuration not supported" )
406409
407410 def get_state (self , key = None ) -> dict :
408411 """ ""
@@ -444,11 +447,9 @@ def _create_graph(self):
444447 """
445448 Abstract method to create a graph representation.
446449 """
447- pass
448450
449451 @abstractmethod
450452 def run (self ) -> str :
451453 """
452454 Abstract method to execute the graph and return the result.
453455 """
454- pass
0 commit comments