Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,20 +199,21 @@ def test_tool_using_agent_callbacks(caplog: LogCap) -> None:
with Client() as client:
llm = client.llm.model(model_id)
chat = Chat()
chat.add_user_message("What is the sum of 123 and 3210?")
# Ensure the first response is a combination of text and tool use requests
chat.add_user_message("First say 'Hi'. Then calculate 1 + 3 with the tool.")
tools = [ADDITION_TOOL_SPEC]
round_starts: list[int] = []
round_ends: list[int] = []
first_tokens: list[int] = []
predictions: list[PredictionRoundResult] = []
fragments: list[LlmPredictionFragment] = []
last_fragment_round_index = 0
fragment_round_indices: set[int] = set()

def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None:
nonlocal last_fragment_round_index
last_fragment_round_index = max(fragment_round_indices, default=-1)
assert round_index >= last_fragment_round_index
fragments.append(f)
last_fragment_round_index = round_index
fragment_round_indices.add(round_index)

# TODO: Also check on_prompt_processing_progress and handling invalid messages
# (although it isn't clear how to provoke calls to the latter without mocking)
Expand All @@ -233,8 +234,9 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None:
assert round_starts == sequential_round_indices
assert round_ends == sequential_round_indices
expected_token_indices = [p.round_index for p in predictions if p.content]
assert expected_token_indices == sequential_round_indices
assert first_tokens == expected_token_indices
assert last_fragment_round_index == num_rounds - 1
assert fragment_round_indices == set(expected_token_indices)
assert len(chat._messages) == 2 * num_rounds # No tool results in last round

cloned_chat = chat.copy()
Expand Down
Loading