Skip to content

Commit 6d10a8f

Browse files
committed
Refactor ProxyAgent to extend BaseAgent and update orchestration
ProxyAgent now extends agent_framework's BaseAgent and implements the AgentProtocol with run and run_stream methods, providing improved compliance and streaming support. The orchestration_manager now extracts the inner ChatAgent from wrapper templates, ensuring correct agent registration for both wrapped and direct BaseAgent instances.
1 parent 77bddf0 commit 6d10a8f

File tree

2 files changed

+173
-113
lines changed

2 files changed

+173
-113
lines changed
Lines changed: 160 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""
2-
ProxyAgentAF: Human clarification proxy implemented on agent_framework primitives.
2+
ProxyAgent: Human clarification proxy compliant with agent_framework.
33
44
Responsibilities:
55
- Request clarification from a human via websocket
6-
- Await response (with timeout + cancellation handling via orchestration_config)
7-
- Yield ChatResponseUpdate objects compatible with agent_framework streaming loops
6+
- Await response (with timeout + cancellation handling)
7+
- Yield AgentRunResponseUpdate objects compatible with agent_framework
88
"""
99

1010
from __future__ import annotations
@@ -13,16 +13,20 @@
1313
import logging
1414
import time
1515
import uuid
16-
from dataclasses import dataclass, field
17-
from typing import AsyncIterator, List, Optional
16+
from typing import Any, AsyncIterable
1817

1918
from agent_framework import (
20-
ChatResponseUpdate,
19+
AgentRunResponse,
20+
AgentRunResponseUpdate,
21+
BaseAgent,
22+
ChatMessage,
2123
Role,
2224
TextContent,
2325
UsageContent,
2426
UsageDetails,
27+
AgentThread,
2528
)
29+
2630
from af.config.settings import connection_config, orchestration_config
2731
from af.models.messages import (
2832
UserClarificationRequest,
@@ -34,73 +38,110 @@
3438
logger = logging.getLogger(__name__)
3539

3640

37-
# ---------------------------------------------------------------------------
38-
# Internal conversation structure (minimal alternative to SK AgentThread)
39-
# ---------------------------------------------------------------------------
40-
41-
@dataclass
42-
class ProxyConversation:
43-
conversation_id: str = field(default_factory=lambda: f"proxy_{uuid.uuid4().hex}")
44-
messages: List[str] = field(default_factory=list)
45-
46-
def add(self, content: str) -> None:
47-
self.messages.append(content)
48-
49-
50-
# ---------------------------------------------------------------------------
51-
# Proxy Agent AF
52-
# ---------------------------------------------------------------------------
53-
54-
class ProxyAgent:
41+
class ProxyAgent(BaseAgent):
5542
"""
56-
A lightweight "agent" that mediates human clarification.
57-
Not a model-backed agent; it orchestrates a request and emits a synthetic reply.
43+
A human-in-the-loop clarification agent extending agent_framework's BaseAgent.
44+
45+
This agent mediates human clarification requests rather than using an LLM.
46+
It follows the agent_framework protocol with run() and run_stream() methods.
5847
"""
5948

6049
def __init__(
6150
self,
62-
user_id: Optional[str],
51+
user_id: str | None = None,
6352
name: str = "ProxyAgent",
6453
description: str = (
6554
"Clarification agent. Ask this when instructions are unclear or additional "
6655
"user details are required."
6756
),
68-
timeout_seconds: Optional[int] = None,
57+
timeout_seconds: int | None = None,
58+
**kwargs: Any,
6959
):
60+
super().__init__(
61+
name=name,
62+
description=description,
63+
**kwargs
64+
)
7065
self.user_id = user_id or ""
71-
self.name = name
72-
self.description = description
7366
self._timeout = timeout_seconds or orchestration_config.default_timeout
74-
self._conversation = ProxyConversation()
7567

7668
# ---------------------------
77-
# Public invocation interfaces
69+
# AgentProtocol implementation
7870
# ---------------------------
7971

80-
async def invoke(self, message: str) -> AsyncIterator[ChatResponseUpdate]:
72+
async def run(
73+
self,
74+
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
75+
*,
76+
thread: AgentThread | None = None,
77+
**kwargs: Any,
78+
) -> AgentRunResponse:
8179
"""
82-
One-shot style: waits for human clarification, then yields a single final response update.
80+
Get complete clarification response (non-streaming).
81+
82+
Args:
83+
messages: The message(s) requiring clarification
84+
thread: Optional conversation thread
85+
kwargs: Additional keyword arguments
86+
87+
Returns:
88+
AgentRunResponse with the clarification
8389
"""
84-
async for update in self.invoke_stream(message):
85-
# If caller expects only the final text, they can just collect the last update
86-
continue
87-
# When invoke_stream finishes, it already yielded final updates;
88-
# this wrapper exists for parity with LLM agents returning enumerables.
89-
return
90-
91-
async def invoke_stream(self, message: str) -> AsyncIterator[ChatResponseUpdate]:
90+
# Collect all streaming updates
91+
response_messages: list[ChatMessage] = []
92+
response_id = str(uuid.uuid4())
93+
94+
async for update in self.run_stream(messages, thread=thread, **kwargs):
95+
if update.contents:
96+
response_messages.append(
97+
ChatMessage(
98+
role=update.role or Role.ASSISTANT,
99+
contents=update.contents,
100+
)
101+
)
102+
103+
return AgentRunResponse(
104+
messages=response_messages,
105+
response_id=response_id,
106+
)
107+
108+
def run_stream(
109+
self,
110+
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
111+
*,
112+
thread: AgentThread | None = None,
113+
**kwargs: Any,
114+
) -> AsyncIterable[AgentRunResponseUpdate]:
92115
"""
93-
Streaming version:
94-
1. Sends clarification request via websocket (no yield yet).
95-
2. Waits for human response / timeout.
96-
3. Yields:
97-
- A ChatResponseUpdate with the final clarified answer (as assistant text) if received.
98-
- A usage marker (synthetic) for downstream consistency.
116+
Stream clarification process with human interaction.
117+
118+
Args:
119+
messages: The message(s) requiring clarification
120+
thread: Optional conversation thread
121+
kwargs: Additional keyword arguments
122+
123+
Yields:
124+
AgentRunResponseUpdate objects with clarification progress
99125
"""
100-
original_prompt = message or ""
101-
self._conversation.add(original_prompt)
126+
return self._invoke_stream_internal(messages, thread, **kwargs)
102127

103-
clarification_req_text = f"I need clarification about: {original_prompt}"
128+
async def _invoke_stream_internal(
129+
self,
130+
messages: str | ChatMessage | list[str] | list[ChatMessage] | None,
131+
thread: AgentThread | None,
132+
**kwargs: Any,
133+
) -> AsyncIterable[AgentRunResponseUpdate]:
134+
"""
135+
Internal streaming implementation.
136+
137+
1. Sends clarification request via websocket
138+
2. Waits for human response / timeout
139+
3. Yields AgentRunResponseUpdate with the clarified answer
140+
"""
141+
# Normalize messages to string
142+
message_text = self._extract_message_text(messages)
143+
144+
clarification_req_text = f"I need clarification about: {message_text}"
104145
clarification_request = UserClarificationRequest(
105146
question=clarification_req_text,
106147
request_id=str(uuid.uuid4()),
@@ -117,12 +158,14 @@ async def invoke_stream(self, message: str) -> AsyncIterator[ChatResponseUpdate]
117158
)
118159

119160
# Await human clarification
120-
human_response = await self._wait_for_user_clarification(clarification_request.request_id)
161+
human_response = await self._wait_for_user_clarification(
162+
clarification_request.request_id
163+
)
121164

122165
if human_response is None:
123-
# Timeout or cancellation already handled (timeout notification was sent).
166+
# Timeout or cancellation - end silently
124167
logger.debug(
125-
"ProxyAgentAF: No clarification response (timeout/cancel). Ending stream silently."
168+
"ProxyAgent: No clarification response (timeout/cancel). Ending stream."
126169
)
127170
return
128171

@@ -132,23 +175,61 @@ async def invoke_stream(self, message: str) -> AsyncIterator[ChatResponseUpdate]
132175
else "No additional clarification provided."
133176
)
134177
synthetic_reply = f"Human clarification: {answer_text}"
135-
self._conversation.add(synthetic_reply)
136178

137-
# Yield final assistant text chunk
138-
yield self._make_text_update(synthetic_reply, is_final=False)
179+
# Yield final assistant text update
180+
yield AgentRunResponseUpdate(
181+
role=Role.ASSISTANT,
182+
contents=[TextContent(text=synthetic_reply)],
183+
author_name=self.name,
184+
response_id=str(uuid.uuid4()),
185+
message_id=str(uuid.uuid4()),
186+
)
139187

140-
# Yield a synthetic usage update so downstream consumers can treat this like a model run
141-
yield self._make_usage_update(token_estimate=len(synthetic_reply.split()))
188+
# Yield synthetic usage update for consistency
189+
yield AgentRunResponseUpdate(
190+
role=Role.ASSISTANT,
191+
contents=[
192+
UsageContent(
193+
UsageDetails(
194+
input_token_count=0,
195+
output_token_count=len(synthetic_reply.split()),
196+
total_token_count=len(synthetic_reply.split()),
197+
)
198+
)
199+
],
200+
author_name=self.name,
201+
response_id=str(uuid.uuid4()),
202+
message_id=str(uuid.uuid4()),
203+
)
142204

143205
# ---------------------------
144-
# Internal helpers
206+
# Helper methods
145207
# ---------------------------
146208

209+
def _extract_message_text(
210+
self, messages: str | ChatMessage | list[str] | list[ChatMessage] | None
211+
) -> str:
212+
"""Extract text from various message formats."""
213+
if messages is None:
214+
return ""
215+
if isinstance(messages, str):
216+
return messages
217+
if isinstance(messages, ChatMessage):
218+
return messages.text or ""
219+
if isinstance(messages, list):
220+
if not messages:
221+
return ""
222+
if isinstance(messages[0], str):
223+
return " ".join(messages)
224+
if isinstance(messages[0], ChatMessage):
225+
return " ".join(msg.text or "" for msg in messages)
226+
return str(messages)
227+
147228
async def _wait_for_user_clarification(
148229
self, request_id: str
149-
) -> Optional[UserClarificationResponse]:
230+
) -> UserClarificationResponse | None:
150231
"""
151-
Wraps orchestration_config.wait_for_clarification with robust timeout & cleanup.
232+
Wait for user clarification with timeout and cancellation handling.
152233
"""
153234
orchestration_config.set_clarification_pending(request_id)
154235
try:
@@ -158,26 +239,26 @@ async def _wait_for_user_clarification(
158239
await self._notify_timeout(request_id)
159240
return None
160241
except asyncio.CancelledError:
161-
logger.debug("ProxyAgentAF: Clarification request %s cancelled", request_id)
242+
logger.debug("ProxyAgent: Clarification request %s cancelled", request_id)
162243
orchestration_config.cleanup_clarification(request_id)
163244
return None
164245
except KeyError:
165-
logger.debug("ProxyAgentAF: Invalid clarification request id %s", request_id)
246+
logger.debug("ProxyAgent: Invalid clarification request id %s", request_id)
166247
return None
167-
except Exception as ex:
168-
logger.debug("ProxyAgentAF: Unexpected error awaiting clarification: %s", ex)
248+
except Exception as ex:
249+
logger.debug("ProxyAgent: Unexpected error awaiting clarification: %s", ex)
169250
orchestration_config.cleanup_clarification(request_id)
170251
return None
171252
finally:
172-
# Safety net cleanup if still pending with no value.
253+
# Safety net cleanup
173254
if (
174255
request_id in orchestration_config.clarifications
175256
and orchestration_config.clarifications[request_id] is None
176257
):
177258
orchestration_config.cleanup_clarification(request_id)
178259

179260
async def _notify_timeout(self, request_id: str) -> None:
180-
"""Send a timeout notification to the client and clean up."""
261+
"""Send timeout notification to the client."""
181262
notice = TimeoutNotification(
182263
timeout_type="clarification",
183264
request_id=request_id,
@@ -195,59 +276,27 @@ async def _notify_timeout(self, request_id: str) -> None:
195276
message_type=WebsocketMessageType.TIMEOUT_NOTIFICATION,
196277
)
197278
logger.info(
198-
"ProxyAgentAF: Timeout notification sent (request_id=%s user=%s)",
279+
"ProxyAgent: Timeout notification sent (request_id=%s user=%s)",
199280
request_id,
200281
self.user_id,
201282
)
202-
except Exception as ex:
203-
logger.error("ProxyAgentAF: Failed to send timeout notification: %s", ex)
283+
except Exception as ex:
284+
logger.error("ProxyAgent: Failed to send timeout notification: %s", ex)
204285
orchestration_config.cleanup_clarification(request_id)
205286

206-
def _make_text_update(
207-
self,
208-
text: str,
209-
is_final: bool,
210-
) -> ChatResponseUpdate:
211-
"""
212-
Build a ChatResponseUpdate containing assistant text. We treat each
213-
emitted text as a 'delta'; downstream can concatenate if needed.
214-
"""
215-
return ChatResponseUpdate(
216-
role=Role.ASSISTANT,
217-
text=text,
218-
contents=[TextContent(text=text)],
219-
conversation_id=self._conversation.conversation_id,
220-
message_id=str(uuid.uuid4()),
221-
response_id=str(uuid.uuid4()),
222-
)
223-
224-
def _make_usage_update(self, token_estimate: int) -> ChatResponseUpdate:
225-
"""
226-
Provide a synthetic usage update (assist in downstream finalization logic).
227-
"""
228-
usage = UsageContent(
229-
UsageDetails(
230-
input_token_count=0,
231-
output_token_count=token_estimate,
232-
total_token_count=token_estimate,
233-
)
234-
)
235-
return ChatResponseUpdate(
236-
role=Role.ASSISTANT,
237-
text="",
238-
contents=[usage],
239-
conversation_id=self._conversation.conversation_id,
240-
message_id=str(uuid.uuid4()),
241-
response_id=str(uuid.uuid4()),
242-
)
243-
244287

245288
# ---------------------------------------------------------------------------
246289
# Factory
247290
# ---------------------------------------------------------------------------
248291

249-
async def create_proxy_agent(user_id: Optional[str] = None) -> ProxyAgent:
292+
async def create_proxy_agent(user_id: str | None = None) -> ProxyAgent:
250293
"""
251-
Factory for ProxyAgentAF (mirrors previous create_proxy_agent interface).
294+
Factory for ProxyAgent.
295+
296+
Args:
297+
user_id: User ID for websocket communication
298+
299+
Returns:
300+
Initialized ProxyAgent instance
252301
"""
253302
return ProxyAgent(user_id=user_id)

0 commit comments

Comments
 (0)