11"""
22AbstractGraph Module
33"""
4+
45from abc import ABC , abstractmethod
56from typing import Optional
67import uuid
910from langchain .chat_models import init_chat_model
1011from langchain_core .rate_limiters import InMemoryRateLimiter
1112from ..helpers import models_tokens
12- from ..models import (
13- OneApi ,
14- DeepSeek
15- )
13+ from ..models import OneApi , DeepSeek
1614from ..utils .logging import set_verbosity_warning , set_verbosity_info
1715
16+
1817class AbstractGraph (ABC ):
1918 """
2019 Scaffolding class for creating a graph representation and executing it.
@@ -39,14 +38,18 @@ class AbstractGraph(ABC):
3938 ... # Implementation of graph creation here
4039 ... return graph
4140 ...
42- >>> my_graph = MyGraph("Example Graph",
41+ >>> my_graph = MyGraph("Example Graph",
4342 {"llm": {"model": "gpt-3.5-turbo"}}, "example_source")
4443 >>> result = my_graph.run()
4544 """
4645
47- def __init__ (self , prompt : str , config : dict ,
48- source : Optional [str ] = None , schema : Optional [BaseModel ] = None ):
49-
46+ def __init__ (
47+ self ,
48+ prompt : str ,
49+ config : dict ,
50+ source : Optional [str ] = None ,
51+ schema : Optional [BaseModel ] = None ,
52+ ):
5053 if config .get ("llm" ).get ("temperature" ) is None :
5154 config ["llm" ]["temperature" ] = 0
5255
@@ -55,14 +58,13 @@ def __init__(self, prompt: str, config: dict,
5558 self .config = config
5659 self .schema = schema
5760 self .llm_model = self ._create_llm (config ["llm" ])
58- self .verbose = False if config is None else config .get (
59- "verbose" , False )
60- self .headless = True if self .config is None else config .get (
61- "headless" , True )
61+ self .verbose = False if config is None else config .get ("verbose" , False )
62+ self .headless = True if self .config is None else config .get ("headless" , True )
6263 self .loader_kwargs = self .config .get ("loader_kwargs" , {})
6364 self .cache_path = self .config .get ("cache_path" , False )
6465 self .browser_base = self .config .get ("browser_base" )
6566 self .scrape_do = self .config .get ("scrape_do" )
67+ self .storage_state = self .config .get ("storage_state" )
6668
6769 self .graph = self ._create_graph ()
6870 self .final_state = None
@@ -81,7 +83,7 @@ def __init__(self, prompt: str, config: dict,
8183 "loader_kwargs" : self .loader_kwargs ,
8284 "llm_model" : self .llm_model ,
8385 "cache_path" : self .cache_path ,
84- }
86+ }
8587
8688 self .set_common_params (common_params , overwrite = True )
8789
@@ -129,7 +131,8 @@ def _create_llm(self, llm_config: dict) -> object:
129131 with warnings .catch_warnings ():
130132 warnings .simplefilter ("ignore" )
131133 llm_params ["rate_limiter" ] = InMemoryRateLimiter (
132- requests_per_second = requests_per_second )
134+ requests_per_second = requests_per_second
135+ )
133136 if max_retries is not None :
134137 llm_params ["max_retries" ] = max_retries
135138
@@ -140,30 +143,55 @@ def _create_llm(self, llm_config: dict) -> object:
140143 raise KeyError ("model_tokens not specified" ) from exc
141144 return llm_params ["model_instance" ]
142145
143- known_providers = {"openai" , "azure_openai" , "google_genai" , "google_vertexai" ,
144- "ollama" , "oneapi" , "nvidia" , "groq" , "anthropic" , "bedrock" , "mistralai" ,
145- "hugging_face" , "deepseek" , "ernie" , "fireworks" , "togetherai" }
146-
147- if '/' in llm_params ["model" ]:
148- split_model_provider = llm_params ["model" ].split ("/" , 1 )
149- llm_params ["model_provider" ] = split_model_provider [0 ]
150- llm_params ["model" ] = split_model_provider [1 ]
146+ known_providers = {
147+ "openai" ,
148+ "azure_openai" ,
149+ "google_genai" ,
150+ "google_vertexai" ,
151+ "ollama" ,
152+ "oneapi" ,
153+ "nvidia" ,
154+ "groq" ,
155+ "anthropic" ,
156+ "bedrock" ,
157+ "mistralai" ,
158+ "hugging_face" ,
159+ "deepseek" ,
160+ "ernie" ,
161+ "fireworks" ,
162+ "togetherai" ,
163+ }
164+
165+ if "/" in llm_params ["model" ]:
166+ split_model_provider = llm_params ["model" ].split ("/" , 1 )
167+ llm_params ["model_provider" ] = split_model_provider [0 ]
168+ llm_params ["model" ] = split_model_provider [1 ]
151169 else :
152- possible_providers = [provider for provider , models_d in models_tokens .items () if llm_params ["model" ] in models_d ]
170+ possible_providers = [
171+ provider
172+ for provider , models_d in models_tokens .items ()
173+ if llm_params ["model" ] in models_d
174+ ]
153175 if len (possible_providers ) <= 0 :
154176 raise ValueError (f"""Provider { llm_params ['model_provider' ]} is not supported.
155177 If possible, try to use a model instance instead.""" )
156178 llm_params ["model_provider" ] = possible_providers [0 ]
157- print ((f"Found providers { possible_providers } for model { llm_params ['model' ]} , using { llm_params ['model_provider' ]} .\n "
158- "If it was not intended please specify the model provider in the graph configuration" ))
179+ print (
180+ (
181+ f"Found providers { possible_providers } for model { llm_params ['model' ]} , using { llm_params ['model_provider' ]} .\n "
182+ "If it was not intended please specify the model provider in the graph configuration"
183+ )
184+ )
159185
160186 if llm_params ["model_provider" ] not in known_providers :
161187 raise ValueError (f"""Provider { llm_params ['model_provider' ]} is not supported.
162188 If possible, try to use a model instance instead.""" )
163189
164190 if "model_tokens" not in llm_params :
165191 try :
166- self .model_token = models_tokens [llm_params ["model_provider" ]][llm_params ["model" ]]
192+ self .model_token = models_tokens [llm_params ["model_provider" ]][
193+ llm_params ["model" ]
194+ ]
167195 except KeyError :
168196 print (f"""Model { llm_params ['model_provider' ]} /{ llm_params ['model' ]} not found,
169197 using default token size (8192)""" )
@@ -172,10 +200,17 @@ def _create_llm(self, llm_config: dict) -> object:
172200 self .model_token = llm_params ["model_tokens" ]
173201
174202 try :
175- if llm_params ["model_provider" ] not in \
176- {"oneapi" ,"nvidia" ,"ernie" ,"deepseek" ,"togetherai" }:
203+ if llm_params ["model_provider" ] not in {
204+ "oneapi" ,
205+ "nvidia" ,
206+ "ernie" ,
207+ "deepseek" ,
208+ "togetherai" ,
209+ }:
177210 if llm_params ["model_provider" ] == "bedrock" :
178- llm_params ["model_kwargs" ] = { "temperature" : llm_params .pop ("temperature" ) }
211+ llm_params ["model_kwargs" ] = {
212+ "temperature" : llm_params .pop ("temperature" )
213+ }
179214 with warnings .catch_warnings ():
180215 warnings .simplefilter ("ignore" )
181216 return init_chat_model (** llm_params )
@@ -187,6 +222,7 @@ def _create_llm(self, llm_config: dict) -> object:
187222
188223 if model_provider == "ernie" :
189224 from langchain_community .chat_models import ErnieBotChat
225+
190226 return ErnieBotChat (** llm_params )
191227
192228 elif model_provider == "oneapi" :
@@ -211,7 +247,6 @@ def _create_llm(self, llm_config: dict) -> object:
211247 except Exception as e :
212248 raise Exception (f"Error instancing model: { e } " )
213249
214-
215250 def get_state (self , key = None ) -> dict :
216251 """ ""
217252 Get the final state of the graph.
0 commit comments