forked from microsoft/agent-framework
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathoverride_result_with_middleware.py
More file actions
220 lines (175 loc) · 8.33 KB
/
override_result_with_middleware.py
File metadata and controls
220 lines (175 loc) · 8.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import re
from collections.abc import AsyncIterable, Awaitable, Callable
from random import randint
from typing import Annotated
from agent_framework import (
AgentContext,
AgentResponse,
AgentResponseUpdate,
ChatContext,
ChatResponse,
ChatResponseUpdate,
Content,
Message,
ResponseStream,
tool,
)
from agent_framework.openai import OpenAIResponsesClient
from dotenv import load_dotenv
from pydantic import Field
# Load environment variables from .env file
load_dotenv()
"""
Result Override with MiddlewareTypes (Regular and Streaming)
This sample demonstrates how to use middleware to intercept and modify function results
after execution, supporting both regular and streaming agent responses. The example shows:
- How to execute the original function first and then modify its result
- Replacing function outputs with custom messages or transformed data
- Using middleware for result filtering, formatting, or enhancement
- Detecting streaming vs non-streaming execution using context.stream
- Overriding streaming results with custom async generators
The weather override middleware lets the original weather function execute normally,
then replaces its result with a custom "perfect weather" message. For streaming responses,
it creates a custom async generator that yields the override message in chunks.
"""
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production;
# see samples/02-agents/tools/function_tool_with_approval.py
# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py.
@tool(approval_mode="never_require")
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
"""Get the weather for a given location."""
conditions = ["sunny", "cloudy", "rainy", "stormy"]
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
async def weather_override_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
"""Chat middleware that overrides weather results for both streaming and non-streaming cases."""
# Let the original agent execution complete first
await call_next()
# Check if there's a result to override (agent called weather function)
if context.result is not None:
# Create custom weather message
chunks = [
"due to special atmospheric conditions, ",
"all locations are experiencing perfect weather today! ",
"Temperature is a comfortable 22°C with gentle breezes. ",
"Perfect day for outdoor activities!",
]
if context.stream and isinstance(context.result, ResponseStream):
async def _override_stream() -> AsyncIterable[ChatResponseUpdate]:
for i, chunk_text in enumerate(chunks):
yield ChatResponseUpdate(
contents=[Content.from_text(text=f"Weather Advisory: [{i}] {chunk_text}")],
role="assistant",
)
context.result = ResponseStream(_override_stream())
else:
# For non-streaming: just replace with a new message
current_text = context.result.text if isinstance(context.result, ChatResponse) else ""
custom_message = f"Weather Advisory: [0] {''.join(chunks)} Original message was: {current_text}"
context.result = ChatResponse(messages=[Message(role="assistant", text=custom_message)])
async def validate_weather_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
"""Chat middleware that simulates result validation for both streaming and non-streaming cases."""
await call_next()
validation_note = "Validation: weather data verified."
if context.result is None:
return
if context.stream and isinstance(context.result, ResponseStream):
def _append_validation_note(response: ChatResponse) -> ChatResponse:
response.messages.append(Message(role="assistant", text=validation_note))
return response
context.result.with_finalizer(_append_validation_note)
elif isinstance(context.result, ChatResponse):
context.result.messages.append(Message(role="assistant", text=validation_note))
async def agent_cleanup_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
"""Agent middleware that validates chat middleware effects and cleans the result."""
await call_next()
if context.result is None:
return
validation_note = "Validation: weather data verified."
state = {"found_prefix": False}
def _sanitize(response: AgentResponse) -> AgentResponse:
found_prefix = state["found_prefix"]
found_validation = False
cleaned_messages: list[Message] = []
for message in response.messages:
text = message.text
if text is None:
cleaned_messages.append(message)
continue
if validation_note in text:
found_validation = True
text = text.replace(validation_note, "").strip()
if not text:
continue
if "Weather Advisory:" in text:
found_prefix = True
text = text.replace("Weather Advisory:", "")
text = re.sub(r"\[\d+\]\s*", "", text)
cleaned_messages.append(
Message(
role=message.role,
text=text.strip(),
author_name=message.author_name,
message_id=message.message_id,
additional_properties=message.additional_properties,
raw_representation=message.raw_representation,
)
)
if not found_prefix:
raise RuntimeError("Expected chat middleware prefix not found in agent response.")
if not found_validation:
raise RuntimeError("Expected validation note not found in agent response.")
cleaned_messages.append(Message(role="assistant", text=" Agent: OK"))
response.messages = cleaned_messages
return response
if context.stream and isinstance(context.result, ResponseStream):
def _clean_update(update: AgentResponseUpdate) -> AgentResponseUpdate:
for content in update.contents or []:
if not content.text:
continue
text = content.text
if "Weather Advisory:" in text:
state["found_prefix"] = True
text = text.replace("Weather Advisory:", "")
text = re.sub(r"\[\d+\]\s*", "", text)
content.text = text
return update
context.result.with_transform_hook(_clean_update)
context.result.with_finalizer(_sanitize)
elif isinstance(context.result, AgentResponse):
context.result = _sanitize(context.result)
async def main() -> None:
"""Example demonstrating result override with middleware for both streaming and non-streaming."""
print("=== Result Override MiddlewareTypes Example ===")
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
# authentication option.
agent = OpenAIResponsesClient(
middleware=[validate_weather_middleware, weather_override_middleware],
).as_agent(
name="WeatherAgent",
instructions="You are a helpful weather assistant. Use the weather tool to get current conditions.",
tools=get_weather,
middleware=[agent_cleanup_middleware],
)
# Non-streaming example
print("\n--- Non-streaming Example ---")
query = "What's the weather like in Seattle?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result}")
# Streaming example
print("\n--- Streaming Example ---")
query = "What's the weather like in Portland?"
print(f"User: {query}")
print("Agent: ", end="", flush=True)
response = agent.run(query, stream=True)
async for chunk in response:
if chunk.text:
print(chunk.text, end="", flush=True)
print("\n")
print(f"Final Result: {(await response.get_final_response()).text}")
if __name__ == "__main__":
asyncio.run(main())