Skip to content

Commit e1eea0d

Browse files
markpollackSenreySong
authored andcommitted
Add Anthropic prompt caching via AnthropicChatOptions
This commit implements comprehensive prompt caching support for Anthropic Claude models in Spring AI: Core Implementation: - Add AnthropicCacheStrategy enum with 4 strategic options: NONE, SYSTEM_ONLY, SYSTEM_AND_TOOLS, CONVERSATION_HISTORY - Implement strategic cache placement with automatic 4-breakpoint limit enforcement via CacheBreakpointTracker - Support configurable TTL durations: "5m" (default) and "1h" (requires beta header) - Add cache_control support to system messages, tools, and conversation history based on strategy API Changes: - Extend AnthropicChatOptions with cacheStrategy() and cacheTtl() builder methods - Update AnthropicApi.Tool record to support cache_control field - Add cache usage tracking via cacheCreationInputTokens() and cacheReadInputTokens() Testing & Quality: - Add comprehensive integration tests with real-world scenarios - Add extensive mock test coverage with complex multi-breakpoint scenarios - Fix all checkstyle violations and test failures - Add cache breakpoint limit warning for production debugging Documentation: - Complete API documentation with practical examples and best practices - Add real-world use cases: legal document analysis, batch code review, customer support - Include cost optimization guidance demonstrating up to 90% savings - Document future enhancement roadmap for advanced scenarios Signed-off-by: Mark Pollack <[email protected]> Signed-off-by: Soby Chacko <[email protected]>
1 parent ba67460 commit e1eea0d

File tree

15 files changed

+2077
-237
lines changed

15 files changed

+2077
-237
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,7 @@ qodana.yaml
5151
__pycache__/
5252
*.pyc
5353
tmp
54+
55+
56+
plans
57+

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 246 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@
4242
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Source;
4343
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type;
4444
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
45+
import org.springframework.ai.anthropic.api.AnthropicCacheStrategy;
4546
import org.springframework.ai.chat.messages.AssistantMessage;
47+
import org.springframework.ai.chat.messages.Message;
4648
import org.springframework.ai.chat.messages.MessageType;
4749
import org.springframework.ai.chat.messages.ToolResponseMessage;
4850
import org.springframework.ai.chat.messages.UserMessage;
@@ -460,6 +462,12 @@ Prompt buildRequestPrompt(Prompt prompt) {
460462
this.defaultOptions.getToolCallbacks()));
461463
requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(),
462464
this.defaultOptions.getToolContext()));
465+
466+
// Merge cache strategy and TTL (also @JsonIgnore fields)
467+
requestOptions.setCacheStrategy(runtimeOptions.getCacheStrategy() != null
468+
? runtimeOptions.getCacheStrategy() : this.defaultOptions.getCacheStrategy());
469+
requestOptions.setCacheTtl(runtimeOptions.getCacheTtl() != null ? runtimeOptions.getCacheTtl()
470+
: this.defaultOptions.getCacheTtl());
463471
}
464472
else {
465473
requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders());
@@ -483,81 +491,75 @@ private Map<String, String> mergeHttpHeaders(Map<String, String> runtimeHttpHead
483491

484492
ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
485493

486-
// Get cache control from options
487-
AnthropicChatOptions requestOptions = (AnthropicChatOptions) prompt.getOptions();
488-
AnthropicApi.ChatCompletionRequest.CacheControl cacheControl = (requestOptions != null)
489-
? requestOptions.getCacheControl() : null;
494+
// Get caching strategy and options from the request
495+
logger.debug("DEBUGINFO: prompt.getOptions() type: {}, value: {}",
496+
prompt.getOptions() != null ? prompt.getOptions().getClass().getName() : "null", prompt.getOptions());
490497

491-
List<AnthropicMessage> userMessages = prompt.getInstructions()
492-
.stream()
493-
.filter(message -> message.getMessageType() != MessageType.SYSTEM)
494-
.map(message -> {
495-
if (message.getMessageType() == MessageType.USER) {
496-
List<ContentBlock> contents = new ArrayList<>();
498+
AnthropicChatOptions requestOptions = null;
499+
if (prompt.getOptions() instanceof AnthropicChatOptions) {
500+
requestOptions = (AnthropicChatOptions) prompt.getOptions();
501+
logger.debug("DEBUGINFO: Found AnthropicChatOptions - cacheStrategy: {}, cacheTtl: {}",
502+
requestOptions.getCacheStrategy(), requestOptions.getCacheTtl());
503+
}
504+
else {
505+
logger.debug("DEBUGINFO: Options is NOT AnthropicChatOptions, it's: {}",
506+
prompt.getOptions() != null ? prompt.getOptions().getClass().getName() : "null");
507+
}
497508

498-
// Apply cache control if enabled for user messages
499-
if (cacheControl != null) {
500-
contents.add(new ContentBlock(message.getText(), cacheControl));
501-
}
502-
else {
503-
contents.add(new ContentBlock(message.getText()));
504-
}
505-
if (message instanceof UserMessage userMessage) {
506-
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
507-
List<ContentBlock> mediaContent = userMessage.getMedia().stream().map(media -> {
508-
Type contentBlockType = getContentBlockTypeByMedia(media);
509-
var source = getSourceByMedia(media);
510-
return new ContentBlock(contentBlockType, source);
511-
}).toList();
512-
contents.addAll(mediaContent);
513-
}
514-
}
515-
return new AnthropicMessage(contents, Role.valueOf(message.getMessageType().name()));
516-
}
517-
else if (message.getMessageType() == MessageType.ASSISTANT) {
518-
AssistantMessage assistantMessage = (AssistantMessage) message;
519-
List<ContentBlock> contentBlocks = new ArrayList<>();
520-
if (StringUtils.hasText(message.getText())) {
521-
contentBlocks.add(new ContentBlock(message.getText()));
522-
}
523-
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
524-
for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
525-
contentBlocks.add(new ContentBlock(Type.TOOL_USE, toolCall.id(), toolCall.name(),
526-
ModelOptionsUtils.jsonToMap(toolCall.arguments())));
527-
}
528-
}
529-
return new AnthropicMessage(contentBlocks, Role.ASSISTANT);
530-
}
531-
else if (message.getMessageType() == MessageType.TOOL) {
532-
List<ContentBlock> toolResponses = ((ToolResponseMessage) message).getResponses()
533-
.stream()
534-
.map(toolResponse -> new ContentBlock(Type.TOOL_RESULT, toolResponse.id(),
535-
toolResponse.responseData()))
536-
.toList();
537-
return new AnthropicMessage(toolResponses, Role.USER);
538-
}
539-
else {
540-
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
541-
}
542-
})
543-
.toList();
509+
AnthropicCacheStrategy strategy = requestOptions != null ? requestOptions.getCacheStrategy()
510+
: AnthropicCacheStrategy.NONE;
511+
String cacheTtl = requestOptions != null ? requestOptions.getCacheTtl() : "5m";
544512

