|
| 1 | +""" |
| 2 | +Agent Framework response callbacks for employee onboarding / multi-agent system. |
| 3 | +Replaces Semantic Kernel message types with agent_framework ChatResponseUpdate handling. |
| 4 | +""" |
| 5 | + |
| 6 | +import asyncio |
| 7 | +import json |
| 8 | +import logging |
| 9 | +import re |
| 10 | +import time |
| 11 | +from typing import Optional |
| 12 | + |
| 13 | +from agent_framework import ( |
| 14 | + ChatResponseUpdate, |
| 15 | + FunctionCallContent, |
| 16 | + UsageContent, |
| 17 | + Role, |
| 18 | + TextContent, |
| 19 | +) |
| 20 | + |
| 21 | +from af.config.settings import connection_config |
| 22 | +from af.models.messages import ( |
| 23 | + AgentMessage, |
| 24 | + AgentMessageStreaming, |
| 25 | + AgentToolCall, |
| 26 | + AgentToolMessage, |
| 27 | + WebsocketMessageType, |
| 28 | +) |
| 29 | + |
| 30 | +logger = logging.getLogger(__name__) |
| 31 | + |
| 32 | + |
| 33 | +# --------------------------------------------------------------------------- |
| 34 | +# Utility |
| 35 | +# --------------------------------------------------------------------------- |
| 36 | + |
| 37 | +_CITATION_PATTERNS = [ |
| 38 | + (r"\[\d+:\d+\|source\]", ""), # [9:0|source] |
| 39 | + (r"\[\s*source\s*\]", ""), # [source] |
| 40 | + (r"\[\d+\]", ""), # [12] |
| 41 | + (r"【[^】]*】", ""), # Unicode bracket citations |
| 42 | + (r"\(source:[^)]*\)", ""), # (source: xyz) |
| 43 | + (r"\[source:[^\]]*\]", ""), # [source: xyz] |
| 44 | +] |
| 45 | + |
| 46 | + |
| 47 | +def clean_citations(text: str) -> str: |
| 48 | + """Remove citation markers from agent responses while preserving formatting.""" |
| 49 | + if not text: |
| 50 | + return text |
| 51 | + for pattern, repl in _CITATION_PATTERNS: |
| 52 | + text = re.sub(pattern, repl, text, flags=re.IGNORECASE) |
| 53 | + return text |
| 54 | + |
| 55 | + |
| 56 | +def _parse_function_arguments(arg_value: Optional[str | dict]) -> dict: |
| 57 | + """Best-effort parse for function call arguments (stringified JSON or dict).""" |
| 58 | + if arg_value is None: |
| 59 | + return {} |
| 60 | + if isinstance(arg_value, dict): |
| 61 | + return arg_value |
| 62 | + if isinstance(arg_value, str): |
| 63 | + try: |
| 64 | + return json.loads(arg_value) |
| 65 | + except Exception: # noqa: BLE001 |
| 66 | + return {"raw": arg_value} |
| 67 | + return {"raw": str(arg_value)} |
| 68 | + |
| 69 | + |
| 70 | +# --------------------------------------------------------------------------- |
| 71 | +# Core handlers |
| 72 | +# --------------------------------------------------------------------------- |
| 73 | + |
| 74 | +def agent_framework_update_callback( |
| 75 | + update: ChatResponseUpdate, |
| 76 | + user_id: Optional[str] = None, |
| 77 | +) -> None: |
| 78 | + """ |
| 79 | + Handle a non-streaming perspective of updates (tool calls, intermediate steps, final usage). |
| 80 | + This can be called for each ChatResponseUpdate; it will route tool calls and standard text |
| 81 | + messages to WebSocket. |
| 82 | + """ |
| 83 | + agent_name = getattr(update, "model_id", None) or "Agent" |
| 84 | + # Use Role or fallback |
| 85 | + role = getattr(update, "role", Role.ASSISTANT) |
| 86 | + |
| 87 | + # Detect tool/function calls |
| 88 | + function_call_contents = [ |
| 89 | + c for c in (update.contents or []) |
| 90 | + if isinstance(c, FunctionCallContent) |
| 91 | + ] |
| 92 | + |
| 93 | + if user_id is None: |
| 94 | + return |
| 95 | + |
| 96 | + try: |
| 97 | + if function_call_contents: |
| 98 | + # Build tool message |
| 99 | + tool_message = AgentToolMessage(agent_name=agent_name) |
| 100 | + for fc in function_call_contents: |
| 101 | + args = _parse_function_arguments(getattr(fc, "arguments", None)) |
| 102 | + tool_message.tool_calls.append( |
| 103 | + AgentToolCall( |
| 104 | + tool_name=getattr(fc, "name", "unknown_tool"), |
| 105 | + arguments=args, |
| 106 | + ) |
| 107 | + ) |
| 108 | + asyncio.create_task( |
| 109 | + connection_config.send_status_update_async( |
| 110 | + tool_message, |
| 111 | + user_id, |
| 112 | + message_type=WebsocketMessageType.AGENT_TOOL_MESSAGE, |
| 113 | + ) |
| 114 | + ) |
| 115 | + logger.info("Function call(s) dispatched: %s", tool_message) |
| 116 | + return |
| 117 | + |
| 118 | + # Ignore pure usage or empty updates (handled as final in streaming handler) |
| 119 | + if any(isinstance(c, UsageContent) for c in (update.contents or [])): |
| 120 | + # We'll treat this as a final token accounting event; no standard message needed. |
| 121 | + logger.debug("UsageContent received (final accounting); skipping text dispatch.") |
| 122 | + return |
| 123 | + |
| 124 | + # Standard assistant/user message (non-stream delta) |
| 125 | + if update.text: |
| 126 | + final_message = AgentMessage( |
| 127 | + agent_name=agent_name, |
| 128 | + timestamp=str(time.time()), |
| 129 | + content=clean_citations(update.text), |
| 130 | + ) |
| 131 | + asyncio.create_task( |
| 132 | + connection_config.send_status_update_async( |
| 133 | + final_message, |
| 134 | + user_id, |
| 135 | + message_type=WebsocketMessageType.AGENT_MESSAGE, |
| 136 | + ) |
| 137 | + ) |
| 138 | + logger.info("%s message: %s", role.name.capitalize(), final_message) |
| 139 | + |
| 140 | + except Exception as e: # noqa: BLE001 |
| 141 | + logger.error("agent_framework_update_callback: Error sending WebSocket message: %s", e) |
| 142 | + |
| 143 | + |
| 144 | +async def streaming_agent_framework_callback( |
| 145 | + update: ChatResponseUpdate, |
| 146 | + user_id: Optional[str] = None, |
| 147 | +) -> None: |
| 148 | + """ |
| 149 | + Handle streaming deltas. For each update with text, forward a streaming message. |
| 150 | + Mark is_final=True when a UsageContent is observed (end of run). |
| 151 | + """ |
| 152 | + if user_id is None: |
| 153 | + return |
| 154 | + |
| 155 | + try: |
| 156 | + # Determine if this update marks the end |
| 157 | + is_final = any(isinstance(c, UsageContent) for c in (update.contents or [])) |
| 158 | + |
| 159 | + # Streaming text can appear either in update.text or inside TextContent entries. |
| 160 | + pieces: list[str] = [] |
| 161 | + if update.text: |
| 162 | + pieces.append(update.text) |
| 163 | + # Some events may provide TextContent objects without setting update.text |
| 164 | + for c in (update.contents or []): |
| 165 | + if isinstance(c, TextContent) and getattr(c, "text", None): |
| 166 | + pieces.append(c.text) |
| 167 | + |
| 168 | + if not pieces: |
| 169 | + return |
| 170 | + |
| 171 | + streaming_message = AgentMessageStreaming( |
| 172 | + agent_name=getattr(update, "model_id", None) or "Agent", |
| 173 | + content=clean_citations("".join(pieces)), |
| 174 | + is_final=is_final, |
| 175 | + ) |
| 176 | + |
| 177 | + await connection_config.send_status_update_async( |
| 178 | + streaming_message, |
| 179 | + user_id, |
| 180 | + message_type=WebsocketMessageType.AGENT_MESSAGE_STREAMING, |
| 181 | + ) |
| 182 | + |
| 183 | + if is_final: |
| 184 | + logger.info("Final streaming chunk sent for agent '%s'", streaming_message.agent_name) |
| 185 | + |
| 186 | + except Exception as e: # noqa: BLE001 |
| 187 | + logger.error("streaming_agent_framework_callback: Error sending streaming WebSocket message: %s", e) |
| 188 | + |
| 189 | + |
| 190 | +# --------------------------------------------------------------------------- |
| 191 | +# Convenience wrappers (optional) |
| 192 | +# --------------------------------------------------------------------------- |
| 193 | + |
| 194 | +def handle_update(update: ChatResponseUpdate, user_id: Optional[str]) -> None: |
| 195 | + """ |
| 196 | + Unified entry point if caller doesn't distinguish streaming vs non-streaming. |
| 197 | + You can call this once per update. It will: |
| 198 | + - Forward streaming text increments |
| 199 | + - Forward tool calls |
| 200 | + - Skip purely usage-only events (except marking final in streaming) |
| 201 | + """ |
| 202 | + # Send streaming chunk first (async context) |
| 203 | + asyncio.create_task(streaming_agent_framework_callback(update, user_id)) |
| 204 | + # Then send non-stream items (tool calls or discrete messages) |
| 205 | + agent_framework_update_callback(update, user_id) |
| 206 | + |
0 commit comments