Skip to content

Commit 3653973

Browse files
authored
add a method to Messages to set a cache breakpoint (#335)
1 parent 551ed0e commit 3653973

File tree

6 files changed

+212
-21
lines changed

6 files changed

+212
-21
lines changed

src/aviary/env.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,13 @@ async def step(
550550
) -> tuple[Messages, float, bool, bool]:
551551
msgs: Messages = await self.exec_tool_calls( # type: ignore[assignment]
552552
action, state=self.state, concurrency=self.concurrent_tool_calls
553-
) or [Message(content=f"No tool calls input in tool request {action}.")]
553+
) or [
554+
ToolResponseMessage(
555+
content=f"No tool calls input in tool request {action}.",
556+
name="",
557+
tool_call_id="",
558+
)
559+
]
554560
self.state.messages.extend(msgs)
555561
return msgs, self.state.reward, self.state.done, False
556562

src/aviary/message.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import logging
23
from collections.abc import Iterable
34
from typing import TYPE_CHECKING, ClassVar, Self
45

@@ -18,6 +19,8 @@
1819

1920
import numpy as np
2021

22+
logger = logging.getLogger(__name__)
23+
2124

2225
class Message(BaseModel):
2326
DEFAULT_ROLE: ClassVar[str] = "user"
@@ -60,6 +63,14 @@ class Message(BaseModel):
6063
repr=False,
6164
)
6265

