Skip to content

Commit 78b245f

Browse files
Lancetnikclaudemarklysze
authored
feat(beta): add on_human_input middleware hook (#2477)
* feat(beta): add on_human_input middleware hook Add a new `on_human_input` lifecycle hook to `BaseMiddleware`, allowing middleware to intercept, modify, or short-circuit human-in-the-loop requests and responses. Wire the hook through `wrap_hitl` and update docs and exports accordingly. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * docs(beta): update roadmap details and fix tools.mdx formatting Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * refactor(beta): replace _CapturingStream stub with MemoryStream in client tool tests Use MemoryStream and MagicMock instead of a custom _CapturingStream helper, making tests exercise the real stream infrastructure. * refactor(beta): extract shared ClientTool fixture in client tool tests Deduplicate schema/ClientTool construction into a pytest fixture and unify tool names across all three tests. --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: Mark Sze <66362098+marklysze@users.noreply.github.com>
1 parent 95f6399 commit 78b245f

File tree

9 files changed

+338
-76
lines changed

9 files changed

+338
-76
lines changed

autogen/beta/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ async def _call_client(context: Context) -> None:
328328
)
329329

330330
stack.enter_context(
331-
context.stream.where(HumanInputRequest).sub_scope(self.__hitl_hook),
331+
context.stream.where(HumanInputRequest).sub_scope(self.__hitl_hook(middleware_instances)),
332332
)
333333