545-
String systemPrompt = prompt.getInstructions()
546-
.stream()
547-
.filter(m -> m.getMessageType() == MessageType.SYSTEM)
548-
.map(m -> m.getText())
549-
.collect(Collectors.joining(System.lineSeparator()));
513+
logger.debug("Cache strategy: {}, TTL: {}", strategy, cacheTtl);
514+
515+
// Track how many breakpoints we've used (max 4)
516+
CacheBreakpointTracker breakpointsUsed = new CacheBreakpointTracker();
517+
ChatCompletionRequest.CacheControl cacheControl = null;
518+
519+
if (strategy != AnthropicCacheStrategy.NONE) {
520+
// Create cache control with TTL if specified, otherwise use default 5m
521+
if (cacheTtl != null && !cacheTtl.equals("5m")) {
522+
cacheControl = new ChatCompletionRequest.CacheControl("ephemeral", cacheTtl);
523+
logger.debug("Created cache control with TTL: type={}, ttl={}", "ephemeral", cacheTtl);
524+
}
525+
else {
526+
cacheControl = new ChatCompletionRequest.CacheControl("ephemeral");
527+
logger.debug("Created cache control with default TTL: type={}, ttl={}", "ephemeral", "5m");
528+
}
529+
}
530+
531+
// Build messages WITHOUT blanket cache control - strategic placement only
532+
List<AnthropicMessage> userMessages = buildMessages(prompt, strategy, cacheControl, breakpointsUsed);
550533

534+
// Process system - as array if caching, string otherwise
535+
Object systemContent = buildSystemContent(prompt, strategy, cacheControl, breakpointsUsed);
536+
537+
// Build base request
551538
ChatCompletionRequest request = new ChatCompletionRequest(this.defaultOptions.getModel(), userMessages,
552-
systemPrompt, this.defaultOptions.getMaxTokens(), this.defaultOptions.getTemperature(), stream);
539+
systemContent, this.defaultOptions.getMaxTokens(), this.defaultOptions.getTemperature(), stream);
553540

554541
request = ModelOptionsUtils.merge(requestOptions, request, ChatCompletionRequest.class);
555542

