11"""
22SmartScraperGraph Module Burr Version
33"""
4- from typing import Tuple
4+ from typing import Tuple , Union
55
66from burr import tracking
77from burr .core import Application , ApplicationBuilder , State , default , when
1414from langchain_community .document_transformers import Html2TextTransformer , EmbeddingsRedundantFilter
1515from langchain_community .vectorstores import FAISS
1616from langchain_core .documents import Document
17+ from langchain_core import load as lc_serde
1718from langchain_core .output_parsers import JsonOutputParser
1819from langchain_core .prompts import PromptTemplate
1920from langchain_core .runnables import RunnableParallel
@@ -67,10 +68,10 @@ def parse_node(state: State, chunk_size: int = 4096) -> tuple[dict, State]:
6768
6869@action (reads = ["user_prompt" , "parsed_doc" , "doc" ],
6970 writes = ["relevant_chunks" ])
70- def rag_node (state : State , llm_model : object , embedder_model : object ) -> tuple [dict , State ]:
71- # bug around input serialization with tracker
72- llm_model = OpenAI ({"model_name" : "gpt-3.5-turbo" })
73- embedder_model = OpenAIEmbeddings ()
71+ def rag_node (state : State , llm_model : str , embedder_model : object ) -> tuple [dict , State ]:
72+ # bug around input serialization with tracker -- so instantiate objects here:
73+ llm_model = OpenAI ({"model_name" : llm_model })
74+ embedder_model = OpenAIEmbeddings () if embedder_model == "openai" else None
7475 user_prompt = state ["user_prompt" ]
7576 doc = state ["parsed_doc" ]
7677
@@ -104,8 +105,10 @@ def rag_node(state: State, llm_model: object, embedder_model: object) -> tuple[d
104105
105106@action (reads = ["user_prompt" , "relevant_chunks" , "parsed_doc" , "doc" ],
106107 writes = ["answer" ])
107- def generate_answer_node (state : State , llm_model : object ) -> tuple [dict , State ]:
108- llm_model = OpenAI ({"model_name" : "gpt-3.5-turbo" })
108+ def generate_answer_node (state : State , llm_model : str ) -> tuple [dict , State ]:
109+ # bug around input serialization with tracker -- so instantiate objects here:
110+ llm_model = OpenAI ({"model_name" : llm_model })
111+
109112 user_prompt = state ["user_prompt" ]
110113 doc = state .get ("relevant_chunks" ,
111114 state .get ("parsed_doc" ,
@@ -207,21 +210,49 @@ def post_run_step(
207210 ):
208211 print (f"Finishing action: { action .name } " )
209212
213+ import json
214+
215+ def _deserialize_document (x : Union [str , dict ]) -> Document :
216+ if isinstance (x , dict ):
217+ return lc_serde .load (x )
218+ elif isinstance (x , str ):
219+ try :
220+ return lc_serde .loads (x )
221+ except json .JSONDecodeError :
222+ return Document (page_content = x )
223+ raise ValueError ("Couldn't deserialize document" )
224+
210225
211226def run (prompt : str , input_key : str , source : str , config : dict ) -> str :
227+ # these configs aren't really used yet.
212228 llm_model = config ["llm_model" ]
213-
214229 embedder_model = config ["embedder_model" ]
215- open_ai_embedder = OpenAIEmbeddings ()
230+ # open_ai_embedder = OpenAIEmbeddings()
216231 chunk_size = config ["model_token" ]
217232
233+ tracker = tracking .LocalTrackingClient (project = "smart-scraper-graph" )
234+ app_instance_id = "testing-12345678919"
218235 initial_state = {
219236 "user_prompt" : prompt ,
220237 input_key : source ,
221238 }
222- from burr .core import expr
223- tracker = tracking .LocalTrackingClient (project = "smart-scraper-graph" )
224-
239+ entry_point = "fetch_node"
240+ if app_instance_id :
241+ persisted_state = tracker .load (None , app_id = app_instance_id , sequence_no = None )
242+ if not persisted_state :
243+ print (f"Warning: No persisted state found for app_id { app_instance_id } ." )
244+ else :
245+ initial_state = persisted_state ["state" ]
246+ # for now we need to manually deserialize LangChain messages into LangChain Objects
247+ # i.e. we know which objects need to be LC objects
248+ initial_state = initial_state .update (** {
249+ "doc" : _deserialize_document (initial_state ["doc" ])
250+ })
251+ docs = [_deserialize_document (doc ) for doc in initial_state ["relevant_chunks" ]]
252+ initial_state = initial_state .update (** {
253+ "relevant_chunks" : docs
254+ })
255+ entry_point = persisted_state ["position" ]
225256
226257 app = (
227258 ApplicationBuilder ()
@@ -236,16 +267,17 @@ def run(prompt: str, input_key: str, source: str, config: dict) -> str:
236267 ("parse_node" , "rag_node" , default ),
237268 ("rag_node" , "generate_answer_node" , default )
238269 )
239- # .with_entrypoint("fetch_node")
240- # .with_state(**initial_state)
241- .initialize_from (
242- tracker ,
243- resume_at_next_action = True , # always resume from entrypoint in the case of failure
244- default_state = initial_state ,
245- default_entrypoint = "fetch_node" ,
246- )
247- # .with_identifiers(app_id="testing-123456")
248- .with_tracker (project = "smart-scraper-graph" )
270+ .with_entrypoint (entry_point )
271+ .with_state (** initial_state )
272+ # this will work once we get serialization plugin for langchain objects done
273+ # .initialize_from(
274+ # tracker,
275+ # resume_at_next_action=True, # always resume from entrypoint in the case of failure
276+ # default_state=initial_state,
277+ # default_entrypoint="fetch_node",
278+ # )
279+ .with_identifiers (app_id = app_instance_id )
280+ .with_tracker (tracker )
249281 .with_hooks (PrintLnHook ())
250282 .build ()
251283 )
@@ -270,8 +302,8 @@ def run(prompt: str, input_key: str, source: str, config: dict) -> str:
270302 source = "https://en.wikipedia.org/wiki/Paris"
271303 input_key = "url"
272304 config = {
273- "llm_model" : "rag-token " ,
274- "embedder_model" : "foo " ,
305+ "llm_model" : "gpt-3.5-turbo " ,
306+ "embedder_model" : "openai " ,
275307 "model_token" : "bar" ,
276308 }
277- run (prompt , input_key , source , config )
309+ print ( run (prompt , input_key , source , config ) )
0 commit comments