|
2 | 2 |
|
3 | 3 | import numpy as np |
4 | 4 | import pytest |
| 5 | +from lmi import LiteLLMModel |
5 | 6 |
|
6 | 7 | from aviary.core import ( |
7 | 8 | Message, |
@@ -372,3 +373,145 @@ def test_prepend_text(self, subtests) -> None: |
372 | 373 | assert trm.content is not None |
373 | 374 | content_list_original = json.loads(trm.content) |
374 | 375 | assert len(content_list_original) == 3 |
| 376 | + |
| 377 | + |
| 378 | +class TestCacheBreakpoint: |
| 379 | + def test_default_is_false(self) -> None: |
| 380 | + msg = Message(content="test") |
| 381 | + assert not msg.cache_breakpoint |
| 382 | + |
| 383 | + def test_serialization_without_cache_breakpoint(self) -> None: |
| 384 | + data = Message(content="test").model_dump(exclude_none=True) |
| 385 | + assert data == {"role": "user", "content": "test"} |
| 386 | + |
| 387 | + @pytest.mark.parametrize( |
| 388 | + ("content", "expected_content"), |
| 389 | + [ |
| 390 | + ( |
| 391 | + "test", |
| 392 | + [ |
| 393 | + { |
| 394 | + "type": "text", |
| 395 | + "text": "test", |
| 396 | + "cache_control": {"type": "ephemeral"}, |
| 397 | + } |
| 398 | + ], |
| 399 | + ), |
| 400 | + ( |
| 401 | + [{"type": "text", "text": "first"}, {"type": "text", "text": "second"}], |
| 402 | + [ |
| 403 | + {"type": "text", "text": "first"}, |
| 404 | + { |
| 405 | + "type": "text", |
| 406 | + "text": "second", |
| 407 | + "cache_control": {"type": "ephemeral"}, |
| 408 | + }, |
| 409 | + ], |
| 410 | + ), |
| 411 | + ], |
| 412 | + ) |
| 413 | + def test_serialization_with_cache_breakpoint( |
| 414 | + self, content, expected_content |
| 415 | + ) -> None: |
| 416 | + data = Message(content=content, cache_breakpoint=True).model_dump( |
| 417 | + exclude_none=True |
| 418 | + ) |
| 419 | + assert data == {"role": "user", "content": expected_content} |
| 420 | + |
| 421 | + def test_serialization_with_cache_breakpoint_empty_content(self) -> None: |
| 422 | + data = Message(content=None, cache_breakpoint=True).model_dump( |
| 423 | + exclude_none=True |
| 424 | + ) |
| 425 | + # Should not crash, content stays None |
| 426 | + assert data == {"role": "user"} |
| 427 | + |
| 428 | + def test_cache_breakpoint_excluded_from_dump(self) -> None: |
| 429 | + data = Message(content="test", cache_breakpoint=True).model_dump() |
| 430 | + assert "cache_breakpoint" not in data |
| 431 | + |
| 432 | + def test_cache_breakpoint_with_image_content(self) -> None: |
| 433 | + data = Message.create_message( |
| 434 | + text="Describe this image", |
| 435 | + images=[ |
| 436 | + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" |
| 437 | + ], |
| 438 | + cache_breakpoint=True, |
| 439 | + ).model_dump(exclude_none=True) |
| 440 | + # cache_control should be on the last block (the text block) |
| 441 | + assert len(data["content"]) == 2 |
| 442 | + assert data["content"][0]["type"] == "image_url" |
| 443 | + assert "cache_control" not in data["content"][0] |
| 444 | + assert data["content"][1]["type"] == "text" |
| 445 | + assert data["content"][1]["cache_control"] == {"type": "ephemeral"} |
| 446 | + |
| 447 | + def test_cache_breakpoint_skipped_when_deserialize_content_false(self) -> None: |
| 448 | + data = Message(content="test", cache_breakpoint=True).model_dump( |
| 449 | + context={"deserialize_content": False} |
| 450 | + ) |
| 451 | + # Content should remain a string, cache_breakpoint not applied |
| 452 | + assert data["content"] == "test" |
| 453 | + |
| 454 | + def test_cache_breakpoint_logs_warning_when_skipped(self, caplog) -> None: |
| 455 | + import logging |
| 456 | + |
| 457 | + msg = Message(content="test", cache_breakpoint=True) |
| 458 | + with caplog.at_level(logging.WARNING): |
| 459 | + msg.model_dump(context={"deserialize_content": False}) |
| 460 | + assert "cache_breakpoint ignored" in caplog.text |
| 461 | + |
| 462 | + |
| 463 | +def _make_long_content(prefix: str, num_items: int = 300) -> str: |
| 464 | + """Generate long content for cache testing (>1024 tokens for Anthropic).""" |
| 465 | + return prefix + " ".join(f"item_{i}" for i in range(num_items)) |
| 466 | + |
| 467 | + |
| 468 | +@pytest.mark.asyncio |
| 469 | +@pytest.mark.parametrize( |
| 470 | + ("model_name", "require_cache_hit"), |
| 471 | + [ |
| 472 | + ("claude-3-5-haiku-20241022", True), |
| 473 | + ("gpt-4o-mini", False), |
| 474 | + ], |
| 475 | +) |
| 476 | +async def test_cache_breakpoint_live(model_name: str, require_cache_hit: bool) -> None: |
| 477 | + """Verify cache breakpoint behavior with different providers. |
| 478 | +
|
| 479 | + For Anthropic: cache_breakpoint causes upstream content to be cached. |
| 480 | + For OpenAI: LiteLLM correctly strips cache_control, and OpenAI's automatic |
| 481 | + prefix caching may or may not activate. |
| 482 | + """ |
| 483 | + system_msg = Message(role="system", content=_make_long_content("System: ")) |
| 484 | + user_context = Message(role="user", content=_make_long_content("Context: ")) |
| 485 | + user_context.cache_breakpoint = True |
| 486 | + assistant_msg = Message(role="assistant", content="Acknowledged.") |
| 487 | + user_question = Message(role="user", content="Summarize.") |
| 488 | + |
| 489 | + messages = [system_msg, user_context, assistant_msg, user_question] |
| 490 | + llm = LiteLLMModel(name=model_name) |
| 491 | + |
| 492 | + # First request - may create cache or hit existing cache |
| 493 | + result1 = await llm.call_single(messages) |
| 494 | + if require_cache_hit: |
| 495 | + cache_active = (result1.cache_creation_tokens or 0) > 0 or ( |
| 496 | + result1.cache_read_tokens or 0 |
| 497 | + ) > 0 |
| 498 | + assert cache_active, "Expected cache creation or cache read on first request" |
| 499 | + else: |
| 500 | + assert result1.text is not None |
| 501 | + |
| 502 | + # Second request - should hit cache (for Anthropic) or may hit (for OpenAI) |
| 503 | + result2 = await llm.call_single(messages) |
| 504 | + if require_cache_hit: |
| 505 | + assert (result2.cache_read_tokens or 0) > 0, ( |
| 506 | + "Expected cache hit on second request" |
| 507 | + ) |
| 508 | + assert (result2.cache_read_tokens or 0) > 500, ( |
| 509 | + f"Expected >500 cached tokens, got {result2.cache_read_tokens}" |
| 510 | + ) |
| 511 | + else: |
| 512 | + assert result2.text is not None |
| 513 | + # OpenAI's caching is automatic and not guaranteed |
| 514 | + if result2.cache_read_tokens is not None and result2.cache_read_tokens > 0: |
| 515 | + assert result2.cache_read_tokens > 500, ( |
| 516 | + f"Expected >500 cached tokens if cache hit, got {result2.cache_read_tokens}" |
| 517 | + ) |
0 commit comments