Skip to content

Commit e810133

Browse files
committed
WIP: Account for cache costs when calculating cost limit
Adds test with multiple models to validate behavior.
1 parent 75e430f commit e810133

File tree

2 files changed

+70
-1
lines changed

2 files changed

+70
-1
lines changed

src/inspect_ai/model/_model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,12 @@ async def generate() -> tuple[ModelOutput, BaseModel]:
617617
output=existing,
618618
call=None,
619619
)
620-
# TODO: Update cost info based on the cache hit
620+
# Cost limits should still be updated on cache hits
621+
if existing.usage:
622+
total_cost = calculate_model_usage_cost(
623+
{cache_entry.model: existing.usage}
624+
)
625+
record_model_usage_cost(total_cost)
621626
return existing, event
622627
else:
623628
cache_entry = None

tests/test_sample_limits.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,70 @@ def test_cost_limit(tmp_path):
340340
check_cost_limit_event(log, cost_limit)
341341

342342

343+
def test_multi_model_cost_limit(tmp_path):
344+
# TODO: how can we support multiple models
345+
346+
# Build a temporary JSON file under tmp_path (set by pytest)
347+
cost_file = tmp_path / "cost_config.json"
348+
data = {
349+
"model": {
350+
"input_cost_per_token": 0.01,
351+
"output_cost_per_token": 0.001,
352+
"cache_read_input_token_cost": 0.10,
353+
},
354+
"other_model": {
355+
"input_cost_per_token": 0.01,
356+
"output_cost_per_token": 0.001,
357+
"cache_read_input_token_cost": 0.10,
358+
},
359+
}
360+
cost_file.write_text(json.dumps(data))
361+
362+
model1 = get_model(
363+
"mockllm/model",
364+
custom_outputs=repeat_forever(
365+
mock_model_output(
366+
# Configure so each generation produces 1 unique input, 1 input cache,
367+
# and 1 output token (total of 3 tokens)
368+
input_tokens=1, # Unique input tokens
369+
input_tokens_cache_read=1, # Cached input tokens
370+
output_tokens=1,
371+
total_tokens=3,
372+
)
373+
),
374+
)
375+
376+
model2 = get_model(
377+
"mockllm/other_model",
378+
custom_outputs=repeat_forever(
379+
mock_model_output(
380+
# Configure so each generation produces 1 unique input, 1 input cache,
381+
# and 1 output token (total of 3 tokens)
382+
input_tokens=1, # Unique input tokens
383+
input_tokens_cache_read=1, # Cached input tokens
384+
output_tokens=1,
385+
total_tokens=3,
386+
)
387+
),
388+
)
389+
# With our simulated costs, each turn should cost $0.111 so after 10 turns
390+
# we should hit the limit at 30 total tokens.
391+
# The cost limit should be hit while the token and turn limits should not
392+
token_limit = 31
393+
message_limit = 21 # Expect 10 messages from "user", 10 from assistant
394+
cost_limit = 1.00
395+
396+
log = eval(
397+
Task(solver=looping_solver()),
398+
model=[model1, model2],
399+
token_limit=token_limit,
400+
message_limit=message_limit,
401+
cost_limit=cost_limit,
402+
cost_file=cost_file,
403+
)[0]
404+
check_cost_limit_event(log, cost_limit)
405+
406+
343407
@pytest.mark.slow
344408
@skip_if_no_docker
345409
def test_working_limit_does_not_raise_during_sandbox_teardown() -> None:

0 commit comments

Comments
 (0)