Skip to content

Commit 15eb26d

Browse files
mikhailmokhovhntrl
andauthored
feat(community): Support AWS Bedrock Converse API Prompt Caching (#8524)
Co-authored-by: Hunter Lovell <[email protected]>
1 parent 6e972e4 commit 15eb26d

File tree

2 files changed

+169
-25
lines changed

2 files changed

+169
-25
lines changed

libs/langchain-aws/src/common.ts

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,19 @@ import {
5151
MessageContentReasoningBlockRedacted,
5252
} from "./types.js";
5353

54+
function isDefaultCachePoint(block: unknown): boolean {
55+
return Boolean(
56+
typeof block === "object" &&
57+
block !== null &&
58+
"cachePoint" in block &&
59+
block.cachePoint &&
60+
typeof block.cachePoint === "object" &&
61+
block.cachePoint !== null &&
62+
"type" in block.cachePoint &&
63+
block.cachePoint.type === "default"
64+
);
65+
}
66+
5467
const standardContentBlockConverter: StandardContentBlockConverter<{
5568
text: ContentBlock.TextMember;
5669
image: ContentBlock.ImageMember;
@@ -310,6 +323,14 @@ function convertLangChainContentBlockToConverseContentBlock<
310323
};
311324
}
312325

326+
if (isDefaultCachePoint(block)) {
327+
return {
328+
cachePoint: {
329+
type: "default",
330+
},
331+
};
332+
}
333+
313334
if (onUnknown === "throw") {
314335
throw new Error(`Unsupported content block type: ${block.type}`);
315336
} else {
@@ -319,14 +340,28 @@ function convertLangChainContentBlockToConverseContentBlock<
319340

320341
function convertSystemMessageToConverseMessage(
321342
msg: SystemMessage
322-
): BedrockSystemContentBlock {
343+
): BedrockSystemContentBlock[] {
323344
if (typeof msg.content === "string") {
324-
return { text: msg.content };
325-
} else if (msg.content.length === 1 && msg.content[0].type === "text") {
326-
return { text: msg.content[0].text };
345+
return [{ text: msg.content }];
346+
} else if (Array.isArray(msg.content) && msg.content.length > 0) {
347+
const contentBlocks: BedrockSystemContentBlock[] = [];
348+
for (const block of msg.content) {
349+
if (block.type === "text" && typeof block.text === "string") {
350+
contentBlocks.push({
351+
text: block.text,
352+
});
353+
} else if (isDefaultCachePoint(block)) {
354+
contentBlocks.push({
355+
cachePoint: {
356+
type: "default",
357+
},
358+
});
359+
} else break;
360+
}
361+
if (msg.content.length === contentBlocks.length) return contentBlocks;
327362
}
328363
throw new Error(
329-
"System message content must be either a string, or a content array containing a single text object."
364+
"System message content must be either a string, or an array of text blocks, optionally including a cache point."
330365
);
331366
}
332367

@@ -353,6 +388,12 @@ function convertAIMessageToConverseMessage(msg: AIMessage): BedrockMessage {
353388
block as MessageContentReasoningBlock
354389
),
355390
};
391+
} else if (isDefaultCachePoint(block)) {
392+
return {
393+
cachePoint: {
394+
type: "default",
395+
},
396+
};
356397
} else {
357398
const blockValues = Object.fromEntries(
358399
Object.entries(block).filter(([key]) => key !== "type")
@@ -393,7 +434,7 @@ function convertHumanMessageToConverseMessage(
393434
): BedrockMessage {
394435
if (msg.content === "") {
395436
throw new Error(
396-
`Invalid message content: empty string. '${msg._getType()}' must contain non-empty content.`
437+
`Invalid message content: empty string. '${msg.getType()}' must contain non-empty content.`
397438
);
398439
}
399440

@@ -469,20 +510,20 @@ export function convertToConverseMessages(messages: BaseMessage[]): {
469510
converseSystem: BedrockSystemContentBlock[];
470511
} {
471512
const converseSystem: BedrockSystemContentBlock[] = messages
472-
.filter((msg) => msg._getType() === "system")
473-
.map((msg) => convertSystemMessageToConverseMessage(msg));
513+
.filter((msg) => msg.getType() === "system")
514+
.flatMap((msg) => convertSystemMessageToConverseMessage(msg));
474515

475516
const converseMessages: BedrockMessage[] = messages
476-
.filter((msg) => msg._getType() !== "system")
517+
.filter((msg) => msg.getType() !== "system")
477518
.map((msg) => {
478-
if (msg._getType() === "ai") {
519+
if (msg.getType() === "ai") {
479520
return convertAIMessageToConverseMessage(msg as AIMessage);
480-
} else if (msg._getType() === "human" || msg._getType() === "generic") {
521+
} else if (msg.getType() === "human" || msg.getType() === "generic") {
481522
return convertHumanMessageToConverseMessage(msg as HumanMessage);
482-
} else if (msg._getType() === "tool") {
523+
} else if (msg.getType() === "tool") {
483524
return convertToolMessageToConverseMessage(msg as ToolMessage);
484525
} else {
485-
throw new Error(`Unsupported message type: ${msg._getType()}`);
526+
throw new Error(`Unsupported message type: ${msg.getType()}`);
486527
}
487528
});
488529

libs/langchain-aws/src/tests/chat_models.test.ts

Lines changed: 115 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@ import {
77
BaseMessage,
88
} from "@langchain/core/messages";
99
import { concat } from "@langchain/core/utils/stream";
10-
import type {
11-
Message as BedrockMessage,
12-
SystemContentBlock as BedrockSystemContentBlock,
10+
import {
11+
ConversationRole as BedrockConversationRole,
12+
type Message as BedrockMessage,
13+
type SystemContentBlock as BedrockSystemContentBlock,
1314
} from "@aws-sdk/client-bedrock-runtime";
1415
import { z } from "zod";
1516
import { describe, expect, test } from "@jest/globals";
@@ -63,15 +64,15 @@ describe("convertToConverseMessages", () => {
6364
output: {
6465
converseMessages: [
6566
{
66-
role: "user",
67+
role: BedrockConversationRole.USER,
6768
content: [
6869
{
6970
text: "What's the weather like today in Berkeley, CA? Use weather.com to check.",
7071
},
7172
],
7273
},
7374
{
74-
role: "assistant",
75+
role: BedrockConversationRole.ASSISTANT,
7576
content: [
7677
{
7778
toolUse: {
@@ -85,7 +86,7 @@ describe("convertToConverseMessages", () => {
8586
],
8687
},
8788
{
88-
role: "user",
89+
role: BedrockConversationRole.USER,
8990
content: [
9091
{
9192
toolResult: {
@@ -107,6 +108,108 @@ describe("convertToConverseMessages", () => {
107108
],
108109
},
109110
},
111+
{
112+
name: "prompt caching",
113+
input: [
114+
new SystemMessage({
115+
content: [
116+
{ type: "text", text: "You're an advanced AI assistant." },
117+
{
118+
cachePoint: {
119+
type: "default",
120+
},
121+
},
122+
{
123+
type: "text",
124+
text: "Answer the user's questions using your own knowledge or provided tool.",
125+
},
126+
],
127+
}),
128+
new HumanMessage({
129+
content: [
130+
{
131+
type: "text",
132+
text: "What is the capital of France?",
133+
},
134+
{
135+
cachePoint: {
136+
type: "default",
137+
},
138+
},
139+
{
140+
type: "text",
141+
text: "And what is the capital of Germany?",
142+
},
143+
],
144+
}),
145+
new AIMessage({
146+
content: [
147+
{
148+
type: "text",
149+
text: "Sure! The capital of France is Paris.",
150+
},
151+
{
152+
cachePoint: {
153+
type: "default",
154+
},
155+
},
156+
{
157+
type: "text",
158+
text: "The capital of Germany is Berlin.",
159+
},
160+
],
161+
}),
162+
],
163+
output: {
164+
converseMessages: [
165+
{
166+
role: BedrockConversationRole.USER,
167+
content: [
168+
{
169+
text: "What is the capital of France?",
170+
},
171+
{
172+
cachePoint: {
173+
type: "default",
174+
},
175+
},
176+
{
177+
text: "And what is the capital of Germany?",
178+
},
179+
],
180+
},
181+
{
182+
role: BedrockConversationRole.ASSISTANT,
183+
content: [
184+
{
185+
text: "Sure! The capital of France is Paris.",
186+
},
187+
{
188+
cachePoint: {
189+
type: "default",
190+
},
191+
},
192+
{
193+
text: "The capital of Germany is Berlin.",
194+
},
195+
],
196+
},
197+
],
198+
converseSystem: [
199+
{
200+
text: "You're an advanced AI assistant.",
201+
},
202+
{
203+
cachePoint: {
204+
type: "default",
205+
},
206+
},
207+
{
208+
text: "Answer the user's questions using your own knowledge or provided tool.",
209+
},
210+
],
211+
},
212+
},
110213
{
111214
name: "consecutive user tool messages",
112215
input: [
@@ -180,15 +283,15 @@ describe("convertToConverseMessages", () => {
180283
],
181284
converseMessages: [
182285
{
183-
role: "user",
286+
role: BedrockConversationRole.USER,
184287
content: [
185288
{
186289
text: "What's the weather like today in Berkeley, CA and in Paris, France? Use weather.com to check.",
187290
},
188291
],
189292
},
190293
{
191-
role: "assistant",
294+
role: BedrockConversationRole.ASSISTANT,
192295
content: [
193296
{
194297
toolUse: {
@@ -211,7 +314,7 @@ describe("convertToConverseMessages", () => {
211314
],
212315
},
213316
{
214-
role: "user",
317+
role: BedrockConversationRole.USER,
215318
content: [
216319
{
217320
toolResult: {
@@ -236,15 +339,15 @@ describe("convertToConverseMessages", () => {
236339
],
237340
},
238341
{
239-
role: "user",
342+
role: BedrockConversationRole.USER,
240343
content: [
241344
{
242345
text: "What's the weather like today in Berkeley, CA and in Paris, France? Use meteofrance.com to check.",
243346
},
244347
],
245348
},
246349
{
247-
role: "assistant",
350+
role: BedrockConversationRole.ASSISTANT,
248351
content: [
249352
{
250353
toolUse: {
@@ -267,7 +370,7 @@ describe("convertToConverseMessages", () => {
267370
],
268371
},
269372
{
270-
role: "user",
373+
role: BedrockConversationRole.USER,
271374
content: [
272375
{
273376
toolResult: {

0 commit comments

Comments
 (0)