Skip to content

Commit 459afad

Browse files
Added ability to add custom cache breakpoints for anthropic models
1 parent cd40737 commit 459afad

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

src/agentlab/agents/tool_use_agent/multi_tool_agent.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class StructuredDiscussion:
6565

6666
def __init__(self, keep_last_n_obs=None):
6767
self.groups: list[MsgGroup] = []
68-
self.keep_last_n_obs = keep_last_n_obs
68+
self.keep_last_n_obs: int| None = keep_last_n_obs
6969

7070
def append(self, message: MessageBuilder):
7171
"""Append a message to the last group."""
@@ -87,11 +87,14 @@ def flatten(self) -> list[MessageBuilder]:
8787
print(
8888
f"Processing group {i} ({group.name}), is_tail={is_tail}, len(greoup)={len(group.messages)}"
8989
)
90+
# Include only summary if group not in last n groups.
9091
if not is_tail and group.summary is not None:
9192
messages.append(group.summary)
9293
else:
9394
messages.extend(group.messages)
94-
95+
# Mark all summarized messages for caching
96+
if i == len(self.groups) - keep_last_n_obs:
97+
messages[i].mark_all_previous_msg_for_caching()
9598
return messages
9699

97100
def set_last_summary(self, summary: MessageBuilder):
@@ -453,7 +456,8 @@ def get_action(self, obs: Any) -> float:
453456
messages=messages,
454457
tool_choice="any",
455458
cache_tool_definition=True,
456-
cache_complete_prompt=True,
459+
cache_complete_prompt=False,
460+
use_cache_breakpoints=True,
457461
)
458462

459463
action = response.action

src/agentlab/llm/response_api.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ def add_image_url(self, image_url: str) -> "MessageBuilder":
107107
self.content.append({"image": image_to_png_base64_url(image_url)})
108108
return self
109109

110+
def mark_all_previous_msg_for_caching(self):
111+
"""Insert a cache breakpoint in the message content."""
112+
# This is a placeholder for future implementation.
113+
raise NotImplementedError
114+
110115

111116
# TODO: Support parallel tool calls.
112117

@@ -216,6 +221,10 @@ def transform_content(self, content: ContentItem) -> ContentItem:
216221
else:
217222
raise ValueError(f"Unsupported content type: {content}")
218223

224+
def mark_all_previous_msg_for_caching(self) -> List[Message]:
225+
"""Insert a cache breakpoint in the message content to mark all previous messages for caching."""
226+
self._cache_breakpoint = True
227+
219228

220229
class OpenAIChatCompletionAPIMessageBuilder(MessageBuilder):
221230

@@ -521,7 +530,10 @@ def _call_api(
521530
sys_msg, other_msgs = self.filter_system_messages(messages)
522531
sys_msg_text = "\n".join(c["text"] for m in sys_msg for c in m.content)
523532
for msg in other_msgs:
524-
input.extend(msg.prepare_message() if isinstance(msg, MessageBuilder) else [msg])
533+
temp = msg.prepare_message() if isinstance(msg, MessageBuilder) else [msg]
534+
if kwargs.pop("use_cache_breakpoints", False):
535+
temp = self.apply_cache_breakpoints(msg, temp)
536+
input.extend(temp)
525537

526538
api_params: Dict[str, Any] = {
527539
"model": self.model_name,
@@ -581,6 +593,16 @@ def _parse_response(self, response: dict) -> dict:
581593
result.think += output.text
582594
return result
583595

596+
# def ensure_cache_conditions(self, msgs: List[Message]) -> bool:
597+
# """Ensure API specific cache conditions are met."""
598+
# assert sum(getattr(msg, "_cache_breakpoint", 0) for msg in msgs) <= 4, "Too many cache breakpoints in the message."
599+
600+
def apply_cache_breakpoints(self, msg: Message, prepared_msg: dict) -> List[Message]:
601+
"""Apply cache breakpoints to the messages."""
602+
if getattr(msg, "_cache_breakpoint", False):
603+
prepared_msg[-1]["content"][-1]["cache_control"] = {"type": "ephemeral"}
604+
return prepared_msg
605+
584606

585607
def cua_response_to_text(action):
586608
"""

0 commit comments

Comments
 (0)