334334
self.__tool_executor.register(

autogen/beta/hitl.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,28 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from collections.abc import Awaitable, Callable
5+
from collections.abc import Awaitable, Callable, Iterable
66
from contextlib import AsyncExitStack
7+
from functools import partial
78
from typing import TypeAlias
89

910
from .annotations import Context
1011
from .events import HumanInputRequest, HumanMessage
1112
from .exceptions import HumanInputNotProvidedError
13+
from .middleware.base import BaseMiddleware, HumanInputHook
1214
from .utils import CONTEXT_OPTION_NAME, build_model
1315

1416
HumanHook: TypeAlias = Callable[..., HumanMessage] | Callable[..., Awaitable[HumanMessage]]
1517

18+
HitlExecution: TypeAlias = Callable[[HumanInputRequest, Context], Awaitable[None]]
1619

17-
def wrap_hitl(func: HumanHook) -> None:
20+
21+
def wrap_hitl(
22+
func: HumanHook,
23+
) -> Callable[[Iterable["BaseMiddleware"]], HitlExecution]:
1824
call_model = build_model(func)
1925

20-
async def wrapper(event: HumanInputRequest, context: Context) -> None:
26+
async def _call_model(event: HumanInputRequest, context: Context) -> HumanMessage:
2127
async with AsyncExitStack() as stack:
2228
event = await call_model.asolve(
2329
event,
@@ -26,10 +32,24 @@ async def wrapper(event: HumanInputRequest, context: Context) -> None:
2632
dependency_provider=context.dependency_provider,
2733
**{CONTEXT_OPTION_NAME: context},
2834
)
29-
await context.send(event)
35+
return event
36+
37+
def make_hook(middlewares: Iterable["BaseMiddleware"]) -> HitlExecution:
38+
ask_user: HumanInputHook = _call_model
39+
for middleware in middlewares:
40+
ask_user = partial(middleware.on_human_input, ask_user)
41+
42+
async def wrapper(event: HumanInputRequest, context: Context) -> None:
43+
event = await ask_user(event, context)
44+
await context.send(event)
45+
46+
return wrapper
47+
48+
return make_hook
3049

31-
return wrapper
3250

51+
def default_hitl_hook(middlewares: Iterable["BaseMiddleware"]) -> HitlExecution:
52+
async def _call_model(event: HumanInputRequest, context: Context) -> None:
53+
raise HumanInputNotProvidedError
3354

34-
def default_hitl_hook() -> HumanMessage:
35-
raise HumanInputNotProvidedError
55+
return _call_model

autogen/beta/middleware/__init__.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,27 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from .base import AgentTurn, BaseMiddleware, LLMCall, Middleware, ToolExecution, ToolResultType
6-
from .builtin import HistoryLimiter, LoggingMiddleware, RetryMiddleware, TokenLimiter
5+
from .base import (
6+
AgentTurn,
7+
BaseMiddleware,
8+
HumanInputHook,
9+
LLMCall,
10+
Middleware,
11+
ToolExecution,
12+
ToolResultType,
13+
)
14+
from .builtin import (
15+
HistoryLimiter,
16+
LoggingMiddleware,
17+
RetryMiddleware,
18+
TokenLimiter,
19+
)
720

821
__all__ = (
922
"AgentTurn",
1023
"BaseMiddleware",
1124
"HistoryLimiter",
25+
"HumanInputHook",
1226
"LLMCall",
1327
"LoggingMiddleware",
1428
"Middleware",

autogen/beta/middleware/base.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,16 @@
66
from typing import Any, Protocol, TypeAlias
77

88
from autogen.beta.annotations import Context
9-
from autogen.beta.events import BaseEvent, ClientToolCall, ModelResponse, ToolCall, ToolError, ToolResult
9+
from autogen.beta.events import (
10+
BaseEvent,
11+
ClientToolCall,
12+
HumanInputRequest,
13+
HumanMessage,
14+
ModelResponse,
15+
ToolCall,
16+
ToolError,
17+
ToolResult,
18+
)
1019

1120

1221
class MiddlewareFactory(Protocol):
@@ -36,6 +45,7 @@ def __call__(
3645
AgentTurn: TypeAlias = Callable[["BaseEvent", "Context"], Awaitable["ModelResponse"]]
3746
ToolExecution: TypeAlias = Callable[["ToolCall", "Context"], Awaitable[ToolResultType]]
3847
LLMCall: TypeAlias = Callable[["Sequence[BaseEvent]", "Context"], Awaitable["ModelResponse"]]
48+
HumanInputHook: TypeAlias = Callable[["HumanInputRequest", "Context"], Awaitable["HumanMessage"]]
3949

4050

4151
class BaseMiddleware:
@@ -70,3 +80,11 @@ async def on_llm_call(
7080
context: "Context",
7181
) -> "ModelResponse":
7282
return await call_next(events, context)
83+
84+
async def on_human_input(
85+
self,
86+
call_next: HumanInputHook,
87+
event: "HumanInputRequest",
88+
context: "Context",
89+
) -> "HumanMessage":
90+
return await call_next(event, context)
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# Copyright (c) 2023 - 2026, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from unittest.mock import MagicMock
6+
7+
import pytest
8+
9+
from autogen.beta import Agent, Context
10+
from autogen.beta.events import BaseEvent, HumanInputRequest, HumanMessage, ToolCall
11+
from autogen.beta.middleware import BaseMiddleware, Middleware
12+
from autogen.beta.middleware.base import HumanInputHook
13+
from autogen.beta.testing import TestConfig
14+
15+
16+
@pytest.fixture()
17+
def test_config() -> TestConfig:
18+
return TestConfig(
19+
ToolCall(name="my_tool"),
20+
"result",
21+
)
22+
23+
24+
class MockHumanInputMiddleware(BaseMiddleware):
25+
def __init__(
26+
self,
27+
event: BaseEvent,
28+
ctx: Context,
29+
mock: MagicMock,
30+
) -> None:
31+
super().__init__(event, ctx)
32+
self.mock = mock
33+
34+
async def on_human_input(
35+
self,
36+
call_next: HumanInputHook,
37+
event: HumanInputRequest,
38+
ctx: Context,
39+
) -> HumanMessage:
40+
self.mock.enter(event.content)
41+
result = await call_next(event, ctx)
42+
self.mock.exit(result.content)
43+
return result
44+
45+
46+
@pytest.mark.asyncio()
47+
async def test_human_input_middleware(mock: MagicMock, test_config: TestConfig) -> None:
48+
async def my_tool(ctx: Context) -> str:
49+
await ctx.input("Say smth", timeout=1.0)
50+
return ""
51+
52+
def hitl_hook(event: HumanInputRequest) -> HumanMessage:
53+
return HumanMessage(content="answer")
54+
55+
agent = Agent(
56+
"",
57+
config=test_config,
58+
tools=[my_tool],
59+
hitl_hook=hitl_hook,
60+
middleware=[Middleware(MockHumanInputMiddleware, mock=mock)],
61+
)
62+
63+
await agent.ask("Hi!")
64+
65+
mock.enter.assert_called_once_with("Say smth")
66+
mock.exit.assert_called_once_with("answer")
67+
68+
69+
class OrderingHumanInputMiddleware(BaseMiddleware):
70+
def __init__(
71+
self,
72+
event: BaseEvent,
73+
ctx: Context,
74+
mock: MagicMock,
75+
position: int,
76+
) -> None:
77+
super().__init__(event, ctx)
78+
self.mock = mock
79+
self.position = position
80+
81+
async def on_human_input(
82+
self,
83+
call_next: HumanInputHook,
84+
event: HumanInputRequest,
85+
ctx: Context,
86+
) -> HumanMessage:
87+
self.mock.enter(self.position)
88+
result = await call_next(event, ctx)
89+
self.mock.exit(self.position)
90+
return result
91+
92+
93+
@pytest.mark.asyncio()
94+
async def test_human_input_middleware_call_sequence(mock: MagicMock, test_config: TestConfig) -> None:
95+
async def my_tool(ctx: Context) -> str:
96+
await ctx.input("Say smth", timeout=1.0)
97+
return ""
98+
99+
def hitl_hook(event: HumanInputRequest) -> HumanMessage:
100+
return HumanMessage(content="answer")
101+
102+
agent = Agent(
103+
"",
104+
config=test_config,
105+
tools=[my_tool],
106+
hitl_hook=hitl_hook,
107+
middleware=[Middleware(OrderingHumanInputMiddleware, mock=mock, position=i) for i in range(1, 4)],
108+
)
109+
110+
await agent.ask("Hi!")
111+
112+
assert [c.args[0] for c in mock.enter.call_args_list] == [1, 2, 3]
113+
assert [c.args[0] for c in mock.exit.call_args_list] == [3, 2, 1]
114+
115+
116+
@pytest.mark.asyncio()
117+
async def test_human_input_middleware_mutates_request(mock: MagicMock, test_config: TestConfig) -> None:
118+
class MutatingMiddleware(BaseMiddleware):
119+
async def on_human_input(
120+
self,
121+
call_next: HumanInputHook,
122+
event: HumanInputRequest,
123+
ctx: Context,
124+
) -> HumanMessage:
125+
event = HumanInputRequest(content=event.content + "!")
126+
return await call_next(event, ctx)
127+
128+
async def my_tool(ctx: Context) -> str:
129+
await ctx.input("Say smth", timeout=1.0)
130+
return ""
131+
132+
def hitl_hook(event: HumanInputRequest) -> HumanMessage:
133+
mock.hitl(event.content)
134+
return HumanMessage(content="answer")
135+
136+
agent = Agent(
137+
"",
138+
config=test_config,
139+
tools=[my_tool],
140+
hitl_hook=hitl_hook,
141+
middleware=[MutatingMiddleware, MutatingMiddleware, MutatingMiddleware],
142+
)
143+
144+
await agent.ask("Hi!")
145+
146+
mock.hitl.assert_called_once_with("Say smth!!!")
147+
148+
149+
@pytest.mark.asyncio()
150+
async def test_human_input_middleware_mutates_response(mock: MagicMock, test_config: TestConfig) -> None:
151+
class MutatingMiddleware(BaseMiddleware):
152+
async def on_human_input(
153+
self,
154+
call_next: HumanInputHook,
155+
event: HumanInputRequest,
156+
ctx: Context,
157+
) -> HumanMessage:
158+
result = await call_next(event, ctx)
159+
return HumanMessage(content=result.content + "!")
160+
161+
async def my_tool(ctx: Context) -> str:
162+
mock(await ctx.input("Say smth", timeout=1.0))
163+
return ""
164+
165+
def hitl_hook(event: HumanInputRequest) -> HumanMessage:
166+
return HumanMessage(content="answer")
167+
168+
agent = Agent(
169+
"",
170+
config=test_config,
171+
tools=[my_tool],
172+
hitl_hook=hitl_hook,
173+
middleware=[MutatingMiddleware, MutatingMiddleware, MutatingMiddleware],
174+
)
175+
176+
await agent.ask("Hi!")
177+
178+
mock.assert_called_once_with("answer!!!")
179+
180+
181+
@pytest.mark.asyncio()
182+
async def test_human_input_middleware_short_circuits(mock: MagicMock, test_config: TestConfig) -> None:
183+
class ShortCircuitMiddleware(BaseMiddleware):
184+
async def on_human_input(
185+
self,
186+
call_next: HumanInputHook,
187+
event: HumanInputRequest,
188+
ctx: Context,
189+
) -> HumanMessage:
190+
mock.intercepted(event.content)
191+
return HumanMessage(content="intercepted")
192+
193+
async def my_tool(ctx: Context) -> str:
194+
mock(await ctx.input("Say smth", timeout=1.0))
195+
return ""
196+
197+
def hitl_hook(event: HumanInputRequest) -> HumanMessage:
198+
mock.hitl()
199+
return HumanMessage(content="answer")
200+
201+
agent = Agent(
202+
"",
203+
config=test_config,
204+
tools=[my_tool],
205+
hitl_hook=hitl_hook,
206+
middleware=[Middleware(ShortCircuitMiddleware)],
207+
)
208+
209+
await agent.ask("Hi!")
210+
211+
mock.intercepted.assert_called_once_with("Say smth")
212+
mock.hitl.assert_not_called()
213+
mock.assert_called_once_with("intercepted")

0 commit comments

Comments
 (0)