|
9 | 9 | AgentRunResponse, |
10 | 10 | AgentRunResponseUpdate, |
11 | 11 | AgentRunUpdateEvent, |
| 12 | + AgentThread, |
12 | 13 | ChatMessage, |
| 14 | + ChatMessageStore, |
13 | 15 | Executor, |
14 | 16 | FunctionApprovalRequestContent, |
15 | 17 | FunctionApprovalResponseContent, |
@@ -75,6 +77,31 @@ async def handle_request_response( |
75 | 77 | await ctx.add_event(AgentRunUpdateEvent(executor_id=self.id, data=update)) |
76 | 78 |
|
77 | 79 |
|
| 80 | +class ConversationHistoryCapturingExecutor(Executor): |
| 81 | + """Executor that captures the received conversation history for verification.""" |
| 82 | + |
| 83 | + def __init__(self, id: str): |
| 84 | + super().__init__(id=id) |
| 85 | + self.received_messages: list[ChatMessage] = [] |
| 86 | + |
| 87 | + @handler |
| 88 | + async def handle_message(self, messages: list[ChatMessage], ctx: WorkflowContext[list[ChatMessage]]) -> None: |
| 89 | + # Capture all received messages |
| 90 | + self.received_messages = list(messages) |
| 91 | + |
| 92 | + # Count messages by role for the response |
| 93 | + message_count = len(messages) |
| 94 | + response_text = f"Received {message_count} messages" |
| 95 | + |
| 96 | + response_message = ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text=response_text)]) |
| 97 | + |
| 98 | + streaming_update = AgentRunResponseUpdate( |
| 99 | + contents=[TextContent(text=response_text)], role=Role.ASSISTANT, message_id=str(uuid.uuid4()) |
| 100 | + ) |
| 101 | + await ctx.add_event(AgentRunUpdateEvent(executor_id=self.id, data=streaming_update)) |
| 102 | + await ctx.send_message([response_message]) |
| 103 | + |
| 104 | + |
78 | 105 | class TestWorkflowAgent: |
79 | 106 | """Test cases for WorkflowAgent end-to-end functionality.""" |
80 | 107 |
|
@@ -257,6 +284,105 @@ async def handle_bool(self, message: bool, context: WorkflowContext[Any]) -> Non |
257 | 284 | with pytest.raises(ValueError, match="Workflow's start executor cannot handle list\\[ChatMessage\\]"): |
258 | 285 | workflow.as_agent() |
259 | 286 |
|
| 287 | + async def test_thread_conversation_history_included_in_workflow_run(self) -> None: |
| 288 | + """Test that conversation history from thread is included when running WorkflowAgent. |
| 289 | +
|
| 290 | + This verifies that when a thread with existing messages is provided to agent.run(), |
| 291 | + the workflow receives the complete conversation history (thread history + new messages). |
| 292 | + """ |
| 293 | + # Create an executor that captures all received messages |
| 294 | + capturing_executor = ConversationHistoryCapturingExecutor(id="capturing") |
| 295 | + workflow = WorkflowBuilder().set_start_executor(capturing_executor).build() |
| 296 | + agent = WorkflowAgent(workflow=workflow, name="Thread History Test Agent") |
| 297 | + |
| 298 | + # Create a thread with existing conversation history |
| 299 | + history_messages = [ |
| 300 | + ChatMessage(role=Role.USER, text="Previous user message"), |
| 301 | + ChatMessage(role=Role.ASSISTANT, text="Previous assistant response"), |
| 302 | + ] |
| 303 | + message_store = ChatMessageStore(messages=history_messages) |
| 304 | + thread = AgentThread(message_store=message_store) |
| 305 | + |
| 306 | + # Run the agent with the thread and a new message |
| 307 | + new_message = "New user question" |
| 308 | + await agent.run(new_message, thread=thread) |
| 309 | + |
| 310 | + # Verify the executor received both history AND new message |
| 311 | + assert len(capturing_executor.received_messages) == 3 |
| 312 | + |
| 313 | + # Verify the order: history first, then new message |
| 314 | + assert capturing_executor.received_messages[0].text == "Previous user message" |
| 315 | + assert capturing_executor.received_messages[1].text == "Previous assistant response" |
| 316 | + assert capturing_executor.received_messages[2].text == "New user question" |
| 317 | + |
| 318 | + async def test_thread_conversation_history_included_in_workflow_stream(self) -> None: |
| 319 | + """Test that conversation history from thread is included when streaming WorkflowAgent. |
| 320 | +
|
| 321 | + This verifies that run_stream also includes thread history. |
| 322 | + """ |
| 323 | + # Create an executor that captures all received messages |
| 324 | + capturing_executor = ConversationHistoryCapturingExecutor(id="capturing_stream") |
| 325 | + workflow = WorkflowBuilder().set_start_executor(capturing_executor).build() |
| 326 | + agent = WorkflowAgent(workflow=workflow, name="Thread Stream Test Agent") |
| 327 | + |
| 328 | + # Create a thread with existing conversation history |
| 329 | + history_messages = [ |
| 330 | + ChatMessage(role=Role.SYSTEM, text="You are a helpful assistant"), |
| 331 | + ChatMessage(role=Role.USER, text="Hello"), |
| 332 | + ChatMessage(role=Role.ASSISTANT, text="Hi there!"), |
| 333 | + ] |
| 334 | + message_store = ChatMessageStore(messages=history_messages) |
| 335 | + thread = AgentThread(message_store=message_store) |
| 336 | + |
| 337 | + # Stream from the agent with the thread and a new message |
| 338 | + async for _ in agent.run_stream("How are you?", thread=thread): |
| 339 | + pass |
| 340 | + |
| 341 | + # Verify the executor received all messages (3 from history + 1 new) |
| 342 | + assert len(capturing_executor.received_messages) == 4 |
| 343 | + |
| 344 | + # Verify the order |
| 345 | + assert capturing_executor.received_messages[0].text == "You are a helpful assistant" |
| 346 | + assert capturing_executor.received_messages[1].text == "Hello" |
| 347 | + assert capturing_executor.received_messages[2].text == "Hi there!" |
| 348 | + assert capturing_executor.received_messages[3].text == "How are you?" |
| 349 | + |
| 350 | + async def test_empty_thread_works_correctly(self) -> None: |
| 351 | + """Test that an empty thread (no message store) works correctly.""" |
| 352 | + capturing_executor = ConversationHistoryCapturingExecutor(id="empty_thread_test") |
| 353 | + workflow = WorkflowBuilder().set_start_executor(capturing_executor).build() |
| 354 | + agent = WorkflowAgent(workflow=workflow, name="Empty Thread Test Agent") |
| 355 | + |
| 356 | + # Create an empty thread |
| 357 | + thread = AgentThread() |
| 358 | + |
| 359 | + # Run with the empty thread |
| 360 | + await agent.run("Just a new message", thread=thread) |
| 361 | + |
| 362 | + # Should only receive the new message |
| 363 | + assert len(capturing_executor.received_messages) == 1 |
| 364 | + assert capturing_executor.received_messages[0].text == "Just a new message" |
| 365 | + |
| 366 | + async def test_checkpoint_storage_passed_to_workflow(self) -> None: |
| 367 | + """Test that checkpoint_storage parameter is passed through to the workflow.""" |
| 368 | + from agent_framework import InMemoryCheckpointStorage |
| 369 | + |
| 370 | + capturing_executor = ConversationHistoryCapturingExecutor(id="checkpoint_test") |
| 371 | + workflow = WorkflowBuilder().set_start_executor(capturing_executor).build() |
| 372 | + agent = WorkflowAgent(workflow=workflow, name="Checkpoint Test Agent") |
| 373 | + |
| 374 | + # Create checkpoint storage |
| 375 | + checkpoint_storage = InMemoryCheckpointStorage() |
| 376 | + |
| 377 | + # Run with checkpoint storage enabled |
| 378 | + async for _ in agent.run_stream("Test message", checkpoint_storage=checkpoint_storage): |
| 379 | + pass |
| 380 | + |
| 381 | + # Drain workflow events to get checkpoint |
| 382 | + # The workflow should have created checkpoints |
| 383 | + checkpoints = await checkpoint_storage.list_checkpoints(workflow.id) |
| 384 | + assert len(checkpoints) > 0, "Checkpoints should have been created when checkpoint_storage is provided" |
| 385 | + |
260 | 386 |
|
261 | 387 | class TestWorkflowAgentMergeUpdates: |
262 | 388 | """Test cases specifically for the WorkflowAgent.merge_updates static method.""" |
|
0 commit comments