Skip to content

Commit 0fc7933

Browse files
authored
Fix WorkflowAgent to include thread convo history. Enable checkpointing. (#2774)
1 parent d7434d5 commit 0fc7933

File tree

5 files changed

+508
-6
lines changed

5 files changed

+508
-6
lines changed

python/packages/core/agent_framework/_workflows/_agent.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
)
2424

2525
from ..exceptions import AgentExecutionException
26+
from ._checkpoint import CheckpointStorage
2627
from ._events import (
2728
AgentRunUpdateEvent,
2829
RequestInfoEvent,
@@ -117,17 +118,25 @@ async def run(
117118
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
118119
*,
119120
thread: AgentThread | None = None,
121+
checkpoint_id: str | None = None,
122+
checkpoint_storage: CheckpointStorage | None = None,
120123
**kwargs: Any,
121124
) -> AgentRunResponse:
122125
"""Get a response from the workflow agent (non-streaming).
123126
124127
This method collects all streaming updates and merges them into a single response.
125128
126129
Args:
127-
messages: The message(s) to send to the workflow.
130+
messages: The message(s) to send to the workflow. Required for new runs,
131+
should be None when resuming from checkpoint.
128132
129133
Keyword Args:
130134
thread: The conversation thread. If None, a new thread will be created.
135+
checkpoint_id: ID of checkpoint to restore from. If provided, the workflow
136+
resumes from this checkpoint instead of starting fresh.
137+
checkpoint_storage: Runtime checkpoint storage. When provided with checkpoint_id,
138+
used to load and restore the checkpoint. When provided without checkpoint_id,
139+
enables checkpointing for this run.
131140
**kwargs: Additional keyword arguments.
132141
133142
Returns:
@@ -139,7 +148,9 @@ async def run(
139148
thread = thread or self.get_new_thread()
140149
response_id = str(uuid.uuid4())
141150

142-
async for update in self._run_stream_impl(input_messages, response_id):
151+
async for update in self._run_stream_impl(
152+
input_messages, response_id, thread, checkpoint_id, checkpoint_storage
153+
):
143154
response_updates.append(update)
144155

145156
# Convert updates to final response.
@@ -155,15 +166,23 @@ async def run_stream(
155166
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
156167
*,
157168
thread: AgentThread | None = None,
169+
checkpoint_id: str | None = None,
170+
checkpoint_storage: CheckpointStorage | None = None,
158171
**kwargs: Any,
159172
) -> AsyncIterable[AgentRunResponseUpdate]:
160173
"""Stream response updates from the workflow agent.
161174
162175
Args:
163-
messages: The message(s) to send to the workflow.
176+
messages: The message(s) to send to the workflow. Required for new runs,
177+
should be None when resuming from checkpoint.
164178
165179
Keyword Args:
166180
thread: The conversation thread. If None, a new thread will be created.
181+
checkpoint_id: ID of checkpoint to restore from. If provided, the workflow
182+
resumes from this checkpoint instead of starting fresh.
183+
checkpoint_storage: Runtime checkpoint storage. When provided with checkpoint_id,
184+
used to load and restore the checkpoint. When provided without checkpoint_id,
185+
enables checkpointing for this run.
167186
**kwargs: Additional keyword arguments.
168187
169188
Yields:
@@ -174,7 +193,9 @@ async def run_stream(
174193
response_updates: list[AgentRunResponseUpdate] = []
175194
response_id = str(uuid.uuid4())
176195

177-
async for update in self._run_stream_impl(input_messages, response_id):
196+
async for update in self._run_stream_impl(
197+
input_messages, response_id, thread, checkpoint_id, checkpoint_storage
198+
):
178199
response_updates.append(update)
179200
yield update
180201

@@ -188,12 +209,18 @@ async def _run_stream_impl(
188209
self,
189210
input_messages: list[ChatMessage],
190211
response_id: str,
212+
thread: AgentThread,
213+
checkpoint_id: str | None = None,
214+
checkpoint_storage: CheckpointStorage | None = None,
191215
) -> AsyncIterable[AgentRunResponseUpdate]:
192216
"""Internal implementation of streaming execution.
193217
194218
Args:
195219
input_messages: Normalized input messages to process.
196220
response_id: The unique response ID for this workflow execution.
221+
thread: The conversation thread containing message history.
222+
checkpoint_id: ID of checkpoint to restore from.
223+
checkpoint_storage: Runtime checkpoint storage.
197224
198225
Yields:
199226
AgentRunResponseUpdate objects representing the workflow execution progress.
@@ -217,10 +244,27 @@ async def _run_stream_impl(
217244
# and we will let the workflow to handle this -- the agent does not
218245
# have an opinion on this.
219246
event_stream = self.workflow.send_responses_streaming(function_responses)
247+
elif checkpoint_id is not None:
248+
# Resume from checkpoint - don't prepend thread history since workflow state
249+
# is being restored from the checkpoint
250+
event_stream = self.workflow.run_stream(
251+
message=None,
252+
checkpoint_id=checkpoint_id,
253+
checkpoint_storage=checkpoint_storage,
254+
)
220255
else:
221256
# Execute workflow with streaming (initial run or no function responses)
222-
# Pass the new input messages directly to the workflow
223-
event_stream = self.workflow.run_stream(input_messages)
257+
# Build the complete conversation by prepending thread history to input messages
258+
conversation_messages: list[ChatMessage] = []
259+
if thread.message_store:
260+
history = await thread.message_store.list_messages()
261+
if history:
262+
conversation_messages.extend(history)
263+
conversation_messages.extend(input_messages)
264+
event_stream = self.workflow.run_stream(
265+
message=conversation_messages,
266+
checkpoint_storage=checkpoint_storage,
267+
)
224268

225269
# Process events from the stream
226270
async for event in event_stream:

python/packages/core/tests/workflow/test_workflow_agent.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
AgentRunResponse,
1010
AgentRunResponseUpdate,
1111
AgentRunUpdateEvent,
12+
AgentThread,
1213
ChatMessage,
14+
ChatMessageStore,
1315
Executor,
1416
FunctionApprovalRequestContent,
1517
FunctionApprovalResponseContent,
@@ -75,6 +77,31 @@ async def handle_request_response(
7577
await ctx.add_event(AgentRunUpdateEvent(executor_id=self.id, data=update))
7678

7779

80+
class ConversationHistoryCapturingExecutor(Executor):
81+
"""Executor that captures the received conversation history for verification."""
82+
83+
def __init__(self, id: str):
84+
super().__init__(id=id)
85+
self.received_messages: list[ChatMessage] = []
86+
87+
@handler
88+
async def handle_message(self, messages: list[ChatMessage], ctx: WorkflowContext[list[ChatMessage]]) -> None:
89+
# Capture all received messages
90+
self.received_messages = list(messages)
91+
92+
# Count messages by role for the response
93+
message_count = len(messages)
94+
response_text = f"Received {message_count} messages"
95+
96+
response_message = ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text=response_text)])
97+
98+
streaming_update = AgentRunResponseUpdate(
99+
contents=[TextContent(text=response_text)], role=Role.ASSISTANT, message_id=str(uuid.uuid4())
100+
)
101+
await ctx.add_event(AgentRunUpdateEvent(executor_id=self.id, data=streaming_update))
102+
await ctx.send_message([response_message])
103+
104+
78105
class TestWorkflowAgent:
79106
"""Test cases for WorkflowAgent end-to-end functionality."""
80107

@@ -257,6 +284,105 @@ async def handle_bool(self, message: bool, context: WorkflowContext[Any]) -> Non
257284
with pytest.raises(ValueError, match="Workflow's start executor cannot handle list\\[ChatMessage\\]"):
258285
workflow.as_agent()
259286

287+
async def test_thread_conversation_history_included_in_workflow_run(self) -> None:
288+
"""Test that conversation history from thread is included when running WorkflowAgent.
289+
290+
This verifies that when a thread with existing messages is provided to agent.run(),
291+
the workflow receives the complete conversation history (thread history + new messages).
292+
"""
293+
# Create an executor that captures all received messages
294+
capturing_executor = ConversationHistoryCapturingExecutor(id="capturing")
295+
workflow = WorkflowBuilder().set_start_executor(capturing_executor).build()
296+
agent = WorkflowAgent(workflow=workflow, name="Thread History Test Agent")
297+
298+
# Create a thread with existing conversation history
299+
history_messages = [
300+
ChatMessage(role=Role.USER, text="Previous user message"),
301+
ChatMessage(role=Role.ASSISTANT, text="Previous assistant response"),
302+
]
303+
message_store = ChatMessageStore(messages=history_messages)
304+
thread = AgentThread(message_store=message_store)
305+
306+
# Run the agent with the thread and a new message
307+
new_message = "New user question"
308+
await agent.run(new_message, thread=thread)
309+
310+
# Verify the executor received both history AND new message
311+
assert len(capturing_executor.received_messages) == 3
312+
313+
# Verify the order: history first, then new message
314+
assert capturing_executor.received_messages[0].text == "Previous user message"
315+
assert capturing_executor.received_messages[1].text == "Previous assistant response"
316+
assert capturing_executor.received_messages[2].text == "New user question"
317+
318+
async def test_thread_conversation_history_included_in_workflow_stream(self) -> None:
319+
"""Test that conversation history from thread is included when streaming WorkflowAgent.
320+
321+
This verifies that run_stream also includes thread history.
322+
"""
323+
# Create an executor that captures all received messages
324+
capturing_executor = ConversationHistoryCapturingExecutor(id="capturing_stream")
325+
workflow = WorkflowBuilder().set_start_executor(capturing_executor).build()
326+
agent = WorkflowAgent(workflow=workflow, name="Thread Stream Test Agent")
327+
328+
# Create a thread with existing conversation history
329+
history_messages = [
330+
ChatMessage(role=Role.SYSTEM, text="You are a helpful assistant"),
331+
ChatMessage(role=Role.USER, text="Hello"),
332+
ChatMessage(role=Role.ASSISTANT, text="Hi there!"),
333+
]
334+
message_store = ChatMessageStore(messages=history_messages)
335+
thread = AgentThread(message_store=message_store)
336+
337+
# Stream from the agent with the thread and a new message
338+
async for _ in agent.run_stream("How are you?", thread=thread):
339+
pass
340+
341+
# Verify the executor received all messages (3 from history + 1 new)
342+
assert len(capturing_executor.received_messages) == 4
343+
344+
# Verify the order
345+
assert capturing_executor.received_messages[0].text == "You are a helpful assistant"
346+
assert capturing_executor.received_messages[1].text == "Hello"
347+
assert capturing_executor.received_messages[2].text == "Hi there!"
348+
assert capturing_executor.received_messages[3].text == "How are you?"
349+
350+
async def test_empty_thread_works_correctly(self) -> None:
351+
"""Test that an empty thread (no message store) works correctly."""
352+
capturing_executor = ConversationHistoryCapturingExecutor(id="empty_thread_test")
353+
workflow = WorkflowBuilder().set_start_executor(capturing_executor).build()
354+
agent = WorkflowAgent(workflow=workflow, name="Empty Thread Test Agent")
355+
356+
# Create an empty thread
357+
thread = AgentThread()
358+
359+
# Run with the empty thread
360+
await agent.run("Just a new message", thread=thread)
361+
362+
# Should only receive the new message
363+
assert len(capturing_executor.received_messages) == 1
364+
assert capturing_executor.received_messages[0].text == "Just a new message"
365+
366+
async def test_checkpoint_storage_passed_to_workflow(self) -> None:
367+
"""Test that checkpoint_storage parameter is passed through to the workflow."""
368+
from agent_framework import InMemoryCheckpointStorage
369+
370+
capturing_executor = ConversationHistoryCapturingExecutor(id="checkpoint_test")
371+
workflow = WorkflowBuilder().set_start_executor(capturing_executor).build()
372+
agent = WorkflowAgent(workflow=workflow, name="Checkpoint Test Agent")
373+
374+
# Create checkpoint storage
375+
checkpoint_storage = InMemoryCheckpointStorage()
376+
377+
# Run with checkpoint storage enabled
378+
async for _ in agent.run_stream("Test message", checkpoint_storage=checkpoint_storage):
379+
pass
380+
381+
# Drain workflow events to get checkpoint
382+
# The workflow should have created checkpoints
383+
checkpoints = await checkpoint_storage.list_checkpoints(workflow.id)
384+
assert len(checkpoints) > 0, "Checkpoints should have been created when checkpoint_storage is provided"
385+
260386

261387
class TestWorkflowAgentMergeUpdates:
262388
"""Test cases specifically for the WorkflowAgent.merge_updates static method."""

python/samples/getting_started/workflows/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ Once comfortable with these, explore the rest of the samples below.
4444
| Magentic Workflow as Agent | [agents/magentic_workflow_as_agent.py](./agents/magentic_workflow_as_agent.py) | Configure Magentic orchestration with callbacks, then expose the workflow as an agent |
4545
| Workflow as Agent (Reflection Pattern) | [agents/workflow_as_agent_reflection_pattern.py](./agents/workflow_as_agent_reflection_pattern.py) | Wrap a workflow so it can behave like an agent (reflection pattern) |
4646
| Workflow as Agent + HITL | [agents/workflow_as_agent_human_in_the_loop.py](./agents/workflow_as_agent_human_in_the_loop.py) | Extend workflow-as-agent with human-in-the-loop capability |
47+
| Workflow as Agent with Thread | [agents/workflow_as_agent_with_thread.py](./agents/workflow_as_agent_with_thread.py) | Use AgentThread to maintain conversation history across workflow-as-agent invocations |
4748
| Handoff Workflow as Agent | [agents/handoff_workflow_as_agent.py](./agents/handoff_workflow_as_agent.py) | Use a HandoffBuilder workflow as an agent with HITL via FunctionCallContent/FunctionResultContent |
4849

4950
### checkpoint
@@ -54,6 +55,7 @@ Once comfortable with these, explore the rest of the samples below.
5455
| Checkpoint & HITL Resume | [checkpoint/checkpoint_with_human_in_the_loop.py](./checkpoint/checkpoint_with_human_in_the_loop.py) | Combine checkpointing with human approvals and resume pending HITL requests |
5556
| Checkpointed Sub-Workflow | [checkpoint/sub_workflow_checkpoint.py](./checkpoint/sub_workflow_checkpoint.py) | Save and resume a sub-workflow that pauses for human approval |
5657
| Handoff + Tool Approval Resume | [checkpoint/handoff_with_tool_approval_checkpoint_resume.py](./checkpoint/handoff_with_tool_approval_checkpoint_resume.py) | Handoff workflow that captures tool-call approvals in checkpoints and resumes with human decisions |
58+
| Workflow as Agent Checkpoint | [checkpoint/workflow_as_agent_checkpoint.py](./checkpoint/workflow_as_agent_checkpoint.py) | Enable checkpointing when using workflow.as_agent() with checkpoint_storage parameter |
5759

5860
### composition
5961

0 commit comments

Comments
 (0)