Skip to content
Merged
51 changes: 46 additions & 5 deletions src/aviary/message.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import logging
from collections.abc import Iterable
from typing import TYPE_CHECKING, ClassVar, Self

Expand All @@ -18,6 +19,8 @@

import numpy as np

logger = logging.getLogger(__name__)


class Message(BaseModel):
DEFAULT_ROLE: ClassVar[str] = "user"
Expand Down Expand Up @@ -60,6 +63,14 @@ class Message(BaseModel):
repr=False,
)

cache_breakpoint: bool = Field(
default=False,
description="Mark this message as a cache breakpoint for prompt caching. "
"When True, adds cache_control to the content during serialization.",
exclude=True,
repr=False,
)

@field_validator("role")
@classmethod
def check_role(cls, v: str) -> str:
Expand Down Expand Up @@ -97,14 +108,44 @@ def _serialize(self, handler, info: SerializationInfo):
as LLM APIs expect multimodal content as structured blocks.
- Other structured content stays as a JSON string,
as tool response content must be a string for LLM API compatibility.

For cache_breakpoint:
- When True, adds cache_control to the content for prompt caching.
- String content is converted to content block format, the list-of-dicts
representation that LLM APIs use for structured content, e.g.
[{"type": "text", "text": "hello", "cache_control": {"type": "ephemeral"}}].
- Multimodal content has cache_control added to the last block.
"""
data = handler(self)
if (
self.is_multimodal
and "content" in data
and (info.context or {}).get("deserialize_content", True)
):
deserialize_content = (info.context or {}).get("deserialize_content", True)
if self.is_multimodal and "content" in data and deserialize_content:
data["content"] = json.loads(data["content"])

# Handle cache_breakpoint - add cache_control to content
# Skip when deserialize_content=False as it would convert string to list,
# breaking call sites that require string content (e.g., ToolResponseMessage)
if self.cache_breakpoint and not deserialize_content:
logger.warning(
"cache_breakpoint ignored: deserialize_content=False requires string content"
)
elif (
self.cache_breakpoint and "content" in data and data["content"] is not None
):
cache_control = {"type": "ephemeral"}
if isinstance(data["content"], list):
# Multimodal: add cache_control to last block
if data["content"]:
data["content"][-1]["cache_control"] = cache_control
else:
# String content: convert to content block format with cache_control
data["content"] = [
{
"type": "text",
"text": data["content"],
"cache_control": cache_control,
}
]

if (info.context or {}).get("include_info"):
data["info"] = self.info
return data
Expand Down
170 changes: 170 additions & 0 deletions tests/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,3 +372,173 @@ def test_prepend_text(self, subtests) -> None:
assert trm.content is not None
content_list_original = json.loads(trm.content)
assert len(content_list_original) == 3


class TestCacheBreakpoint:
def test_default_is_false(self) -> None:
msg = Message(content="test")
assert not msg.cache_breakpoint

def test_serialization_without_cache_breakpoint(self) -> None:
data = Message(content="test").model_dump(exclude_none=True)
assert data == {"role": "user", "content": "test"}

def test_serialization_with_cache_breakpoint_string_content(self) -> None:
data = Message(content="test", cache_breakpoint=True).model_dump(
exclude_none=True
)
assert data == {
"role": "user",
"content": [
{"type": "text", "text": "test", "cache_control": {"type": "ephemeral"}}
],
}

def test_serialization_with_cache_breakpoint_multimodal_content(self) -> None:
data = Message(
content=[
{"type": "text", "text": "first"},
{"type": "text", "text": "second"},
],
cache_breakpoint=True,
).model_dump(exclude_none=True)
# cache_control should be on the last block
assert data["content"][0] == {"type": "text", "text": "first"}
assert data["content"][1] == {
"type": "text",
"text": "second",
"cache_control": {"type": "ephemeral"},
}

def test_serialization_with_cache_breakpoint_empty_content(self) -> None:
data = Message(content=None, cache_breakpoint=True).model_dump(
exclude_none=True
)
# Should not crash, content stays None
assert data == {"role": "user"}

def test_cache_breakpoint_excluded_from_dump(self) -> None:
data = Message(content="test", cache_breakpoint=True).model_dump()
assert "cache_breakpoint" not in data

def test_cache_breakpoint_with_image_content(self) -> None:
data = Message.create_message(
text="Describe this image",
images=[
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
],
cache_breakpoint=True,
).model_dump(exclude_none=True)
# cache_control should be on the last block (the text block)
assert len(data["content"]) == 2
assert data["content"][0]["type"] == "image_url"
assert "cache_control" not in data["content"][0]
assert data["content"][1]["type"] == "text"
assert data["content"][1]["cache_control"] == {"type": "ephemeral"}

def test_cache_breakpoint_skipped_when_deserialize_content_false(self) -> None:
data = Message(content="test", cache_breakpoint=True).model_dump(
context={"deserialize_content": False}
)
# Content should remain a string, cache_breakpoint not applied
assert data["content"] == "test"

def test_cache_breakpoint_logs_warning_when_skipped(self, caplog) -> None:
import logging

msg = Message(content="test", cache_breakpoint=True)
with caplog.at_level(logging.WARNING):
msg.model_dump(context={"deserialize_content": False})
assert "cache_breakpoint ignored" in caplog.text


def _make_long_content(prefix: str, num_items: int = 300) -> str:
"""Generate long content for cache testing (>1024 tokens for Anthropic)."""
return prefix + " ".join(f"item_{i}" for i in range(num_items))


@pytest.mark.asyncio
async def test_cache_breakpoint_live() -> None:
"""Verify cache breakpoint causes upstream content to be cached.