556-
// Add the tool definitions to the request's tools parameter.
543+
// Add the tool definitions with potential caching
557544
List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions);
558545
if (!CollectionUtils.isEmpty(toolDefinitions)) {
559546
request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class);
560-
request = ChatCompletionRequest.from(request).tools(getFunctionTools(toolDefinitions)).build();
547+
List<AnthropicApi.Tool> tools = getFunctionTools(toolDefinitions);
548+
549+
// Apply caching to tools if strategy includes them
550+
if ((strategy == AnthropicCacheStrategy.SYSTEM_AND_TOOLS
551+
|| strategy == AnthropicCacheStrategy.CONVERSATION_HISTORY) && breakpointsUsed.canUse()) {
552+
tools = addCacheToLastTool(tools, cacheControl, breakpointsUsed);
553+
}
554+
555+
request = ChatCompletionRequest.from(request).tools(tools).build();
556+
}
557+
558+
// Add beta header for 1-hour TTL if needed
559+
if ("1h".equals(cacheTtl) && requestOptions != null) {
560+
Map<String, String> headers = new HashMap<>(requestOptions.getHttpHeaders());
561+
headers.put("anthropic-beta", AnthropicApi.BETA_EXTENDED_CACHE_TTL);
562+
requestOptions.setHttpHeaders(headers);
561563
}
562564

563565
return request;
@@ -573,6 +575,154 @@ private List<AnthropicApi.Tool> getFunctionTools(List<ToolDefinition> toolDefini
573575
}).toList();
574576
}
575577

