|
8 | 8 | AssistantMessage, |
9 | 9 | ChatMessage, |
10 | 10 | LLMClient, |
| 11 | + SummaryMessage, |
11 | 12 | SystemMessage, |
12 | 13 | TokenUsage, |
13 | 14 | Tool, |
|
22 | 23 | class MockLLMClient(LLMClient): |
23 | 24 | """Mock LLM client for testing.""" |
24 | 25 |
|
25 | | - def __init__(self, responses: list[AssistantMessage]) -> None: |
| 26 | + def __init__(self, responses: list[AssistantMessage], max_tokens: int = 100_000) -> None: |
26 | 27 | self.responses = responses |
27 | 28 | self.call_count = 0 |
| 29 | + self._max_tokens = max_tokens |
28 | 30 |
|
29 | 31 | @property |
30 | 32 | def model_slug(self) -> str: |
31 | 33 | return "mock-model" |
32 | 34 |
|
33 | 35 | @property |
34 | 36 | def max_tokens(self) -> int: |
35 | | - return 100_000 |
| 37 | + return self._max_tokens |
36 | 38 |
|
37 | 39 | async def generate(self, messages: list[ChatMessage], tools: dict[str, Tool]) -> AssistantMessage: # noqa: ARG002 |
38 | 40 | response = self.responses[self.call_count] |
@@ -493,3 +495,90 @@ async def test_allow_successive_assistant_messages() -> None: |
493 | 495 | messages = message_history[0] |
494 | 496 | continue_messages = [m for m in messages if isinstance(m, UserMessage) and m.content == "Please continue the task"] |
495 | 497 | assert len(continue_messages) == 0 |
| 498 | + |
| 499 | + |
| 500 | +async def test_summarize_history_has_one_summary_per_trajectory() -> None: |
| 501 | + """Test that each sub-trajectory in history contains at most one SummaryMessage. |
| 502 | +
|
| 503 | + Simulates an agent run where summarization triggers twice. Verifies: |
| 504 | + - history[0] (pre-first-summary) has 0 SummaryMessages |
| 505 | + - history[1] (post-first-summary) has exactly 1 SummaryMessage |
| 506 | + - history[2] (post-second-summary, final) has exactly 1 SummaryMessage |
| 507 | + """ |
| 508 | + # max_tokens=1000 and cutoff=0.3 means summarization triggers when |
| 509 | + # token_usage.total >= 300. Turns without tool calls also trigger |
| 510 | + # "Please continue" messages from block_successive_assistant_messages. |
| 511 | + |
| 512 | + responses = [ |
| 513 | + # Turn 1: high token usage triggers first summarization |
| 514 | + AssistantMessage( |
| 515 | + content="Working on it", |
| 516 | + tool_calls=[], |
| 517 | + token_usage=TokenUsage(input=250, answer=100), # total=350 >= 300 |
| 518 | + ), |
| 519 | + # First summarization generate call |
| 520 | + AssistantMessage( |
| 521 | + content="First summary of progress.", |
| 522 | + tool_calls=[], |
| 523 | + token_usage=TokenUsage(input=200, answer=50), |
| 524 | + ), |
| 525 | + # Turn 2: high token usage triggers second summarization |
| 526 | + AssistantMessage( |
| 527 | + content="Continuing work", |
| 528 | + tool_calls=[], |
| 529 | + token_usage=TokenUsage(input=250, answer=100), # total=350 >= 300 |
| 530 | + ), |
| 531 | + # Second summarization generate call |
| 532 | + AssistantMessage( |
| 533 | + content="Second summary of progress.", |
| 534 | + tool_calls=[], |
| 535 | + token_usage=TokenUsage(input=200, answer=50), |
| 536 | + ), |
| 537 | + # Turn 3: finish |
| 538 | + AssistantMessage( |
| 539 | + content="Done", |
| 540 | + tool_calls=[ |
| 541 | + ToolCall( |
| 542 | + name=FINISH_TOOL_NAME, |
| 543 | + arguments='{"reason": "Completed", "paths": []}', |
| 544 | + tool_call_id="call_finish", |
| 545 | + ) |
| 546 | + ], |
| 547 | + token_usage=TokenUsage(input=100, answer=50), |
| 548 | + ), |
| 549 | + ] |
| 550 | + |
| 551 | + client = MockLLMClient(responses, max_tokens=1000) |
| 552 | + |
| 553 | + agent = Agent( |
| 554 | + client=client, |
| 555 | + name="test-agent", |
| 556 | + max_turns=10, |
| 557 | + turns_remaining_warning_threshold=2, |
| 558 | + tools=[], |
| 559 | + finish_tool=SIMPLE_FINISH_TOOL, |
| 560 | + context_summarization_cutoff=0.3, |
| 561 | + ) |
| 562 | + |
| 563 | + async with agent.session() as session: |
| 564 | + _finish_params, history, _ = await session.run( |
| 565 | + [SystemMessage(content="System prompt"), UserMessage(content="Do the task")] |
| 566 | + ) |
| 567 | + |
| 568 | + # Should have 3 sub-trajectories: pre-summary, post-1st-summary, post-2nd-summary (final) |
| 569 | + assert len(history) == 3 |
| 570 | + |
| 571 | + # history[0]: original conversation before first summarization — no summaries |
| 572 | + summaries_0 = [m for m in history[0] if isinstance(m, SummaryMessage)] |
| 573 | + assert len(summaries_0) == 0 |
| 574 | + |
| 575 | + # history[1]: after first summarization — exactly 1 SummaryMessage |
| 576 | + summaries_1 = [m for m in history[1] if isinstance(m, SummaryMessage)] |
| 577 | + assert len(summaries_1) == 1 |
| 578 | + |
| 579 | + # history[2]: after second summarization — exactly 1 SummaryMessage (not 2) |
| 580 | + summaries_2 = [m for m in history[2] if isinstance(m, SummaryMessage)] |
| 581 | + assert len(summaries_2) == 1 |
| 582 | + |
| 583 | + # The summary content should be different between history[1] and history[2] |
| 584 | + assert summaries_1[0].content != summaries_2[0].content |
0 commit comments