Skip to content

Commit cc5f0a3

Browse files
committed
test: update usage assertions to include tool_calls in various tests
- Modified assertions in multiple test files to include the new `tool_calls` attribute in `RunUsage`. - Added new test cases to validate tool call limits and ensure correct behavior when exceeding limits. - Updated existing tests to reflect changes in tool call tracking across different models and scenarios.
1 parent 6b45b5f commit cc5f0a3

File tree

10 files changed

+41
-11
lines changed

10 files changed

+41
-11
lines changed

tests/models/test_anthropic.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,11 @@ async def my_tool(first: str, second: str) -> int:
659659
requests=2,
660660
input_tokens=20,
661661
output_tokens=5,
662-
details={'input_tokens': 20, 'output_tokens': 5},
662+
tool_calls=1,
663+
details={
664+
'input_tokens': 20,
665+
'output_tokens': 5,
666+
},
663667
)
664668
)
665669
assert tool_called

tests/models/test_bedrock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ async def temperature(city: str, date: datetime.date) -> str:
111111

112112
result = await agent.run('What was the temperature in London 1st January 2022?', output_type=Response)
113113
assert result.output == snapshot({'temperature': '30°C', 'date': datetime.date(2022, 1, 1), 'city': 'London'})
114-
assert result.usage() == snapshot(RunUsage(requests=2, input_tokens=1236, output_tokens=298))
114+
assert result.usage() == snapshot(RunUsage(requests=2, input_tokens=1236, output_tokens=298, tool_calls=2))
115115
assert result.all_messages() == snapshot(
116116
[
117117
ModelRequest(

tests/models/test_cohere.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ async def get_location(loc_name: str) -> str:
325325
input_tokens=5,
326326
output_tokens=3,
327327
details={'input_tokens': 4, 'output_tokens': 2},
328+
tool_calls=2,
328329
)
329330
)
330331

tests/models/test_gemini.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -783,7 +783,7 @@ async def get_location(loc_name: str) -> str:
783783
),
784784
]
785785
)
786-
assert result.usage() == snapshot(RunUsage(requests=3, input_tokens=3, output_tokens=6))
786+
assert result.usage() == snapshot(RunUsage(requests=3, input_tokens=3, output_tokens=6, tool_calls=3))
787787

788788

789789
async def test_unexpected_response(client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None):
@@ -932,7 +932,7 @@ async def bar(y: str) -> str:
932932
async with agent.run_stream('Hello') as result:
933933
response = await result.get_output()
934934
assert response == snapshot((1, 2))
935-
assert result.usage() == snapshot(RunUsage(requests=2, input_tokens=2, output_tokens=4))
935+
assert result.usage() == snapshot(RunUsage(requests=2, input_tokens=2, output_tokens=4, tool_calls=2))
936936
assert result.all_messages() == snapshot(
937937
[
938938
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),

tests/models/test_google.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ async def temperature(city: str, date: datetime.date) -> str:
147147
requests=2,
148148
input_tokens=224,
149149
output_tokens=35,
150+
tool_calls=2,
150151
details={'text_prompt_tokens': 224, 'text_candidates_tokens': 35},
151152
)
152153
)

tests/models/test_openai.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,9 @@ async def get_location(loc_name: str) -> str:
416416
),
417417
]
418418
)
419-
assert result.usage() == snapshot(RunUsage(requests=3, cache_read_tokens=3, input_tokens=5, output_tokens=3))
419+
assert result.usage() == snapshot(
420+
RunUsage(requests=3, cache_read_tokens=3, input_tokens=5, output_tokens=3, tool_calls=2)
421+
)
420422

421423

422424
FinishReason = Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call']

tests/test_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1954,7 +1954,7 @@ async def ret_a(x: str) -> str:
19541954
assert result2.output == snapshot(Response(a=0))
19551955
assert result2._new_message_index == snapshot(5) # pyright: ignore[reportPrivateUsage]
19561956
assert result2._output_tool_name == snapshot('final_result') # pyright: ignore[reportPrivateUsage]
1957-
assert result2.usage() == snapshot(RunUsage(requests=1, input_tokens=59, output_tokens=13))
1957+
assert result2.usage() == snapshot(RunUsage(requests=1, input_tokens=59, output_tokens=13, tool_calls=1))
19581958
new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()]
19591959
assert new_msg_part_kinds == snapshot(
19601960
[

tests/test_examples.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ async def call_tool(
369369
'The capital of Italy is Rome (Roma, in Italian), which has been a cultural and political center for centuries.'
370370
'Rome is known for its rich history, stunning architecture, and delicious cuisine.'
371371
),
372+
'Please call the tool twice': ToolCallPart(tool_name='do_work', args={}, tool_call_id='pyd_ai_tool_call_id'),
372373
'Begin infinite retry loop!': ToolCallPart(
373374
tool_name='infinite_retry_tool', args={}, tool_call_id='pyd_ai_tool_call_id'
374375
),
@@ -626,6 +627,9 @@ async def model_logic( # noqa: C901
626627
return ModelResponse(
627628
parts=[ToolCallPart(tool_name='final_result', args=args, tool_call_id='pyd_ai_tool_call_id')]
628629
)
630+
elif isinstance(m, ToolReturnPart) and m.tool_name == 'do_work':
631+
# Trigger a second tool call to cause tool_calls_limit to be exceeded in the docs example
632+
return ModelResponse(parts=[ToolCallPart(tool_name='do_work', args={}, tool_call_id='pyd_ai_tool_call_id')])
629633
elif isinstance(m, RetryPromptPart) and m.tool_name == 'calc_volume':
630634
return ModelResponse(
631635
parts=[ToolCallPart(tool_name='calc_volume', args={'size': 6}, tool_call_id='pyd_ai_tool_call_id')]

tests/test_streaming.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ async def ret_a(x: str) -> str:
8282
requests=2,
8383
input_tokens=103,
8484
output_tokens=5,
85+
tool_calls=1,
8586
)
8687
)
8788
response = await result.get_output()
@@ -117,6 +118,7 @@ async def ret_a(x: str) -> str:
117118
requests=2,
118119
input_tokens=103,
119120
output_tokens=11,
121+
tool_calls=1,
120122
)
121123
)
122124

