diff --git a/tests/test_inference.py b/tests/test_inference.py index 73d05a0..d7bb02a 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -199,20 +199,21 @@ def test_tool_using_agent_callbacks(caplog: LogCap) -> None: with Client() as client: llm = client.llm.model(model_id) chat = Chat() - chat.add_user_message("What is the sum of 123 and 3210?") + # Ensure the first response is a combination of text and tool use requests + chat.add_user_message("First say 'Hi'. Then calculate 1 + 3 with the tool.") tools = [ADDITION_TOOL_SPEC] round_starts: list[int] = [] round_ends: list[int] = [] first_tokens: list[int] = [] predictions: list[PredictionRoundResult] = [] fragments: list[LlmPredictionFragment] = [] - last_fragment_round_index = 0 + fragment_round_indices: set[int] = set() def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None: - nonlocal last_fragment_round_index + last_fragment_round_index = max(fragment_round_indices, default=-1) assert round_index >= last_fragment_round_index fragments.append(f) - last_fragment_round_index = round_index + fragment_round_indices.add(round_index) # TODO: Also check on_prompt_processing_progress and handling invalid messages # (although it isn't clear how to provoke calls to the latter without mocking) @@ -233,8 +234,9 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None: assert round_starts == sequential_round_indices assert round_ends == sequential_round_indices expected_token_indices = [p.round_index for p in predictions if p.content] + assert expected_token_indices == sequential_round_indices assert first_tokens == expected_token_indices - assert last_fragment_round_index == num_rounds - 1 + assert fragment_round_indices == set(expected_token_indices) assert len(chat._messages) == 2 * num_rounds # No tool results in last round cloned_chat = chat.copy()