@@ -107,6 +107,11 @@ def add_image_url(self, image_url: str) -> "MessageBuilder":
107107 self .content .append ({"image" : image_to_png_base64_url (image_url )})
108108 return self
109109
110+ def mark_all_previous_msg_for_caching (self ):
111+ """Insert a cache breakpoint in the message content."""
112+ # This is a placeholder for future implementation.
113+ raise NotImplementedError
114+
110115
111116# TODO: Support parallel tool calls.
112117
@@ -216,6 +221,10 @@ def transform_content(self, content: ContentItem) -> ContentItem:
216221 else :
217222 raise ValueError (f"Unsupported content type: { content } " )
218223
224+ def mark_all_previous_msg_for_caching (self ) -> List [Message ]:
225+ """Insert a cache breakpoint in the message content to mark all previous messages for caching."""
226+ self ._cache_breakpoint = True
227+
219228
220229class OpenAIChatCompletionAPIMessageBuilder (MessageBuilder ):
221230
@@ -521,7 +530,10 @@ def _call_api(
521530 sys_msg , other_msgs = self .filter_system_messages (messages )
522531 sys_msg_text = "\n " .join (c ["text" ] for m in sys_msg for c in m .content )
523532 for msg in other_msgs :
524- input .extend (msg .prepare_message () if isinstance (msg , MessageBuilder ) else [msg ])
533+ temp = msg .prepare_message () if isinstance (msg , MessageBuilder ) else [msg ]
534+ if kwargs .pop ("use_cache_breakpoints" , False ):
535+ temp = self .apply_cache_breakpoints (msg , temp )
536+ input .extend (temp )
525537
526538 api_params : Dict [str , Any ] = {
527539 "model" : self .model_name ,
@@ -581,6 +593,16 @@ def _parse_response(self, response: dict) -> dict:
581593 result .think += output .text
582594 return result
583595
596+ # def ensure_cache_conditions(self, msgs: List[Message]) -> bool:
597+ # """Ensure API specific cache conditions are met."""
598+ # assert sum(getattr(msg, "_cache_breakpoint", 0) for msg in msgs) <= 4, "Too many cache breakpoints in the message."
599+
600+ def apply_cache_breakpoints (self , msg : Message , prepared_msg : dict ) -> List [Message ]:
601+ """Apply cache breakpoints to the messages."""
602+ if getattr (msg , "_cache_breakpoint" , False ):
603+ prepared_msg [- 1 ]["content" ][- 1 ]["cache_control" ] = {"type" : "ephemeral" }
604+ return prepared_msg
605+
584606
585607def cua_response_to_text (action ):
586608 """
0 commit comments