When cache_breakpoint is set on a user message, all content up to and
including that message should be cached, even content in prior messages
that don't have cache_breakpoint set.
"""
from lmi import LiteLLMModel

# System message - NOT marked for caching, but will be cached
# because it's upstream of the breakpoint
system_msg = Message(role="system", content=_make_long_content("System: "))

# User context message - marked for caching
# This caches everything up to and including this message
user_context = Message(role="user", content=_make_long_content("Context: "))
user_context.cache_breakpoint = True

# Simulated assistant acknowledgment
assistant_msg = Message(role="assistant", content="Acknowledged.")

# New user question (not cached)
user_question = Message(role="user", content="Summarize.")

messages = [system_msg, user_context, assistant_msg, user_question]

llm = LiteLLMModel(name="claude-3-5-haiku-20241022")

# First request - may create cache or hit existing cache
result1 = await llm.call_single(messages)
cache_active = (result1.cache_creation_tokens or 0) > 0 or (
result1.cache_read_tokens or 0
) > 0
assert cache_active, "Expected cache creation or cache read on first request"

# Second request - should hit cache
result2 = await llm.call_single(messages)
assert (result2.cache_read_tokens or 0) > 0, "Expected cache hit on second request"
# Cached content includes both system and user context (~600 items = ~1200+ tokens)
assert (result2.cache_read_tokens or 0) > 500, (
f"Expected >500 cached tokens, got {result2.cache_read_tokens}"
)


@pytest.mark.asyncio
async def test_cache_breakpoint_openai_live() -> None:
"""Verify cache_breakpoint doesn't interfere with OpenAI's automatic caching.

OpenAI uses automatic prefix-based caching (no explicit breakpoints).
LiteLLM strips cache_control from content blocks before sending to OpenAI.
This test verifies that setting cache_breakpoint doesn't break anything
and OpenAI's native caching still works.
"""
from lmi import LiteLLMModel

# Long content to exceed OpenAI's 1024 token caching threshold
system_msg = Message(role="system", content=_make_long_content("System: "))

# Setting cache_breakpoint - LiteLLM will strip cache_control for OpenAI
user_context = Message(role="user", content=_make_long_content("Context: "))
user_context.cache_breakpoint = True

assistant_msg = Message(role="assistant", content="Acknowledged.")
user_question = Message(role="user", content="Summarize.")

messages = [system_msg, user_context, assistant_msg, user_question]

llm = LiteLLMModel(name="gpt-4o-mini")

# First request
result1 = await llm.call_single(messages)
assert result1.text is not None

# Second request - OpenAI's automatic caching may hit
result2 = await llm.call_single(messages)
assert result2.text is not None

# OpenAI's caching is automatic and not guaranteed, but if it works,
# cache_read_tokens should be populated. At minimum, verify no errors
# occurred from the cache_breakpoint serialization.
if result2.cache_read_tokens is not None and result2.cache_read_tokens > 0:
assert result2.cache_read_tokens > 500, (
f"Expected >500 cached tokens if cache hit, got {result2.cache_read_tokens}"
)
8 changes: 5 additions & 3 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,7 @@ def __init__(self):

@pytest.mark.asyncio
async def test_argref_by_name_async_functions() -> None:
# pylint: disable=unexpected-keyword-arg
class MyState:
def __init__(self):
self.refs = {"foo": 1, "bar": 7}
Expand Down Expand Up @@ -809,6 +810,7 @@ async def async_list_direct(a: int, b: int) -> list[int]: # noqa: RUF029

@pytest.mark.asyncio
async def test_argref_by_name_advanced_features() -> None:
# pylint: disable=unexpected-keyword-arg
class MyState:
def __init__(self):
self.refs = {"foo": 1}
Expand All @@ -817,9 +819,9 @@ def __init__(self):

# Define and test dereference via no state value found with return_direct
@argref_by_name(return_direct=True)
def skip_deref_test(foo: float, a: str) -> str:
def skip_deref_test(ref_key: float, a: str) -> str:
"""Some docstring."""
return f"{foo} {a}"
return f"{ref_key} {a}"

assert skip_deref_test("foo", "not in state", state=s) == "1 not in state"
assert skip_deref_test("foo", "foo", state=s) == "1 1"
Expand Down Expand Up @@ -940,7 +942,7 @@ async def reset(self) -> tuple[list[Message], list[Tool]]:
async def step(self, action):
return await self.exec_tool_calls(action), False, 0, 0

async def export_frame(self):
def export_frame(self):
pass

with suppress_type_checks():
Expand Down
15 changes: 7 additions & 8 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,13 @@ class SomeModel(BaseModel):
assert isinstance(gauss_next, float | None)

# 2. Check deserialized RNG behaves as original RNG
for i, deserialized_model in enumerate((
SomeModel.model_validate_json(model.model_dump_json()), # JSON str
SomeModel.model_validate(model.model_dump(mode="json")), # JSON dict
)):
if i == 0:
# Sample original model once so RNG aligns for both deserialized
# models in the `for` loop
sampled_original = model.rng.sample(list(range(10)), k=6)
json_str = model.model_dump_json()
json_dict = model.model_dump(mode="json")
sampled_original = model.rng.sample(list(range(10)), k=6)
for deserialized_model in (
SomeModel.model_validate_json(json_str),
SomeModel.model_validate(json_dict),
):
assert isinstance(deserialized_model.rng, random.Random)
sampled_deserialized = deserialized_model.rng.sample(list(range(10)), k=6)
assert sampled_original == sampled_deserialized, (
Expand Down