@@ -16,13 +16,58 @@ class PrintLnHook(PostRunStepHook, PreRunStepHook):
1616 """
1717 Hook to print the action name before and after it is executed.
1818 """
19-
19+
2020 def pre_run_step (self , * , state : "State" , action : "Action" , ** future_kwargs : Any ):
2121 print (f"Starting action: { action .name } " )
2222
2323 def post_run_step (self , * , state : "State" , action : "Action" , ** future_kwargs : Any ):
2424 print (f"Finishing action: { action .name } " )
2525
26+
27+ class BurrNodeBridge (Action ):
28+ """Bridge class to convert a base graph node to a Burr action.
29+ This is nice because we can dynamically declare the inputs/outputs (and not rely on function-parsing).
30+ """
31+
32+ def __init__ (self , node ):
33+ """Instantiates a BurrNodeBridge object.
34+ """
35+ super (BurrNodeBridge , self ).__init__ ()
36+ self .node = node
37+
38+ @property
39+ def reads (self ) -> list [str ]:
40+ return parse_boolean_expression (self .node .input )
41+
42+ def run (self , state : State , ** run_kwargs ) -> dict :
43+ node_inputs = {key : state [key ] for key in self .reads }
44+ result_state = self .node .execute (node_inputs , ** run_kwargs )
45+ return result_state
46+
47+ @property
48+ def writes (self ) -> list [str ]:
49+ return self .node .output
50+
51+ def update (self , result : dict , state : State ) -> State :
52+ return state .update (** state )
53+
54+
55+ def parse_boolean_expression (expression : str ) -> List [str ]:
56+ """
57+ Parse a boolean expression to extract the keys used in the expression, without boolean operators.
58+
59+ Args:
60+ expression (str): The boolean expression to parse.
61+
62+ Returns:
63+ list: A list of unique keys used in the expression.
64+ """
65+
66+ # Use regular expression to extract all unique keys
67+ keys = re .findall (r'\w+' , expression )
68+ return list (set (keys )) # Remove duplicates
69+
70+
2671class BurrBridge :
2772 """
2873 Bridge class to integrate Burr into ScrapeGraphAI graphs.
@@ -106,12 +151,16 @@ def _create_action(self, node) -> Any:
106151 function: The Burr action function.
107152 """
108153
109- @action (reads = self ._parse_boolean_expression (node .input ), writes = node .output )
110- def dynamic_action (state : State , ** kwargs ):
111- node_inputs = {key : state [key ] for key in self ._parse_boolean_expression (node .input )}
112- result_state = node .execute (node_inputs , ** kwargs )
113- return result_state , state .update (** result_state )
114- return dynamic_action
154+ # @action(reads=parse_boolean_expression(node.input), writes=node.output)
155+ # def dynamic_action(state: State, **kwargs):
156+ # node_inputs = {key: state[key] for key in self._parse_boolean_expression(node.input)}
157+ # result_state = node.execute(node_inputs, **kwargs)
158+ # return result_state, state.update(**result_state)
159+ #
160+ # return dynamic_action
161+ # import pdb
162+ # pdb.set_trace()
163+ return BurrNodeBridge (node )
115164
116165 def _create_transitions (self ) -> List [Tuple [str , str , Any ]]:
117166 """
@@ -125,22 +174,7 @@ def _create_transitions(self) -> List[Tuple[str, str, Any]]:
125174 for from_node , to_node in self .base_graph .edges .items ():
126175 transitions .append ((from_node , to_node , default ))
127176 return transitions
128-
129- def _parse_boolean_expression (self , expression : str ) -> List [str ]:
130- """
131- Parse a boolean expression to extract the keys used in the expression, without boolean operators.
132177
133- Args:
134- expression (str): The boolean expression to parse.
135-
136- Returns:
137- list: A list of unique keys used in the expression.
138- """
139-
140- # Use regular expression to extract all unique keys
141- keys = re .findall (r'\w+' , expression )
142- return list (set (keys )) # Remove duplicates
143-
144178 def _convert_state_to_burr (self , state : Dict [str , Any ]) -> State :
145179 """
146180 Convert a dictionary state to a Burr state.
@@ -172,7 +206,7 @@ def _convert_state_from_burr(self, burr_state: State) -> Dict[str, Any]:
172206 for key in burr_state .__dict__ .keys ():
173207 state [key ] = getattr (burr_state , key )
174208 return state
175-
209+
176210 def execute (self , initial_state : Dict [str , Any ] = {}) -> Dict [str , Any ]:
177211 """
178212 Execute the Burr application with the given initial state.
@@ -185,7 +219,7 @@ def execute(self, initial_state: Dict[str, Any] = {}) -> Dict[str, Any]:
185219 """
186220
187221 self .burr_app = self ._initialize_burr_app (initial_state )
188-
222+
189223 # TODO: to fix final nodes detection
190224 final_nodes = [self .burr_app .graph .actions [- 1 ].name ]
191225
@@ -195,4 +229,4 @@ def execute(self, initial_state: Dict[str, Any] = {}) -> Dict[str, Any]:
195229 inputs = self .burr_inputs
196230 )
197231
198- return self ._convert_state_from_burr (final_state )
232+ return self ._convert_state_from_burr (final_state )
0 commit comments