Skip to content

Commit de80f8c

Browse files
authored
Python: Include streaming code output for OpenAI Assistants (microsoft#9080)
### Motivation and Context In the recent release of OpenAI Assistant streaming responses, there was a gap related to not exposing code interpreter output, when available. <!-- Thank you for your contribution to the semantic-kernel repo! Please help reviewers and future users, providing the following information: 1. Why is this change required? 2. What problem does it solve? 3. What scenario does it contribute to? 4. If it fixes an open issue, please link to the issue here. --> ### Description This PR closes that gap, and adds the ability to yield messages that contain code input/output. A new concept sample was added to show this. - Unit test coverage also added. <!-- Describe your changes, the overall approach, the underlying design. These notes will help understanding how your code works. Thanks! --> ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [X] The code builds clean without any errors or warnings - [X] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [X] All unit tests pass, and I have added new tests where possible - [X] I didn't break anyone 😄
1 parent 28976b0 commit de80f8c

File tree

5 files changed

+225
-29
lines changed

5 files changed

+225
-29
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
import asyncio
3+
import os
4+
5+
from semantic_kernel.agents.open_ai import OpenAIAssistantAgent
6+
from semantic_kernel.agents.open_ai.azure_assistant_agent import AzureAssistantAgent
7+
from semantic_kernel.contents.chat_message_content import ChatMessageContent
8+
from semantic_kernel.contents.utils.author_role import AuthorRole
9+
from semantic_kernel.kernel import Kernel
10+
11+
#####################################################################
12+
# The following sample demonstrates how to create an OpenAI #
13+
# assistant using either Azure OpenAI or OpenAI and leverage the #
14+
# assistant's ability to stream the response and have the code #
15+
# interpreter work with uploaded files #
16+
#####################################################################
17+
18+
AGENT_NAME = "FileManipulation"
19+
AGENT_INSTRUCTIONS = "Find answers to the user's questions in the provided file."
20+
21+
22+
# A helper method to invoke the agent with the user input
23+
async def invoke_streaming_agent(agent: OpenAIAssistantAgent, thread_id: str, input: str) -> None:
24+
"""Invoke the streaming agent with the user input."""
25+
await agent.add_chat_message(thread_id=thread_id, message=ChatMessageContent(role=AuthorRole.USER, content=input))
26+
27+
print(f"# {AuthorRole.USER}: '{input}'")
28+
29+
first_chunk = True
30+
async for content in agent.invoke_stream(thread_id=thread_id):
31+
if content.role != AuthorRole.TOOL:
32+
if first_chunk:
33+
print(f"# {content.role}: ", end="", flush=True)
34+
first_chunk = False
35+
print(content.content, end="", flush=True)
36+
elif content.role == AuthorRole.TOOL and content.metadata.get("code"):
37+
print("")
38+
print(f"# {content.role} (code):\n\n{content.content}")
39+
print()
40+
41+
42+
async def main():
43+
# Create the instance of the Kernel
44+
kernel = Kernel()
45+
46+
# Define a service_id for the sample
47+
service_id = "agent"
48+
49+
# Get the path to the sales.csv file
50+
csv_file_path = os.path.join(
51+
os.path.dirname(os.path.dirname(os.path.realpath(__file__))),
52+
"resources",
53+
"agent_assistant_file_manipulation",
54+
"sales.csv",
55+
)
56+
57+
# Create the assistant agent
58+
agent = await AzureAssistantAgent.create(
59+
kernel=kernel,
60+
service_id=service_id,
61+
name=AGENT_NAME,
62+
instructions=AGENT_INSTRUCTIONS,
63+
enable_code_interpreter=True,
64+
code_interpreter_filenames=[csv_file_path],
65+
)
66+
67+
# Create a thread and specify the file to use for code interpretation
68+
thread_id = await agent.create_thread()
69+
70+
try:
71+
await invoke_streaming_agent(agent, thread_id=thread_id, input="Which segment had the most sales?")
72+
await invoke_streaming_agent(
73+
agent, thread_id=thread_id, input="List the top 5 countries that generated the most profit."
74+
)
75+
await invoke_streaming_agent(
76+
agent,
77+
thread_id=thread_id,
78+
input="Create a tab delimited file report of profit by each country per month.",
79+
)
80+
finally:
81+
if agent is not None:
82+
[await agent.delete_file(file_id) for file_id in agent.code_interpreter_file_ids]
83+
await agent.delete_thread(thread_id)
84+
await agent.delete()
85+
86+
87+
if __name__ == "__main__":
88+
asyncio.run(main())

python/semantic_kernel/agents/open_ai/assistant_content_generation.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from openai.types.beta.threads.image_file_content_block import ImageFileContentBlock
99
from openai.types.beta.threads.image_file_delta_block import ImageFileDeltaBlock
1010
from openai.types.beta.threads.message_delta_event import MessageDeltaEvent
11+
from openai.types.beta.threads.runs.code_interpreter_tool_call import CodeInterpreter
1112
from openai.types.beta.threads.text_content_block import TextContentBlock
1213
from openai.types.beta.threads.text_delta_block import TextDeltaBlock
1314

@@ -32,6 +33,7 @@
3233
from openai.types.beta.threads.annotation import Annotation
3334
from openai.types.beta.threads.runs import RunStep
3435
from openai.types.beta.threads.runs.tool_call import ToolCall
36+
from openai.types.beta.threads.runs.tool_calls_step_details import ToolCallsStepDetails
3537

3638

3739
###################################################################
@@ -258,6 +260,56 @@ def generate_code_interpreter_content(agent_name: str, code: str) -> "ChatMessag
258260
)
259261

