Skip to content

Commit cfaf7ee

Browse files
authored
Merge pull request #284 from DAGWorks-Inc/burr_integration
Updates Burr bridge to use class-based API
2 parents 6cbd84f + d96840f commit cfaf7ee

File tree

2 files changed

+62
-26
lines changed

2 files changed

+62
-26
lines changed

examples/openai/burr_integration_openai.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
"""
44

55
import os
6+
import uuid
7+
68
from dotenv import load_dotenv
79

810
from langchain_openai import OpenAIEmbeddings
@@ -88,7 +90,7 @@
8890
entry_point=fetch_node,
8991
use_burr=True,
9092
burr_config={
91-
"app_instance_id": "custom_graph_openai",
93+
"app_instance_id": str(uuid.uuid4()),
9294
"inputs": {
9395
"llm_model": graph_config["llm"].get("model", "gpt-3.5-turbo"),
9496
}

scrapegraphai/integrations/burr_bridge.py

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,58 @@ class PrintLnHook(PostRunStepHook, PreRunStepHook):
1616
"""
1717
Hook to print the action name before and after it is executed.
1818
"""
19-
19+
2020
def pre_run_step(self, *, state: "State", action: "Action", **future_kwargs: Any):
2121
print(f"Starting action: {action.name}")
2222

2323
def post_run_step(self, *, state: "State", action: "Action", **future_kwargs: Any):
2424
print(f"Finishing action: {action.name}")
2525

26+
27+
class BurrNodeBridge(Action):
28+
"""Bridge class to convert a base graph node to a Burr action.
29+
This is nice because we can dynamically declare the inputs/outputs (and not rely on function-parsing).
30+
"""
31+
32+
def __init__(self, node):
33+
"""Instantiates a BurrNodeBridge object.
34+
"""
35+
super(BurrNodeBridge, self).__init__()
36+
self.node = node
37+
38+
@property
39+
def reads(self) -> list[str]:
40+
return parse_boolean_expression(self.node.input)
41+
42+
def run(self, state: State, **run_kwargs) -> dict:
43+
node_inputs = {key: state[key] for key in self.reads}
44+
result_state = self.node.execute(node_inputs, **run_kwargs)
45+
return result_state
46+
47+
@property
48+
def writes(self) -> list[str]:
49+
return self.node.output
50+
51+
def update(self, result: dict, state: State) -> State:
52+
return state.update(**state)
53+
54+
55+
def parse_boolean_expression(expression: str) -> List[str]:
56+
"""
57+
Parse a boolean expression to extract the keys used in the expression, without boolean operators.
58+
59+
Args:
60+
expression (str): The boolean expression to parse.
61+
62+
Returns:
63+
list: A list of unique keys used in the expression.
64+
"""
65+
66+
# Use regular expression to extract all unique keys
67+
keys = re.findall(r'\w+', expression)
68+
return list(set(keys)) # Remove duplicates
69+
70+
2671
class BurrBridge:
2772
"""
2873
Bridge class to integrate Burr into ScrapeGraphAI graphs.
@@ -106,12 +151,16 @@ def _create_action(self, node) -> Any:
106151
function: The Burr action function.
107152
"""
108153

109-
@action(reads=self._parse_boolean_expression(node.input), writes=node.output)
110-
def dynamic_action(state: State, **kwargs):
111-
node_inputs = {key: state[key] for key in self._parse_boolean_expression(node.input)}
112-
result_state = node.execute(node_inputs, **kwargs)
113-
return result_state, state.update(**result_state)
114-
return dynamic_action
154+
# @action(reads=parse_boolean_expression(node.input), writes=node.output)
155+
# def dynamic_action(state: State, **kwargs):
156+
# node_inputs = {key: state[key] for key in self._parse_boolean_expression(node.input)}
157+
# result_state = node.execute(node_inputs, **kwargs)
158+
# return result_state, state.update(**result_state)
159+
#
160+
# return dynamic_action
161+
# import pdb
162+
# pdb.set_trace()
163+
return BurrNodeBridge(node)
115164

116165
def _create_transitions(self) -> List[Tuple[str, str, Any]]:
117166
"""
@@ -125,22 +174,7 @@ def _create_transitions(self) -> List[Tuple[str, str, Any]]:
125174
for from_node, to_node in self.base_graph.edges.items():
126175
transitions.append((from_node, to_node, default))
127176
return transitions
128-
129-
def _parse_boolean_expression(self, expression: str) -> List[str]:
130-
"""
131-
Parse a boolean expression to extract the keys used in the expression, without boolean operators.
132177

133-
Args:
134-
expression (str): The boolean expression to parse.
135-
136-
Returns:
137-
list: A list of unique keys used in the expression.
138-
"""
139-
140-
# Use regular expression to extract all unique keys
141-
keys = re.findall(r'\w+', expression)
142-
return list(set(keys)) # Remove duplicates
143-
144178
def _convert_state_to_burr(self, state: Dict[str, Any]) -> State:
145179
"""
146180
Convert a dictionary state to a Burr state.
@@ -172,7 +206,7 @@ def _convert_state_from_burr(self, burr_state: State) -> Dict[str, Any]:
172206
for key in burr_state.__dict__.keys():
173207
state[key] = getattr(burr_state, key)
174208
return state
175-
209+
176210
def execute(self, initial_state: Dict[str, Any] = {}) -> Dict[str, Any]:
177211
"""
178212
Execute the Burr application with the given initial state.
@@ -185,7 +219,7 @@ def execute(self, initial_state: Dict[str, Any] = {}) -> Dict[str, Any]:
185219
"""
186220

187221
self.burr_app = self._initialize_burr_app(initial_state)
188-
222+
189223
# TODO: to fix final nodes detection
190224
final_nodes = [self.burr_app.graph.actions[-1].name]
191225

@@ -195,4 +229,4 @@ def execute(self, initial_state: Dict[str, Any] = {}) -> Dict[str, Any]:
195229
inputs=self.burr_inputs
196230
)
197231

198-
return self._convert_state_from_burr(final_state)
232+
return self._convert_state_from_burr(final_state)

0 commit comments

Comments
 (0)