11"""
22AbstractGraph Module
33"""
4+
45from abc import ABC , abstractmethod
56from typing import Optional
7+
68from langchain_aws import BedrockEmbeddings
7- from langchain_openai import AzureOpenAIEmbeddings , OpenAIEmbeddings
89from langchain_community .embeddings import HuggingFaceHubEmbeddings , OllamaEmbeddings
910from langchain_google_genai import GoogleGenerativeAIEmbeddings
10- from ..helpers import models_tokens
11- from ..utils .logging import set_verbosity
12- from ..models import AzureOpenAI , Bedrock , Gemini , Groq , HuggingFace , Ollama , OpenAI , Anthropic
1311from langchain_google_genai .embeddings import GoogleGenerativeAIEmbeddings
12+ from langchain_openai import AzureOpenAIEmbeddings , OpenAIEmbeddings
13+
14+ from ..helpers import models_tokens
15+ from ..models import (
16+ Anthropic ,
17+ AzureOpenAI ,
18+ Bedrock ,
19+ Gemini ,
20+ Groq ,
21+ HuggingFace ,
22+ Ollama ,
23+ OpenAI ,
24+ )
25+ from ..utils .logging import set_verbosity_debug , set_verbosity_warning
26+
1427
1528class AbstractGraph (ABC ):
1629 """
@@ -46,29 +59,35 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
4659 self .source = source
4760 self .config = config
4861 self .llm_model = self ._create_llm (config ["llm" ], chat = True )
49- self .embedder_model = self ._create_default_embedder (llm_config = config ["llm" ]
50- ) if "embeddings" not in config else self ._create_embedder (
51- config ["embeddings" ])
62+ self .embedder_model = (
63+ self ._create_default_embedder (llm_config = config ["llm" ])
64+ if "embeddings" not in config
65+ else self ._create_embedder (config ["embeddings" ])
66+ )
5267
5368 # Create the graph
5469 self .graph = self ._create_graph ()
5570 self .final_state = None
5671 self .execution_info = None
5772
5873 # Set common configuration parameters
59-
60- verbose = False if config is None else config .get (
61- "verbose" , False )
62- set_verbosity (config .get ("verbose" , "info" ))
63- self .headless = True if config is None else config .get (
64- "headless" , True )
74+
75+ verbose = bool (config and config .get ("verbose" ))
76+
77+ if verbose :
78+ set_verbosity_debug ()
79+ else :
80+ set_verbosity_warning ()
81+
82+ self .headless = True if config is None else config .get ("headless" , True )
6583 self .loader_kwargs = config .get ("loader_kwargs" , {})
6684
67- common_params = {"headless" : self .headless ,
68-
69- "loader_kwargs" : self .loader_kwargs ,
70- "llm_model" : self .llm_model ,
71- "embedder_model" : self .embedder_model }
85+ common_params = {
86+ "headless" : self .headless ,
87+ "loader_kwargs" : self .loader_kwargs ,
88+ "llm_model" : self .llm_model ,
89+ "embedder_model" : self .embedder_model ,
90+ }
7291 self .set_common_params (common_params , overwrite = False )
7392
7493 def set_common_params (self , params : dict , overwrite = False ):
@@ -81,25 +100,25 @@ def set_common_params(self, params: dict, overwrite=False):
81100
82101 for node in self .graph .nodes :
83102 node .update_config (params , overwrite )
84-
103+
85104 def _set_model_token (self , llm ):
86105
87- if ' Azure' in str (type (llm )):
106+ if " Azure" in str (type (llm )):
88107 try :
89108 self .model_token = models_tokens ["azure" ][llm .model_name ]
90109 except KeyError :
91110 raise KeyError ("Model not supported" )
92111
93- elif ' HuggingFaceEndpoint' in str (type (llm )):
94- if ' mistral' in llm .repo_id :
112+ elif " HuggingFaceEndpoint" in str (type (llm )):
113+ if " mistral" in llm .repo_id :
95114 try :
96- self .model_token = models_tokens [' mistral' ][llm .repo_id ]
115+ self .model_token = models_tokens [" mistral" ][llm .repo_id ]
97116 except KeyError :
98117 raise KeyError ("Model not supported" )
99- elif ' Google' in str (type (llm )):
118+ elif " Google" in str (type (llm )):
100119 try :
101- if ' gemini' in llm .model :
102- self .model_token = models_tokens [' gemini' ][llm .model ]
120+ if " gemini" in llm .model :
121+ self .model_token = models_tokens [" gemini" ][llm .model ]
103122 except KeyError :
104123 raise KeyError ("Model not supported" )
105124
@@ -117,17 +136,14 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
117136 KeyError: If the model is not supported.
118137 """
119138
120- llm_defaults = {
121- "temperature" : 0 ,
122- "streaming" : False
123- }
139+ llm_defaults = {"temperature" : 0 , "streaming" : False }
124140 llm_params = {** llm_defaults , ** llm_config }
125141
126142 # If model instance is passed directly instead of the model details
127- if ' model_instance' in llm_params :
143+ if " model_instance" in llm_params :
128144 if chat :
129- self ._set_model_token (llm_params [' model_instance' ])
130- return llm_params [' model_instance' ]
145+ self ._set_model_token (llm_params [" model_instance" ])
146+ return llm_params [" model_instance" ]
131147
132148 # Instantiate the language model based on the model name
133149 if "gpt-" in llm_params ["model" ]:
@@ -193,18 +209,20 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
193209 elif "bedrock" in llm_params ["model" ]:
194210 llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
195211 model_id = llm_params ["model" ]
196- client = llm_params .get (' client' , None )
212+ client = llm_params .get (" client" , None )
197213 try :
198214 self .model_token = models_tokens ["bedrock" ][llm_params ["model" ]]
199215 except KeyError as exc :
200216 raise KeyError ("Model not supported" ) from exc
201- return Bedrock ({
202- "client" : client ,
203- "model_id" : model_id ,
204- "model_kwargs" : {
205- "temperature" : llm_params ["temperature" ],
217+ return Bedrock (
218+ {
219+ "client" : client ,
220+ "model_id" : model_id ,
221+ "model_kwargs" : {
222+ "temperature" : llm_params ["temperature" ],
223+ },
206224 }
207- } )
225+ )
208226 elif "claude-3-" in llm_params ["model" ]:
209227 self .model_token = models_tokens ["claude" ]["claude3" ]
210228 return Anthropic (llm_params )
@@ -215,8 +233,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
215233 raise KeyError ("Model not supported" ) from exc
216234 return DeepSeek (llm_params )
217235 else :
218- raise ValueError (
219- "Model provided by the configuration not supported" )
236+ raise ValueError ("Model provided by the configuration not supported" )
220237
221238 def _create_default_embedder (self , llm_config = None ) -> object :
222239 """
@@ -229,8 +246,9 @@ def _create_default_embedder(self, llm_config=None) -> object:
229246 ValueError: If the model is not supported.
230247 """
231248 if isinstance (self .llm_model , Gemini ):
232- return GoogleGenerativeAIEmbeddings (google_api_key = llm_config ['api_key' ],
233- model = "models/embedding-001" )
249+ return GoogleGenerativeAIEmbeddings (
250+ google_api_key = llm_config ["api_key" ], model = "models/embedding-001"
251+ )
234252 if isinstance (self .llm_model , OpenAI ):
235253 return OpenAIEmbeddings (api_key = self .llm_model .openai_api_key )
236254 elif isinstance (self .llm_model , AzureOpenAIEmbeddings ):
@@ -265,8 +283,8 @@ def _create_embedder(self, embedder_config: dict) -> object:
265283 Raises:
266284 KeyError: If the model is not supported.
267285 """
268- if ' model_instance' in embedder_config :
269- return embedder_config [' model_instance' ]
286+ if " model_instance" in embedder_config :
287+ return embedder_config [" model_instance" ]
270288 # Instantiate the embedding model based on the model name
271289 if "openai" in embedder_config ["model" ]:
272290 return OpenAIEmbeddings (api_key = embedder_config ["api_key" ])
@@ -283,28 +301,27 @@ def _create_embedder(self, embedder_config: dict) -> object:
283301 try :
284302 models_tokens ["hugging_face" ][embedder_config ["model" ]]
285303 except KeyError as exc :
286- raise KeyError ("Model not supported" )from exc
304+ raise KeyError ("Model not supported" ) from exc
287305 return HuggingFaceHubEmbeddings (model = embedder_config ["model" ])
288306 elif "gemini" in embedder_config ["model" ]:
289307 try :
290308 models_tokens ["gemini" ][embedder_config ["model" ]]
291309 except KeyError as exc :
292- raise KeyError ("Model not supported" )from exc
310+ raise KeyError ("Model not supported" ) from exc
293311 return GoogleGenerativeAIEmbeddings (model = embedder_config ["model" ])
294312 elif "bedrock" in embedder_config ["model" ]:
295313 embedder_config ["model" ] = embedder_config ["model" ].split ("/" )[- 1 ]
296- client = embedder_config .get (' client' , None )
314+ client = embedder_config .get (" client" , None )
297315 try :
298316 models_tokens ["bedrock" ][embedder_config ["model" ]]
299317 except KeyError as exc :
300318 raise KeyError ("Model not supported" ) from exc
301319 return BedrockEmbeddings (client = client , model_id = embedder_config ["model" ])
302320 else :
303- raise ValueError (
304- "Model provided by the configuration not supported" )
321+ raise ValueError ("Model provided by the configuration not supported" )
305322
306323 def get_state (self , key = None ) -> dict :
307- """""
324+ """ ""
308325 Get the final state of the graph.
309326
310327 Args:
0 commit comments