66+
cache_breakpoint: bool = Field(
67+
default=False,
68+
description="Mark this message as a cache breakpoint for prompt caching. "
69+
"When True, adds cache_control to the content during serialization.",
70+
exclude=True,
71+
repr=False,
72+
)
73+
6374
@field_validator("role")
6475
@classmethod
6576
def check_role(cls, v: str) -> str:
@@ -97,14 +108,44 @@ def _serialize(self, handler, info: SerializationInfo):
97108
as LLM APIs expect multimodal content as structured blocks.
98109
- Other structured content stays as a JSON string,
99110
as tool response content must be a string for LLM API compatibility.
111+
112+
For cache_breakpoint:
113+
- When True, adds cache_control to the content for prompt caching.
114+
- String content is converted to content block format, the list-of-dicts
115+
representation that LLM APIs use for structured content, e.g.
116+
[{"type": "text", "text": "hello", "cache_control": {"type": "ephemeral"}}].
117+
- Multimodal content has cache_control added to the last block.
100118
"""
101119
data = handler(self)
102-
if (
103-
self.is_multimodal
104-
and "content" in data
105-
and (info.context or {}).get("deserialize_content", True)
106-
):
120+
deserialize_content = (info.context or {}).get("deserialize_content", True)
121+
if self.is_multimodal and "content" in data and deserialize_content:
107122
data["content"] = json.loads(data["content"])
123+
124+
# Handle cache_breakpoint - add cache_control to content
125+
# Skip when deserialize_content=False as it would convert string to list,
126+
# breaking call sites that require string content (e.g., ToolResponseMessage)
127+
if self.cache_breakpoint and not deserialize_content:
128+
logger.warning(
129+
"cache_breakpoint ignored: deserialize_content=False requires string content"
130+
)
131+
elif (
132+
self.cache_breakpoint and "content" in data and data["content"] is not None
133+
):
134+
cache_control = {"type": "ephemeral"}
135+
if isinstance(data["content"], list):
136+
# Multimodal: add cache_control to last block
137+
if data["content"]:
138+
data["content"][-1]["cache_control"] = cache_control
139+
else:
140+
# String content: convert to content block format with cache_control
141+
data["content"] = [
142+
{
143+
"type": "text",
144+
"text": data["content"],
145+
"cache_control": cache_control,
146+
}
147+
]
148+
108149
if (info.context or {}).get("include_info"):
109150
data["info"] = self.info
110151
return data

tests/test_messages.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
import pytest
5+
from lmi import LiteLLMModel
56

67
from aviary.core import (
78
Message,
@@ -372,3 +373,145 @@ def test_prepend_text(self, subtests) -> None:
372373
assert trm.content is not None
373374
content_list_original = json.loads(trm.content)
374375
assert len(content_list_original) == 3
376+
377+
378+
class TestCacheBreakpoint:
379+
def test_default_is_false(self) -> None:
380+
msg = Message(content="test")
381+
assert not msg.cache_breakpoint
382+
383+
def test_serialization_without_cache_breakpoint(self) -> None:
384+
data = Message(content="test").model_dump(exclude_none=True)
385+
assert data == {"role": "user", "content": "test"}
386+
387+
@pytest.mark.parametrize(
388+
("content", "expected_content"),
389+
[
390+
(
391+
"test",
392+
[
393+
{
394+
"type": "text",
395+
"text": "test",
396+
"cache_control": {"type": "ephemeral"},
397+
}
398+
],
399+
),
400+
(
401+
[{"type": "text", "text": "first"}, {"type": "text", "text": "second"}],
402+
[
403+
{"type": "text", "text": "first"},
404+
{
405+
"type": "text",
406+
"text": "second",
407+
"cache_control": {"type": "ephemeral"},
408+
},
409+
],
410+
),
411+
],
412+
)
413+
def test_serialization_with_cache_breakpoint(
414+
self, content, expected_content
415+
) -> None:
416+
data = Message(content=content, cache_breakpoint=True).model_dump(
417+
exclude_none=True
418+
)
419+
assert data == {"role": "user", "content": expected_content}
420+
421+
def test_serialization_with_cache_breakpoint_empty_content(self) -> None:
422+
data = Message(content=None, cache_breakpoint=True).model_dump(
423+
exclude_none=True
424+
)
425+
# Should not crash, content stays None
426+
assert data == {"role": "user"}
427+
428+
def test_cache_breakpoint_excluded_from_dump(self) -> None:
429+
data = Message(content="test", cache_breakpoint=True).model_dump()
430+
assert "cache_breakpoint" not in data
431+
432+
def test_cache_breakpoint_with_image_content(self) -> None:
433+
data = Message.create_message(
434+
text="Describe this image",
435+
images=[
436+
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
437+
],
438+
cache_breakpoint=True,
439+
).model_dump(exclude_none=True)
440+
# cache_control should be on the last block (the text block)
441+
assert len(data["content"]) == 2
442+
assert data["content"][0]["type"] == "image_url"
443+
assert "cache_control" not in data["content"][0]
444+
assert data["content"][1]["type"] == "text"
445+
assert data["content"][1]["cache_control"] == {"type": "ephemeral"}
446+
447+
def test_cache_breakpoint_skipped_when_deserialize_content_false(self) -> None:
448+
data = Message(content="test", cache_breakpoint=True).model_dump(
449+
context={"deserialize_content": False}
450+
)
451+
# Content should remain a string, cache_breakpoint not applied
452+
assert data["content"] == "test"
453+
454+
def test_cache_breakpoint_logs_warning_when_skipped(self, caplog) -> None:
455+
import logging
456+
457+
msg = Message(content="test", cache_breakpoint=True)
458+
with caplog.at_level(logging.WARNING):
459+
msg.model_dump(context={"deserialize_content": False})
460+
assert "cache_breakpoint ignored" in caplog.text
461+
462+
463+
def _make_long_content(prefix: str, num_items: int = 300) -> str:
464+
"""Generate long content for cache testing (>1024 tokens for Anthropic)."""
465+
return prefix + " ".join(f"item_{i}" for i in range(num_items))
466+
467+
468+
@pytest.mark.asyncio
469+
@pytest.mark.parametrize(
470+
("model_name", "require_cache_hit"),
471+
[
472+
("claude-3-5-haiku-20241022", True),
473+
("gpt-4o-mini", False),
474+
],
475+
)
476+
async def test_cache_breakpoint_live(model_name: str, require_cache_hit: bool) -> None:
477+
"""Verify cache breakpoint behavior with different providers.
478+
479+
For Anthropic: cache_breakpoint causes upstream content to be cached.
480+
For OpenAI: LiteLLM correctly strips cache_control, and OpenAI's automatic
481+
prefix caching may or may not activate.
482+
"""
483+
system_msg = Message(role="system", content=_make_long_content("System: "))
484+
user_context = Message(role="user", content=_make_long_content("Context: "))
485+
user_context.cache_breakpoint = True
486+
assistant_msg = Message(role="assistant", content="Acknowledged.")
487+
user_question = Message(role="user", content="Summarize.")
488+
489+
messages = [system_msg, user_context, assistant_msg, user_question]
490+
llm = LiteLLMModel(name=model_name)
491+
492+
# First request - may create cache or hit existing cache
493+
result1 = await llm.call_single(messages)
494+
if require_cache_hit:
495+
cache_active = (result1.cache_creation_tokens or 0) > 0 or (
496+
result1.cache_read_tokens or 0
497+
) > 0
498+
assert cache_active, "Expected cache creation or cache read on first request"
499+
else:
500+
assert result1.text is not None
501+
502+
# Second request - should hit cache (for Anthropic) or may hit (for OpenAI)
503+
result2 = await llm.call_single(messages)
504+
if require_cache_hit:
505+
assert (result2.cache_read_tokens or 0) > 0, (
506+
"Expected cache hit on second request"
507+
)
508+
assert (result2.cache_read_tokens or 0) > 500, (
509+
f"Expected >500 cached tokens, got {result2.cache_read_tokens}"
510+
)
511+
else:
512+
assert result2.text is not None
513+
# OpenAI's caching is automatic and not guaranteed
514+
if result2.cache_read_tokens is not None and result2.cache_read_tokens > 0:
515+
assert result2.cache_read_tokens > 500, (
516+
f"Expected >500 cached tokens if cache hit, got {result2.cache_read_tokens}"
517+
)

tests/test_tools.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,7 @@ def __init__(self):
762762

763763
@pytest.mark.asyncio
764764
async def test_argref_by_name_async_functions() -> None:
765+
# pylint: disable=unexpected-keyword-arg
765766
class MyState:
766767
def __init__(self):
767768
self.refs = {"foo": 1, "bar": 7}
@@ -809,6 +810,7 @@ async def async_list_direct(a: int, b: int) -> list[int]: # noqa: RUF029
809810

810811
@pytest.mark.asyncio
811812
async def test_argref_by_name_advanced_features() -> None:
813+
# pylint: disable=unexpected-keyword-arg
812814
class MyState:
813815
def __init__(self):
814816
self.refs = {"foo": 1}
@@ -817,9 +819,9 @@ def __init__(self):
817819

818820
# Define and test dereference via no state value found with return_direct
819821
@argref_by_name(return_direct=True)
820-
def skip_deref_test(foo: float, a: str) -> str:
822+
def skip_deref_test(ref_key: float, a: str) -> str:
821823
"""Some docstring."""
822-
return f"{foo} {a}"
824+
return f"{ref_key} {a}"
823825

824826
assert skip_deref_test("foo", "not in state", state=s) == "1 not in state"
825827
assert skip_deref_test("foo", "foo", state=s) == "1 1"
@@ -940,7 +942,7 @@ async def reset(self) -> tuple[list[Message], list[Tool]]:
940942
async def step(self, action):
941943
return await self.exec_tool_calls(action), False, 0, 0
942944

943-
async def export_frame(self):
945+
def export_frame(self):
944946
pass
945947

946948
with suppress_type_checks():

tests/test_utils.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,13 @@ class SomeModel(BaseModel):
155155
assert isinstance(gauss_next, float | None)
156156

157157
# 2. Check deserialized RNG behaves as original RNG
158-
for i, deserialized_model in enumerate((
159-
SomeModel.model_validate_json(model.model_dump_json()), # JSON str
160-
SomeModel.model_validate(model.model_dump(mode="json")), # JSON dict
161-
)):
162-
if i == 0:
163-
# Sample original model once so RNG aligns for both deserialized
164-
# models in the `for` loop
165-
sampled_original = model.rng.sample(list(range(10)), k=6)
158+
json_str = model.model_dump_json()
159+
json_dict = model.model_dump(mode="json")
160+
sampled_original = model.rng.sample(list(range(10)), k=6)
161+
for deserialized_model in (
162+
SomeModel.model_validate_json(json_str),
163+
SomeModel.model_validate(json_dict),
164+
):
166165
assert isinstance(deserialized_model.rng, random.Random)
167166
sampled_deserialized = deserialized_model.rng.sample(list(range(10)), k=6)
168167
assert sampled_original == sampled_deserialized, (

uv.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)