@@ -72,14 +72,8 @@ class GraphState:
72
72
73
73
@dataclass
74
74
class GraphResult (MultiAgentResult ):
75
- """Result from graph execution - extends MultiAgentResult with graph-specific details.
75
+ """Result from graph execution - extends MultiAgentResult with graph-specific details."""
76
76
77
- The status field represents the outcome of the graph execution:
78
- - COMPLETED: The graph execution was successfully accomplished
79
- - FAILED: The graph execution failed or produced an error
80
- """
81
-
82
- status : Status = Status .PENDING
83
77
total_nodes : int = 0
84
78
completed_nodes : int = 0
85
79
failed_nodes : int = 0
@@ -146,6 +140,11 @@ def __init__(self) -> None:
146
140
147
141
def add_node (self , executor : Agent | MultiAgentBase , node_id : str | None = None ) -> GraphNode :
148
142
"""Add an Agent or MultiAgentBase instance as a node to the graph."""
143
+ # Check for duplicate node instances
144
+ seen_instances = {id (node .executor ) for node in self .nodes .values ()}
145
+ if id (executor ) in seen_instances :
146
+ raise ValueError ("Duplicate node instance detected. Each node must have a unique object instance." )
147
+
149
148
# Auto-generate node_id if not provided
150
149
if node_id is None :
151
150
node_id = getattr (executor , "id" , None ) or getattr (executor , "name" , None ) or f"node_{ len (self .nodes )} "
@@ -248,24 +247,27 @@ def __init__(self, nodes: dict[str, GraphNode], edges: set[GraphEdge], entry_poi
248
247
"""Initialize Graph."""
249
248
super ().__init__ ()
250
249
250
+ # Validate nodes for duplicate instances
251
+ self ._validate_graph (nodes )
252
+
251
253
self .nodes = nodes
252
254
self .edges = edges
253
255
self .entry_points = entry_points
254
256
self .state = GraphState ()
255
257
self .tracer = get_tracer ()
256
258
257
- def execute (self , task : str | list [ContentBlock ], ** kwargs : Any ) -> GraphResult :
258
- """Execute task synchronously."""
259
+ def __call__ (self , task : str | list [ContentBlock ], ** kwargs : Any ) -> GraphResult :
260
+ """Invoke the graph synchronously."""
259
261
260
262
def execute () -> GraphResult :
261
- return asyncio .run (self .execute_async (task ))
263
+ return asyncio .run (self .invoke_async (task ))
262
264
263
265
with ThreadPoolExecutor () as executor :
264
266
future = executor .submit (execute )
265
267
return future .result ()
266
268
267
- async def execute_async (self , task : str | list [ContentBlock ], ** kwargs : Any ) -> GraphResult :
268
- """Execute the graph asynchronously."""
269
+ async def invoke_async (self , task : str | list [ContentBlock ], ** kwargs : Any ) -> GraphResult :
270
+ """Invoke the graph asynchronously."""
269
271
logger .debug ("task=<%s> | starting graph execution" , task )
270
272
271
273
# Initialize state
@@ -293,6 +295,15 @@ async def execute_async(self, task: str | list[ContentBlock], **kwargs: Any) ->
293
295
self .state .execution_time = round ((time .time () - start_time ) * 1000 )
294
296
return self ._build_result ()
295
297
298
+ def _validate_graph (self , nodes : dict [str , GraphNode ]) -> None :
299
+ """Validate graph nodes for duplicate instances."""
300
+ # Check for duplicate node instances
301
+ seen_instances = set ()
302
+ for node in nodes .values ():
303
+ if id (node .executor ) in seen_instances :
304
+ raise ValueError ("Duplicate node instance detected. Each node must have a unique object instance." )
305
+ seen_instances .add (id (node .executor ))
306
+
296
307
async def _execute_graph (self ) -> None :
297
308
"""Unified execution flow with conditional routing."""
298
309
ready_nodes = list (self .entry_points )
@@ -355,7 +366,7 @@ async def _execute_node(self, node: GraphNode) -> None:
355
366
356
367
# Execute based on node type and create unified NodeResult
357
368
if isinstance (node .executor , MultiAgentBase ):
358
- multi_agent_result = await node .executor .execute_async (node_input )
369
+ multi_agent_result = await node .executor .invoke_async (node_input )
359
370
360
371
# Create NodeResult with MultiAgentResult directly
361
372
node_result = NodeResult (
@@ -444,7 +455,22 @@ def _accumulate_metrics(self, node_result: NodeResult) -> None:
444
455
self .state .execution_count += node_result .execution_count
445
456
446
457
def _build_node_input (self , node : GraphNode ) -> list [ContentBlock ]:
447
- """Build input for a node based on dependency outputs."""
458
+ """Build input text for a node based on dependency outputs.
459
+
460
+ Example formatted output:
461
+ ```
462
+ Original Task: Analyze the quarterly sales data and create a summary report
463
+
464
+ Inputs from previous nodes:
465
+
466
+ From data_processor:
467
+ - Agent: Sales data processed successfully. Found 1,247 transactions totaling $89,432.
468
+ - Agent: Key trends: 15% increase in Q3, top product category is Electronics.
469
+
470
+ From validator:
471
+ - Agent: Data validation complete. All records verified, no anomalies detected.
472
+ ```
473
+ """
448
474
# Get satisfied dependencies
449
475
dependency_results = {}
450
476
for edge in self .edges :
@@ -491,12 +517,12 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
491
517
def _build_result (self ) -> GraphResult :
492
518
"""Build graph result from current state."""
493
519
return GraphResult (
520
+ status = self .state .status ,
494
521
results = self .state .results ,
495
522
accumulated_usage = self .state .accumulated_usage ,
496
523
accumulated_metrics = self .state .accumulated_metrics ,
497
524
execution_count = self .state .execution_count ,
498
525
execution_time = self .state .execution_time ,
499
- status = self .state .status ,
500
526
total_nodes = self .state .total_nodes ,
501
527
completed_nodes = len (self .state .completed_nodes ),
502
528
failed_nodes = len (self .state .failed_nodes ),
0 commit comments