260262

263+
@experimental_function
264+
def generate_streaming_tools_content(
265+
agent_name: str, step_details: "ToolCallsStepDetails"
266+
) -> "StreamingChatMessageContent | None":
267+
"""Generate code interpreter content.
268+
269+
Args:
270+
agent_name: The agent name.
271+
step_details: The current step details.
272+
273+
Returns:
274+
StreamingChatMessageContent: The chat message content.
275+
"""
276+
items: list[StreamingTextContent | StreamingFileReferenceContent] = []
277+
278+
metadata: dict[str, bool] = {}
279+
for index, tool in enumerate(step_details.tool_calls):
280+
if tool.type != "code_interpreter":
281+
continue
282+
if tool.code_interpreter.input:
283+
items.append(
284+
StreamingTextContent(
285+
choice_index=index,
286+
text=tool.code_interpreter.input,
287+
)
288+
)
289+
metadata["code"] = True
290+
if len(tool.code_interpreter.outputs) > 0:
291+
for output in tool.code_interpreter.outputs:
292+
assert isinstance(output, CodeInterpreter) # nosec
293+
if output.image.file_id:
294+
items.append(
295+
StreamingFileReferenceContent(
296+
file_id=output.image.file_id,
297+
)
298+
)
299+
300+
return (
301+
StreamingChatMessageContent(
302+
role=AuthorRole.TOOL,
303+
name=agent_name,
304+
items=items, # type: ignore
305+
choice_index=0,
306+
metadata=metadata if metadata else None,
307+
)
308+
if len(items) > 0
309+
else None
310+
)
311+
312+
261313
@experimental_function
262314
def generate_annotation_content(annotation: "Annotation") -> AnnotationContent:
263315
"""Generate annotation content."""

python/semantic_kernel/agents/open_ai/open_ai_assistant_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
generate_function_result_content,
2626
generate_message_content,
2727
generate_streaming_message_content,
28+
generate_streaming_tools_content,
2829
get_function_call_contents,
2930
get_message_contents,
3031
)
@@ -920,6 +921,10 @@ async def _invoke_internal_stream(
920921
message_id = event.data.step_details.message_creation.message_id
921922
if message_id not in active_messages:
922923
active_messages[message_id] = event.data
924+
elif hasattr(event.data.step_details, "tool_calls"):
925+
tool_content = generate_streaming_tools_content(self.name, event.data.step_details)
926+
if tool_content:
927+
yield tool_content
923928
elif event.event == "thread.run.requires_action":
924929
run = event.data
925930
function_action_result = await self._handle_streaming_requires_action(run, function_steps)

python/tests/unit/agents/test_open_ai_assistant_base.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,34 @@ def mock_thread_run_step_completed():
517517
)
518518

519519

520+
def mock_thread_run_step_completed_with_code():
521+
return ThreadRunStepCompleted(
522+
data=RunStep(
523+
id="step_id_2",
524+
type="message_creation",
525+
completed_at=int(datetime.now(timezone.utc).timestamp()),
526+
created_at=int((datetime.now(timezone.utc) - timedelta(minutes=2)).timestamp()),
527+
step_details=ToolCallsStepDetails(
528+
type="tool_calls",
529+
tool_calls=[
530+
CodeInterpreterToolCall(
531+
id="tool_call_id",
532+
code_interpreter=CodeInterpreter(input="test code", outputs=[]),
533+
type="code_interpreter",
534+
)
535+
],
536+
),
537+
assistant_id="assistant_id",
538+
object="thread.run.step",
539+
run_id="run_id",
540+
status="completed",
541+
thread_id="thread_id",
542+
usage=Usage(completion_tokens=10, prompt_tokens=5, total_tokens=15),
543+
),
544+
event="thread.run.step.completed",
545+
)
546+
547+
520548
def mock_run_with_last_error():
521549
return ThreadRunFailed(
522550
data=Run(
@@ -1161,6 +1189,31 @@ async def test_invoke_stream(
11611189
assert len(messages) > 0
11621190

11631191

1192+
@pytest.mark.asyncio
1193+
async def test_invoke_stream_code_output(
1194+
azure_openai_assistant_agent,
1195+
mock_assistant,
1196+
azure_openai_unit_test_env,
1197+
):
1198+
events = [mock_thread_run_step_completed_with_code()]
1199+
1200+
with patch.object(azure_openai_assistant_agent, "client", spec=AsyncAzureOpenAI) as mock_client:
1201+
mock_client.beta = MagicMock()
1202+
mock_client.beta.threads = MagicMock()
1203+
mock_client.beta.assistants = MagicMock()
1204+
mock_client.beta.assistants.create = AsyncMock(return_value=mock_assistant)
1205+
1206+
mock_client.beta.threads.runs = MagicMock()
1207+
mock_client.beta.threads.runs.stream = MagicMock(return_value=MockStream(events))
1208+
1209+
azure_openai_assistant_agent.assistant = await azure_openai_assistant_agent.create_assistant()
1210+
1211+
messages = []
1212+
async for content in azure_openai_assistant_agent.invoke_stream("thread_id", messages=messages):
1213+
assert content is not None
1214+
assert content.metadata.get("code") is True
1215+
1216+
11641217
@pytest.mark.asyncio
11651218
async def test_invoke_stream_requires_action(
11661219
azure_openai_assistant_agent, mock_assistant, mock_thread_messages, azure_openai_unit_test_env

0 commit comments

Comments
 (0)