578+
/**
579+
* Build messages strategically, applying cache control only where specified by the
580+
* strategy.
581+
*/
582+
private List<AnthropicMessage> buildMessages(Prompt prompt, AnthropicCacheStrategy strategy,
583+
ChatCompletionRequest.CacheControl cacheControl, CacheBreakpointTracker breakpointsUsed) {
584+
585+
List<Message> allMessages = prompt.getInstructions()
586+
.stream()
587+
.filter(message -> message.getMessageType() != MessageType.SYSTEM)
588+
.toList();
589+
590+
// Find the last user message (current question) for CONVERSATION_HISTORY strategy
591+
int lastUserIndex = -1;
592+
if (strategy == AnthropicCacheStrategy.CONVERSATION_HISTORY) {
593+
for (int i = allMessages.size() - 1; i >= 0; i--) {
594+
if (allMessages.get(i).getMessageType() == MessageType.USER) {
595+
lastUserIndex = i;
596+
break;
597+
}
598+
}
599+
}
600+
601+
List<AnthropicMessage> result = new ArrayList<>();
602+
for (int i = 0; i < allMessages.size(); i++) {
603+
Message message = allMessages.get(i);
604+
boolean shouldApplyCache = false;
605+
606+
// Apply cache to history tail (message before current question) for
607+
// CONVERSATION_HISTORY
608+
if (strategy == AnthropicCacheStrategy.CONVERSATION_HISTORY && breakpointsUsed.canUse()) {
609+
if (lastUserIndex > 0) {
610+
// Cache the message immediately before the last user message
611+
// (multi-turn conversation)
612+
shouldApplyCache = (i == lastUserIndex - 1);
613+
}
614+
if (shouldApplyCache) {
615+
breakpointsUsed.use();
616+
}
617+
}
618+
619+
if (message.getMessageType() == MessageType.USER) {
620+
List<ContentBlock> contents = new ArrayList<>();
621+
622+
// Apply cache control strategically, not to all user messages
623+
if (shouldApplyCache && cacheControl != null) {
624+
contents.add(new ContentBlock(message.getText(), cacheControl));
625+
}
626+
else {
627+
contents.add(new ContentBlock(message.getText()));
628+
}
629+
630+
if (message instanceof UserMessage userMessage) {
631+
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
632+
List<ContentBlock> mediaContent = userMessage.getMedia().stream().map(media -> {
633+
Type contentBlockType = getContentBlockTypeByMedia(media);
634+
var source = getSourceByMedia(media);
635+
return new ContentBlock(contentBlockType, source);
636+
}).toList();
637+
contents.addAll(mediaContent);
638+
}
639+
}
640+
result.add(new AnthropicMessage(contents, Role.valueOf(message.getMessageType().name())));
641+
}
642+
else if (message.getMessageType() == MessageType.ASSISTANT) {
643+
AssistantMessage assistantMessage = (AssistantMessage) message;
644+
List<ContentBlock> contentBlocks = new ArrayList<>();
645+
if (StringUtils.hasText(message.getText())) {
646+
contentBlocks.add(new ContentBlock(message.getText()));
647+
}
648+
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
649+
for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
650+
contentBlocks.add(new ContentBlock(Type.TOOL_USE, toolCall.id(), toolCall.name(),
651+
ModelOptionsUtils.jsonToMap(toolCall.arguments())));
652+
}
653+
}
654+
result.add(new AnthropicMessage(contentBlocks, Role.ASSISTANT));
655+
}
656+
else if (message.getMessageType() == MessageType.TOOL) {
657+
List<ContentBlock> toolResponses = ((ToolResponseMessage) message).getResponses()
658+
.stream()
659+
.map(toolResponse -> new ContentBlock(Type.TOOL_RESULT, toolResponse.id(),
660+
toolResponse.responseData()))
661+
.toList();
662+
result.add(new AnthropicMessage(toolResponses, Role.USER));
663+
}
664+
else {
665+
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
666+
}
667+
}
668+
return result;
669+
}
670+
671+
/**
672+
* Build system content - as array if caching, string otherwise.
673+
*/
674+
private Object buildSystemContent(Prompt prompt, AnthropicCacheStrategy strategy,
675+
ChatCompletionRequest.CacheControl cacheControl, CacheBreakpointTracker breakpointsUsed) {
676+
677+
String systemText = prompt.getInstructions()
678+
.stream()
679+
.filter(m -> m.getMessageType() == MessageType.SYSTEM)
680+
.map(Message::getText)
681+
.collect(Collectors.joining(System.lineSeparator()));
682+
683+
if (!StringUtils.hasText(systemText)) {
684+
return null;
685+
}
686+
687+
// Use array format when caching system
688+
if ((strategy == AnthropicCacheStrategy.SYSTEM_ONLY || strategy == AnthropicCacheStrategy.SYSTEM_AND_TOOLS
689+
|| strategy == AnthropicCacheStrategy.CONVERSATION_HISTORY) && breakpointsUsed.canUse()
690+
&& cacheControl != null) {
691+
692+
logger.debug("Applying cache control to system message - strategy: {}, cacheControl: {}", strategy,
693+
cacheControl);
694+
List<ContentBlock> systemBlocks = List.of(new ContentBlock(systemText, cacheControl));
695+
breakpointsUsed.use();
696+
return systemBlocks;
697+
}
698+
699+
// Use string format when not caching (backward compatible)
700+
return systemText;
701+
}
702+
703+
/**
704+
* Add cache control to the last tool for deterministic caching.
705+
*/
706+
private List<AnthropicApi.Tool> addCacheToLastTool(List<AnthropicApi.Tool> tools,
707+
ChatCompletionRequest.CacheControl cacheControl, CacheBreakpointTracker breakpointsUsed) {
708+
709+
if (tools == null || tools.isEmpty() || !breakpointsUsed.canUse() || cacheControl == null) {
710+
return tools;
711+
}
712+
713+
List<AnthropicApi.Tool> modifiedTools = new ArrayList<>();
714+
for (int i = 0; i < tools.size(); i++) {
715+
AnthropicApi.Tool tool = tools.get(i);
716+
if (i == tools.size() - 1) {
717+
// Add cache control to last tool
718+
tool = new AnthropicApi.Tool(tool.name(), tool.description(), tool.inputSchema(), cacheControl);
719+
breakpointsUsed.use();
720+
}
721+
modifiedTools.add(tool);
722+
}
723+
return modifiedTools;
724+
}
725+
576726
@Override
577727
public ChatOptions getDefaultOptions() {
578728
return AnthropicChatOptions.fromOptions(this.defaultOptions);
@@ -654,4 +804,36 @@ public AnthropicChatModel build() {
654804

655805
}
656806

807+
/**
808+
* Tracks cache breakpoints used (max 4 allowed by Anthropic). Non-static to ensure
809+
* each request has its own instance.
810+
*/
811+
private class CacheBreakpointTracker {
812+
813+
private int count = 0;
814+
815+
private boolean hasWarned = false;
816+
817+
public boolean canUse() {
818+
return this.count < 4;
819+
}
820+
821+
public void use() {
822+
if (this.count < 4) {
823+
this.count++;
824+
}
825+
else if (!this.hasWarned) {
826+
logger.warn(
827+
"Anthropic cache breakpoint limit (4) reached. Additional cache_control directives will be ignored. "
828+
+ "Consider using fewer cache strategies or simpler content structure.");
829+
this.hasWarned = true;
830+
}
831+
}
832+
833+
public int getCount() {
834+
return this.count;
835+
}
836+
837+
}
838+
657839
}

0 commit comments

Comments
 (0)