Skip to content

Commit b378ca7

Browse files
authored
Python: extend HITL support for all orchestration patterns (#2620)
* Support HITL for orchestration patterns * Cleanup around naming * Fix typing issues * Clean up * Naming clean up * Updates to HITL to make it cleaner * Rename human input hook to orchestration request info * Clean up per PR feedback
1 parent 0d9ae19 commit b378ca7

23 files changed

+2186
-36
lines changed

python/packages/core/agent_framework/_workflows/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
MagenticStallInterventionRequest,
8686
StandardMagenticManager,
8787
)
88+
from ._orchestration_request_info import AgentInputRequest, AgentResponseReviewRequest, RequestInfoInterceptor
8889
from ._orchestration_state import OrchestrationState
8990
from ._request_info_mixin import response_handler
9091
from ._runner import Runner
@@ -122,6 +123,8 @@
122123
"AgentExecutor",
123124
"AgentExecutorRequest",
124125
"AgentExecutorResponse",
126+
"AgentInputRequest",
127+
"AgentResponseReviewRequest",
125128
"AgentRunEvent",
126129
"AgentRunUpdateEvent",
127130
"Case",
@@ -164,6 +167,7 @@
164167
"Message",
165168
"OrchestrationState",
166169
"RequestInfoEvent",
170+
"RequestInfoInterceptor",
167171
"Runner",
168172
"RunnerContext",
169173
"SequentialBuilder",

python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ def __init__(self, executor_id: str) -> None:
4747
self._max_rounds: int | None = None
4848
self._termination_condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]] | None = None
4949

50-
def register_participant_entry(self, name: str, *, entry_id: str, is_agent: bool) -> None:
50+
def register_participant_entry(
51+
self, name: str, *, entry_id: str, is_agent: bool, exit_id: str | None = None
52+
) -> None:
5153
"""Record routing details for a participant's entry executor.
5254
5355
This method provides a unified interface for registering participants
@@ -57,8 +59,10 @@ def register_participant_entry(self, name: str, *, entry_id: str, is_agent: bool
5759
name: Participant name (used for selection and tracking)
5860
entry_id: Executor ID for this participant's entry point
5961
is_agent: Whether this is an AgentExecutor (True) or custom Executor (False)
62+
exit_id: Executor ID for this participant's exit point (where responses come from).
63+
If None, defaults to entry_id.
6064
"""
61-
self._registry.register(name, entry_id=entry_id, is_agent=is_agent)
65+
self._registry.register(name, entry_id=entry_id, is_agent=is_agent, exit_id=exit_id)
6266

6367
# Conversation state management (shared across all patterns)
6468

python/packages/core/agent_framework/_workflows/_concurrent.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ._checkpoint import CheckpointStorage
1515
from ._executor import Executor, handler
1616
from ._message_utils import normalize_messages_input
17+
from ._orchestration_request_info import RequestInfoInterceptor
1718
from ._workflow import Workflow
1819
from ._workflow_builder import WorkflowBuilder
1920
from ._workflow_context import WorkflowContext
@@ -209,15 +210,18 @@ def summarize(results):
209210
210211
workflow = ConcurrentBuilder().participants([agent1, agent2, agent3]).with_custom_aggregator(summarize).build()
211212
212-
213213
# Enable checkpoint persistence so runs can resume
214214
workflow = ConcurrentBuilder().participants([agent1, agent2, agent3]).with_checkpointing(storage).build()
215+
216+
# Enable request info before aggregation
217+
workflow = ConcurrentBuilder().participants([agent1, agent2]).with_request_info().build()
215218
"""
216219

217220
def __init__(self) -> None:
218221
self._participants: list[AgentProtocol | Executor] = []
219222
self._aggregator: Executor | None = None
220223
self._checkpoint_storage: CheckpointStorage | None = None
224+
self._request_info_enabled: bool = False
221225

222226
def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "ConcurrentBuilder":
223227
r"""Define the parallel participants for this concurrent workflow.
@@ -296,12 +300,33 @@ def with_checkpointing(self, checkpoint_storage: CheckpointStorage) -> "Concurre
296300
self._checkpoint_storage = checkpoint_storage
297301
return self
298302

303+
def with_request_info(self) -> "ConcurrentBuilder":
304+
"""Enable request info before aggregation in the workflow.
305+
306+
When enabled, the workflow pauses after all parallel agents complete,
307+
emitting a RequestInfoEvent that allows the caller to review and optionally
308+
modify the combined results before aggregation. The caller provides feedback
309+
via the standard response_handler/request_info pattern.
310+
311+
Note:
312+
Unlike SequentialBuilder and GroupChatBuilder, ConcurrentBuilder does not
313+
support per-agent filtering since all agents run in parallel and results
314+
are collected together. The pause occurs once with all agent outputs received.
315+
316+
Returns:
317+
self: The builder instance for fluent chaining.
318+
"""
319+
self._request_info_enabled = True
320+
return self
321+
299322
def build(self) -> Workflow:
300323
r"""Build and validate the concurrent workflow.
301324
302325
Wiring pattern:
303326
- Dispatcher (internal) fans out the input to all `participants`
304-
- Fan-in aggregator collects `AgentExecutorResponse` objects
327+
- Fan-in collects `AgentExecutorResponse` objects from all participants
328+
- If request info is enabled, the orchestration emits a request info event with outputs from all participants
329+
before sending the outputs to the aggregator
305330
- Aggregator yields output and the workflow becomes idle. The output is either:
306331
- list[ChatMessage] (default aggregator: one user + one assistant per agent)
307332
- custom payload from the provided callback/executor
@@ -327,7 +352,16 @@ def build(self) -> Workflow:
327352
builder = WorkflowBuilder()
328353
builder.set_start_executor(dispatcher)
329354
builder.add_fan_out_edges(dispatcher, list(self._participants))
330-
builder.add_fan_in_edges(list(self._participants), aggregator)
355+
356+
if self._request_info_enabled:
357+
# Insert interceptor between fan-in and aggregator
358+
# participants -> fan-in -> interceptor -> aggregator
359+
request_info_interceptor = RequestInfoInterceptor(executor_id="request_info")
360+
builder.add_fan_in_edges(list(self._participants), request_info_interceptor)
361+
builder.add_edge(request_info_interceptor, aggregator)
362+
else:
363+
# Direct fan-in to aggregator
364+
builder.add_fan_in_edges(list(self._participants), aggregator)
331365

332366
if self._checkpoint_storage is not None:
333367
builder = builder.with_checkpointing(self._checkpoint_storage)

0 commit comments

Comments
 (0)