Skip to content

Commit 1f64b4b

Browse files
authored
feat(multiagent): introduce Swarm multi-agent orchestrator (strands-agents#416)
1 parent 089ccb3 commit 1f64b4b

File tree

9 files changed

+1305
-35
lines changed

9 files changed

+1305
-35
lines changed

src/strands/multiagent/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@
1010

1111
from .base import MultiAgentBase, MultiAgentResult
1212
from .graph import GraphBuilder, GraphResult
13+
from .swarm import Swarm, SwarmResult
1314

1415
__all__ = [
1516
"GraphBuilder",
1617
"GraphResult",
1718
"MultiAgentBase",
1819
"MultiAgentResult",
20+
"Swarm",
21+
"SwarmResult",
1922
]

src/strands/multiagent/base.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,15 @@ def get_agent_results(self) -> list[AgentResult]:
5959

6060
@dataclass
6161
class MultiAgentResult:
62-
"""Result from multi-agent execution with accumulated metrics."""
62+
"""Result from multi-agent execution with accumulated metrics.
6363
64-
results: dict[str, NodeResult]
64+
The status field represents the outcome of the MultiAgentBase execution:
65+
- COMPLETED: The execution was successfully accomplished
66+
- FAILED: The execution failed or produced an error
67+
"""
68+
69+
status: Status = Status.PENDING
70+
results: dict[str, NodeResult] = field(default_factory=lambda: {})
6571
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
6672
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
6773
execution_count: int = 0
@@ -76,11 +82,11 @@ class MultiAgentBase(ABC):
7682
"""
7783

7884
@abstractmethod
79-
async def execute_async(self, task: str | list[ContentBlock], **kwargs: Any) -> MultiAgentResult:
80-
"""Execute task asynchronously."""
81-
raise NotImplementedError("execute_async not implemented")
85+
async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> MultiAgentResult:
86+
"""Invoke asynchronously."""
87+
raise NotImplementedError("invoke_async not implemented")
8288

8389
@abstractmethod
84-
def execute(self, task: str | list[ContentBlock], **kwargs: Any) -> MultiAgentResult:
85-
"""Execute task synchronously."""
86-
raise NotImplementedError("execute not implemented")
90+
def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> MultiAgentResult:
91+
"""Invoke synchronously."""
92+
raise NotImplementedError("__call__ not implemented")

src/strands/multiagent/graph.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,8 @@ class GraphState:
7272

7373
@dataclass
7474
class GraphResult(MultiAgentResult):
75-
"""Result from graph execution - extends MultiAgentResult with graph-specific details.
75+
"""Result from graph execution - extends MultiAgentResult with graph-specific details."""
7676

77-
The status field represents the outcome of the graph execution:
78-
- COMPLETED: The graph execution was successfully accomplished
79-
- FAILED: The graph execution failed or produced an error
80-
"""
81-
82-
status: Status = Status.PENDING
8377
total_nodes: int = 0
8478
completed_nodes: int = 0
8579
failed_nodes: int = 0
@@ -146,6 +140,11 @@ def __init__(self) -> None:
146140

147141
def add_node(self, executor: Agent | MultiAgentBase, node_id: str | None = None) -> GraphNode:
148142
"""Add an Agent or MultiAgentBase instance as a node to the graph."""
143+
# Check for duplicate node instances
144+
seen_instances = {id(node.executor) for node in self.nodes.values()}
145+
if id(executor) in seen_instances:
146+
raise ValueError("Duplicate node instance detected. Each node must have a unique object instance.")
147+
149148
# Auto-generate node_id if not provided
150149
if node_id is None:
151150
node_id = getattr(executor, "id", None) or getattr(executor, "name", None) or f"node_{len(self.nodes)}"
@@ -248,24 +247,27 @@ def __init__(self, nodes: dict[str, GraphNode], edges: set[GraphEdge], entry_poi
248247
"""Initialize Graph."""
249248
super().__init__()
250249

250+
# Validate nodes for duplicate instances
251+
self._validate_graph(nodes)
252+
251253
self.nodes = nodes
252254
self.edges = edges
253255
self.entry_points = entry_points
254256
self.state = GraphState()
255257
self.tracer = get_tracer()
256258

257-
def execute(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult:
258-
"""Execute task synchronously."""
259+
def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult:
260+
"""Invoke the graph synchronously."""
259261

260262
def execute() -> GraphResult:
261-
return asyncio.run(self.execute_async(task))
263+
return asyncio.run(self.invoke_async(task))
262264

263265
with ThreadPoolExecutor() as executor:
264266
future = executor.submit(execute)
265267
return future.result()
266268

267-
async def execute_async(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult:
268-
"""Execute the graph asynchronously."""
269+
async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult:
270+
"""Invoke the graph asynchronously."""
269271
logger.debug("task=<%s> | starting graph execution", task)
270272

271273
# Initialize state
@@ -293,6 +295,15 @@ async def execute_async(self, task: str | list[ContentBlock], **kwargs: Any) ->
293295
self.state.execution_time = round((time.time() - start_time) * 1000)
294296
return self._build_result()
295297

298+
def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
299+
"""Validate graph nodes for duplicate instances."""
300+
# Check for duplicate node instances
301+
seen_instances = set()
302+
for node in nodes.values():
303+
if id(node.executor) in seen_instances:
304+
raise ValueError("Duplicate node instance detected. Each node must have a unique object instance.")
305+
seen_instances.add(id(node.executor))
306+
296307
async def _execute_graph(self) -> None:
297308
"""Unified execution flow with conditional routing."""
298309
ready_nodes = list(self.entry_points)
@@ -355,7 +366,7 @@ async def _execute_node(self, node: GraphNode) -> None:
355366

356367
# Execute based on node type and create unified NodeResult
357368
if isinstance(node.executor, MultiAgentBase):
358-
multi_agent_result = await node.executor.execute_async(node_input)
369+
multi_agent_result = await node.executor.invoke_async(node_input)
359370

360371
# Create NodeResult with MultiAgentResult directly
361372
node_result = NodeResult(
@@ -444,7 +455,22 @@ def _accumulate_metrics(self, node_result: NodeResult) -> None:
444455
self.state.execution_count += node_result.execution_count
445456

446457
def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
447-
"""Build input for a node based on dependency outputs."""
458+
"""Build input text for a node based on dependency outputs.
459+
460+
Example formatted output:
461+
```
462+
Original Task: Analyze the quarterly sales data and create a summary report
463+
464+
Inputs from previous nodes:
465+
466+
From data_processor:
467+
- Agent: Sales data processed successfully. Found 1,247 transactions totaling $89,432.
468+
- Agent: Key trends: 15% increase in Q3, top product category is Electronics.
469+
470+
From validator:
471+
- Agent: Data validation complete. All records verified, no anomalies detected.
472+
```
473+
"""
448474
# Get satisfied dependencies
449475
dependency_results = {}
450476
for edge in self.edges:
@@ -491,12 +517,12 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
491517
def _build_result(self) -> GraphResult:
492518
"""Build graph result from current state."""
493519
return GraphResult(
520+
status=self.state.status,
494521
results=self.state.results,
495522
accumulated_usage=self.state.accumulated_usage,
496523
accumulated_metrics=self.state.accumulated_metrics,
497524
execution_count=self.state.execution_count,
498525
execution_time=self.state.execution_time,
499-
status=self.state.status,
500526
total_nodes=self.state.total_nodes,
501527
completed_nodes=len(self.state.completed_nodes),
502528
failed_nodes=len(self.state.failed_nodes),

0 commit comments

Comments
 (0)