Skip to content

Commit 15e5026

Browse files
committed
feat(anthropic): add support for prompt caching
Implements Anthropic's prompt caching feature to improve token efficiency. - Adds cache control support in AnthropicApi and AnthropicChatModel - Creates AnthropicCacheType enum with EPHEMERAL cache type - Extends AbstractMessage and UserMessage to support cache parameters - Updates Usage tracking to include cache-related token metrics - Adds integration test to verify prompt caching functionality This implementation follows Anthropic's prompt caching API (beta-2024-07-31) which allows for more efficient token usage by caching frequently used prompts. Signed-off-by: “claudio-code” <[email protected]>
1 parent c89bb4a commit 15e5026

File tree

7 files changed

+159
-28
lines changed

7 files changed

+159
-28
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Source;
4343
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type;
4444
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
45+
import org.springframework.ai.anthropic.api.AnthropicCacheType;
46+
import org.springframework.ai.chat.messages.AbstractMessage;
4547
import org.springframework.ai.chat.messages.AssistantMessage;
4648
import org.springframework.ai.chat.messages.MessageType;
4749
import org.springframework.ai.chat.messages.ToolResponseMessage;
@@ -432,7 +434,16 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
432434
.filter(message -> message.getMessageType() != MessageType.SYSTEM)
433435
.map(message -> {
434436
if (message.getMessageType() == MessageType.USER) {
435-
List<ContentBlock> contents = new ArrayList<>(List.of(new ContentBlock(message.getText())));
437+
AbstractMessage abstractMessage = (AbstractMessage) message;
438+
List<ContentBlock> contents;
439+
if (abstractMessage.getCache() != null) {
440+
AnthropicCacheType cacheType = AnthropicCacheType.valueOf(abstractMessage.getCache());
441+
contents = new ArrayList<>(
442+
List.of(new ContentBlock(message.getText(), cacheType.cacheControl())));
443+
}
444+
else {
445+
contents = new ArrayList<>(List.of(new ContentBlock(message.getText())));
446+
}
436447
if (message instanceof UserMessage userMessage) {
437448
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
438449
List<ContentBlock> mediaContent = userMessage.getMedia().stream().map(media -> {

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import reactor.core.publisher.Flux;
3333
import reactor.core.publisher.Mono;
3434

35-
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionResponse;
35+
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.CacheControl;
3636
import org.springframework.ai.anthropic.api.StreamHelper.ChatCompletionResponseBuilder;
3737
import org.springframework.ai.model.ChatModelDescription;
3838
import org.springframework.ai.model.ModelOptionsUtils;
@@ -77,6 +77,8 @@ public class AnthropicApi {
7777

7878
private static final String HEADER_ANTHROPIC_BETA = "anthropic-beta";
7979

80+
public static final String BETA_PROMPT_CACHING = "prompt-caching-2024-07-31";
81+
8082
private static final Predicate<String> SSE_DONE_PREDICATE = "[DONE]"::equals;
8183

8284
private final RestClient restClient;
@@ -495,25 +497,30 @@ public ChatCompletionRequest(String model, List<AnthropicMessage> messages, Stri
495497
this(model, messages, system, maxTokens, null, stopSequences, stream, temperature, null, null, null);
496498
}
497499

498-
public static ChatCompletionRequestBuilder builder() {
499-
return new ChatCompletionRequestBuilder();
500-
}
501-
502-
public static ChatCompletionRequestBuilder from(ChatCompletionRequest request) {
503-
return new ChatCompletionRequestBuilder(request);
504-
}
505-
506500
/**
507-
* Metadata about the request.
508-
*
509501
* @param userId An external identifier for the user who is associated with the
510502
* request. This should be a uuid, hash value, or other opaque identifier.
511503
* Anthropic may use this id to help detect abuse. Do not include any identifying
512504
* information such as name, email address, or phone number.
513505
*/
514506
@JsonInclude(Include.NON_NULL)
515507
public record Metadata(@JsonProperty("user_id") String userId) {
508+
}
516509

510+
/**
511+
* @param type is the cache type supported by anthropic. <a href=
512+
* "https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#cache-limitations">Doc</a>
513+
*/
514+
@JsonInclude(Include.NON_NULL)
515+
public record CacheControl(String type) {
516+
}
517+
518+
public static ChatCompletionRequestBuilder builder() {
519+
return new ChatCompletionRequestBuilder();
520+
}
521+
522+
public static ChatCompletionRequestBuilder from(ChatCompletionRequest request) {
523+
return new ChatCompletionRequestBuilder(request);
517524
}
518525

519526
}
@@ -689,7 +696,10 @@ public record ContentBlock(
689696

690697
// tool_result response only
691698
@JsonProperty("tool_use_id") String toolUseId,
692-
@JsonProperty("content") String content
699+
@JsonProperty("content") String content,
700+
701+
// cache object
702+
@JsonProperty("cache_control") CacheControl cacheControl
693703
) {
694704
// @formatter:on
695705

@@ -708,23 +718,27 @@ public ContentBlock(String mediaType, String data) {
708718
* @param source The source of the content.
709719
*/
710720
public ContentBlock(Type type, Source source) {
711-
this(type, source, null, null, null, null, null, null, null);
721+
this(type, source, null, null, null, null, null, null, null, null);
712722
}
713723

714724
/**
715725
* Create content block
716726
* @param source The source of the content.
717727
*/
718728
public ContentBlock(Source source) {
719-
this(Type.IMAGE, source, null, null, null, null, null, null, null);
729+
this(Type.IMAGE, source, null, null, null, null, null, null, null, null);
720730
}
721731

722732
/**
723733
* Create content block
724734
* @param text The text of the content.
725735
*/
726736
public ContentBlock(String text) {
727-
this(Type.TEXT, null, text, null, null, null, null, null, null);
737+
this(Type.TEXT, null, text, null, null, null, null, null, null, null);
738+
}
739+
740+
public ContentBlock(String text, CacheControl cache) {
741+
this(Type.TEXT, null, text, null, null, null, null, null, null, cache);
728742
}
729743

730744
// Tool result
@@ -735,7 +749,7 @@ public ContentBlock(String text) {
735749
* @param content The content of the tool result.
736750
*/
737751
public ContentBlock(Type type, String toolUseId, String content) {
738-
this(type, null, null, null, null, null, null, toolUseId, content);
752+
this(type, null, null, null, null, null, null, toolUseId, content, null);
739753
}
740754

741755
/**
@@ -746,7 +760,7 @@ public ContentBlock(Type type, String toolUseId, String content) {
746760
* @param index The index of the content block.
747761
*/
748762
public ContentBlock(Type type, Source source, String text, Integer index) {
749-
this(type, source, text, index, null, null, null, null, null);
763+
this(type, source, text, index, null, null, null, null, null, null);
750764
}
751765

752766
// Tool use input JSON delta streaming
@@ -758,7 +772,7 @@ public ContentBlock(Type type, Source source, String text, Integer index) {
758772
* @param input The input of the tool use.
759773
*/
760774
public ContentBlock(Type type, String id, String name, Map<String, Object> input) {
761-
this(type, null, null, null, id, name, input, null, null);
775+
this(type, null, null, null, id, name, input, null, null, null);
762776
}
763777

764778
/**
@@ -917,7 +931,9 @@ public record ChatCompletionResponse(
917931
public record Usage(
918932
// @formatter:off
919933
@JsonProperty("input_tokens") Integer inputTokens,
920-
@JsonProperty("output_tokens") Integer outputTokens) {
934+
@JsonProperty("output_tokens") Integer outputTokens,
935+
@JsonProperty("cache_creation_input_tokens") Integer cacheCreationInputTokens,
936+
@JsonProperty("cache_read_input_tokens") Integer cacheReadInputTokens) {
921937
// @formatter:off
922938
}
923939

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package org.springframework.ai.anthropic.api;
2+
3+
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.CacheControl;
4+
5+
import java.util.function.Supplier;
6+
7+
public enum AnthropicCacheType {
8+
9+
EPHEMERAL(() -> new CacheControl("ephemeral"));
10+
11+
private Supplier<CacheControl> value;
12+
13+
AnthropicCacheType(Supplier<CacheControl> value) {
14+
this.value = value;
15+
}
16+
17+
public CacheControl cacheControl() {
18+
return this.value.get();
19+
}
20+
21+
}

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,9 @@ else if (event.type().equals(EventType.MESSAGE_DELTA)) {
174174

175175
if (messageDeltaEvent.usage() != null) {
176176
var totalUsage = new Usage(contentBlockReference.get().usage.inputTokens(),
177-
messageDeltaEvent.usage().outputTokens());
177+
messageDeltaEvent.usage().outputTokens(),
178+
contentBlockReference.get().usage.cacheCreationInputTokens(),
179+
contentBlockReference.get().usage.cacheReadInputTokens());
178180
contentBlockReference.get().withUsage(totalUsage);
179181
}
180182
}

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
3030
import org.springframework.http.ResponseEntity;
3131

32+
import org.springframework.web.client.RestClient;
33+
import org.springframework.web.reactive.function.client.WebClient;
34+
import reactor.core.publisher.Flux;
3235
import static org.assertj.core.api.Assertions.assertThat;
3336
import static org.assertj.core.api.Assertions.assertThatThrownBy;
3437

@@ -41,6 +44,34 @@ public class AnthropicApiIT {
4144

4245
AnthropicApi anthropicApi = new AnthropicApi(System.getenv("ANTHROPIC_API_KEY"));
4346

47+
@Test
48+
void chatWithPromptCache() {
49+
String userMessageText = "It could be either a contraction of the full title Quenta Silmarillion (\"Tale of the Silmarils\") or also a plain Genitive which "
50+
+ "(as in Ancient Greek) signifies reference. This genitive is translated in English with \"about\" or \"of\" "
51+
+ "constructions; the titles of the chapters in The Silmarillion are examples of this genitive in poetic English "
52+
+ "(Of the Sindar, Of Men, Of the Darkening of Valinor etc), where \"of\" means \"about\" or \"concerning\". "
53+
+ "In the same way, Silmarillion can be taken to mean \"Of/About the Silmarils\"";
54+
55+
AnthropicMessage chatCompletionMessage = new AnthropicMessage(
56+
List.of(new ContentBlock(userMessageText.repeat(20), AnthropicCacheType.EPHEMERAL.cacheControl())),
57+
Role.USER);
58+
59+
ChatCompletionRequest chatCompletionRequest = new ChatCompletionRequest(
60+
AnthropicApi.ChatModel.CLAUDE_3_HAIKU.getValue(), List.of(chatCompletionMessage), null, 100, 0.8,
61+
false);
62+
AnthropicApi.Usage createdCacheToken = anthropicApi.chatCompletionEntity(chatCompletionRequest)
63+
.getBody()
64+
.usage();
65+
66+
assertThat(createdCacheToken.cacheCreationInputTokens()).isGreaterThan(0);
67+
assertThat(createdCacheToken.cacheReadInputTokens()).isEqualTo(0);
68+
69+
AnthropicApi.Usage readCacheToken = anthropicApi.chatCompletionEntity(chatCompletionRequest).getBody().usage();
70+
71+
assertThat(readCacheToken.cacheCreationInputTokens()).isEqualTo(0);
72+
assertThat(readCacheToken.cacheReadInputTokens()).isGreaterThan(0);
73+
}
74+
4475
@Test
4576
void chatCompletionEntity() {
4677

spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,25 @@ public abstract class AbstractMessage implements Message {
5151
*/
5252
protected final String textContent;
5353

54+
protected String cache;
55+
5456
/**
5557
* Additional options for the message to influence the response, not a generative map.
5658
*/
5759
protected final Map<String, Object> metadata;
5860

59-
/**
60-
* Create a new AbstractMessage with the given message type, text content, and
61-
* metadata.
62-
* @param messageType the message type
63-
* @param textContent the text content
64-
* @param metadata the metadata
65-
*/
61+
protected AbstractMessage(MessageType messageType, String textContent, Map<String, Object> metadata, String cache) {
62+
Assert.notNull(messageType, "Message type must not be null");
63+
if (messageType == MessageType.SYSTEM || messageType == MessageType.USER) {
64+
Assert.notNull(textContent, "Content must not be null for SYSTEM or USER messages");
65+
}
66+
this.messageType = messageType;
67+
this.textContent = textContent;
68+
this.metadata = new HashMap<>(metadata);
69+
this.metadata.put(MESSAGE_TYPE, messageType);
70+
this.cache = cache;
71+
}
72+
6673
protected AbstractMessage(MessageType messageType, String textContent, Map<String, Object> metadata) {
6774
Assert.notNull(messageType, "Message type must not be null");
6875
if (messageType == MessageType.SYSTEM || messageType == MessageType.USER) {
@@ -93,6 +100,20 @@ protected AbstractMessage(MessageType messageType, Resource resource, Map<String
93100
this.metadata.put(MESSAGE_TYPE, messageType);
94101
}
95102

103+
protected AbstractMessage(MessageType messageType, Resource resource, Map<String, Object> metadata, String cache) {
104+
Assert.notNull(resource, "Resource must not be null");
105+
try (InputStream inputStream = resource.getInputStream()) {
106+
this.textContent = StreamUtils.copyToString(inputStream, Charset.defaultCharset());
107+
}
108+
catch (IOException ex) {
109+
throw new RuntimeException("Failed to read resource", ex);
110+
}
111+
this.messageType = messageType;
112+
this.metadata = new HashMap<>(metadata);
113+
this.metadata.put(MESSAGE_TYPE, messageType);
114+
this.cache = cache;
115+
}
116+
96117
/**
97118
* Get the content of the message.
98119
* @return the content of the message
@@ -120,6 +141,10 @@ public MessageType getMessageType() {
120141
return this.messageType;
121142
}
122143

144+
public String getCache() {
145+
return cache;
146+
}
147+
123148
@Override
124149
public boolean equals(Object o) {
125150
if (this == o) {

spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ public class UserMessage extends AbstractMessage implements MediaContent {
3636

3737
protected final List<Media> media;
3838

39+
public UserMessage(String textContent, String cache) {
40+
this(MessageType.USER, textContent, new ArrayList<>(), Map.of(), cache);
41+
}
42+
3943
public UserMessage(String textContent) {
4044
this(MessageType.USER, textContent, new ArrayList<>(), Map.of());
4145
}
@@ -45,6 +49,11 @@ public UserMessage(Resource resource) {
4549
this.media = new ArrayList<>();
4650
}
4751

52+
public UserMessage(Resource resource, String cache) {
53+
super(MessageType.USER, resource, Map.of(), cache);
54+
this.media = new ArrayList<>();
55+
}
56+
4857
public UserMessage(String textContent, List<Media> media) {
4958
this(MessageType.USER, textContent, media, Map.of());
5059
}
@@ -64,6 +73,17 @@ public UserMessage(MessageType messageType, String textContent, Collection<Media
6473
this.media = new ArrayList<>(media);
6574
}
6675

76+
public UserMessage(MessageType messageType, String textContent, Collection<Media> media,
77+
Map<String, Object> metadata, String cache) {
78+
super(messageType, textContent, metadata, cache);
79+
Assert.notNull(media, "media data must not be null");
80+
this.media = new ArrayList<>(media);
81+
}
82+
83+
public List<Media> getMedia(String... dummy) {
84+
return this.media;
85+
}
86+
6787
@Override
6888
public String toString() {
6989
return "UserMessage{" + "content='" + getText() + '\'' + ", properties=" + this.metadata + ", messageType="
@@ -80,4 +100,9 @@ public String getText() {
80100
return this.textContent;
81101
}
82102

103+
@Override
104+
public String getCache() {
105+
return super.getCache();
106+
}
107+
83108
}

0 commit comments

Comments
 (0)