tests/test_usage_limits.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ async def ret_a(x: str) -> str:
120120
requests=2,
121121
input_tokens=103,
122122
output_tokens=5,
123+
tool_calls=1,
123124
)
124125
)
125126
succeeded = True
@@ -151,26 +152,26 @@ async def delegate_to_other_agent1(ctx: RunContext[None], sentence: str) -> int:
151152
delegate_result = await delegate_agent.run(sentence)
152153
delegate_usage = delegate_result.usage()
153154
run_1_usages.append(delegate_usage)
154-
assert delegate_usage == snapshot(RunUsage(requests=1, input_tokens=51, output_tokens=4))
155+
assert delegate_usage == snapshot(RunUsage(requests=1, input_tokens=51, output_tokens=4, tool_calls=1))
155156
return delegate_result.output
156157

157158
result1 = await controller_agent1.run('foobar')
158159
assert result1.output == snapshot('{"delegate_to_other_agent1":0}')
159160
run_1_usages.append(result1.usage())
160-
assert result1.usage() == snapshot(RunUsage(requests=2, input_tokens=103, output_tokens=13))
161+
assert result1.usage() == snapshot(RunUsage(requests=2, input_tokens=103, output_tokens=13, tool_calls=1))
161162

162163
controller_agent2 = Agent(TestModel())
163164

164165
@controller_agent2.tool
165166
async def delegate_to_other_agent2(ctx: RunContext[None], sentence: str) -> int:
166167
delegate_result = await delegate_agent.run(sentence, usage=ctx.usage)
167168
delegate_usage = delegate_result.usage()
168-
assert delegate_usage == snapshot(RunUsage(requests=2, input_tokens=102, output_tokens=9))
169+
assert delegate_usage == snapshot(RunUsage(requests=2, input_tokens=102, output_tokens=9, tool_calls=2))
169170
return delegate_result.output
170171

171172
result2 = await controller_agent2.run('foobar')
172173
assert result2.output == snapshot('{"delegate_to_other_agent2":0}')
173-
assert result2.usage() == snapshot(RunUsage(requests=3, input_tokens=154, output_tokens=17))
174+
assert result2.usage() == snapshot(RunUsage(requests=3, input_tokens=154, output_tokens=17, tool_calls=2))
174175

175176
# confirm the usage from result2 is the sum of the usage from result1
176177
assert result2.usage() == functools.reduce(operator.add, run_1_usages)
@@ -197,7 +198,7 @@ def delegate_to_other_agent(ctx: RunContext[None], sentence: str) -> int:
197198

198199
result = await controller_agent.run('foobar')
199200
assert result.output == snapshot('{"delegate_to_other_agent":0}')
200-
assert result.usage() == snapshot(RunUsage(requests=7, input_tokens=105, output_tokens=16))
201+
assert result.usage() == snapshot(RunUsage(requests=7, input_tokens=105, output_tokens=16, tool_calls=1))
201202

202203

203204
def test_request_usage_basics():
@@ -215,6 +216,7 @@ def test_add_usages():
215216
cache_write_tokens=40,
216217
input_audio_tokens=50,
217218
cache_audio_read_tokens=60,
219+
tool_calls=3,
218220
details={
219221
'custom1': 10,
220222
'custom2': 20,
@@ -229,13 +231,27 @@ def test_add_usages():
229231
cache_read_tokens=60,
230232
input_audio_tokens=100,
231233
cache_audio_read_tokens=120,
234+
tool_calls=6,
232235
details={'custom1': 20, 'custom2': 40},
233236
)
234237
)
235238
assert usage + RunUsage() == usage
236239
assert RunUsage() + RunUsage() == RunUsage()
237240

238241

242+
async def test_tool_call_limit() -> None:
243+
test_agent = Agent(TestModel())
244+
245+
@test_agent.tool_plain
246+
async def ret_a(x: str) -> str:
247+
return f'{x}-apple'
248+
249+
with pytest.raises(
250+
UsageLimitExceeded, match=re.escape('The next tool call would exceed the tool_calls_limit of 0 (tool_calls=0)')
251+
):
252+
await test_agent.run('Hello', usage_limits=UsageLimits(tool_calls_limit=0))
253+
254+
239255
def test_deprecated_usage_limits():
240256
with warns(
241257
snapshot(['DeprecationWarning: `request_tokens_limit` is deprecated, use `input_tokens_limit` instead'])

0 commit comments

Comments
 (0)