Skip to content

Commit a0018c9

Browse files
authored
feat: add prompt caching support for LiteLLM (#5791) (#6074)
* feat: add prompt caching support for LiteLLM (#5791) - Add litellmUsePromptCache configuration option to provider settings - Implement cache control headers in LiteLLM handler when enabled - Add UI checkbox for enabling prompt caching (only shown for supported models) - Track cache read/write tokens in usage data - Add comprehensive test for prompt caching functionality - Reuse existing translation keys for consistency across languages This allows LiteLLM users to benefit from prompt caching with supported models like Claude 3.7, reducing costs and improving response times. * fix: improve LiteLLM prompt caching to work for multi-turn conversations - Convert system message to structured format with cache_control - Handle both string and array content types for user messages - Apply cache_control to content items, not just message level - Update tests to match new message structure This ensures prompt caching works correctly for all messages in a conversation, not just the initial system prompt and first user message. * fix: resolve TypeScript linter error for cache_control property Use type assertion to handle cache_control property that's not in OpenAI types
1 parent 499d12c commit a0018c9

File tree

4 files changed

+262
-10
lines changed

4 files changed

+262
-10
lines changed

packages/types/src/provider-settings.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ const litellmSchema = baseProviderSettingsSchema.extend({
238238
litellmBaseUrl: z.string().optional(),
239239
litellmApiKey: z.string().optional(),
240240
litellmModelId: z.string().optional(),
241+
litellmUsePromptCache: z.boolean().optional(),
241242
})
242243

243244
const defaultSchema = z.object({
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
import { describe, it, expect, vi, beforeEach } from "vitest"
2+
import OpenAI from "openai"
3+
import { Anthropic } from "@anthropic-ai/sdk"
4+
5+
import { LiteLLMHandler } from "../lite-llm"
6+
import { ApiHandlerOptions } from "../../../shared/api"
7+
import { litellmDefaultModelId, litellmDefaultModelInfo } from "@roo-code/types"
8+
9+
// Mock vscode first to avoid import errors
10+
vi.mock("vscode", () => ({}))
11+
12+
// Mock OpenAI
13+
vi.mock("openai", () => {
14+
const mockStream = {
15+
[Symbol.asyncIterator]: vi.fn(),
16+
}
17+
18+
const mockCreate = vi.fn().mockReturnValue({
19+
withResponse: vi.fn().mockResolvedValue({ data: mockStream }),
20+
})
21+
22+
return {
23+
default: vi.fn().mockImplementation(() => ({
24+
chat: {
25+
completions: {
26+
create: mockCreate,
27+
},
28+
},
29+
})),
30+
}
31+
})
32+
33+
// Mock model fetching
34+
vi.mock("../fetchers/modelCache", () => ({
35+
getModels: vi.fn().mockImplementation(() => {
36+
return Promise.resolve({
37+
[litellmDefaultModelId]: litellmDefaultModelInfo,
38+
})
39+
}),
40+
}))
41+
42+
describe("LiteLLMHandler", () => {
43+
let handler: LiteLLMHandler
44+
let mockOptions: ApiHandlerOptions
45+
let mockOpenAIClient: any
46+
47+
beforeEach(() => {
48+
vi.clearAllMocks()
49+
mockOptions = {
50+
litellmApiKey: "test-key",
51+
litellmBaseUrl: "http://localhost:4000",
52+
litellmModelId: litellmDefaultModelId,
53+
}
54+
handler = new LiteLLMHandler(mockOptions)
55+
mockOpenAIClient = new OpenAI()
56+
})
57+
58+
describe("prompt caching", () => {
59+
it("should add cache control headers when litellmUsePromptCache is enabled", async () => {
60+
const optionsWithCache: ApiHandlerOptions = {
61+
...mockOptions,
62+
litellmUsePromptCache: true,
63+
}
64+
handler = new LiteLLMHandler(optionsWithCache)
65+
66+
const systemPrompt = "You are a helpful assistant"
67+
const messages: Anthropic.Messages.MessageParam[] = [
68+
{ role: "user", content: "Hello" },
69+
{ role: "assistant", content: "Hi there!" },
70+
{ role: "user", content: "How are you?" },
71+
]
72+
73+
// Mock the stream response
74+
const mockStream = {
75+
async *[Symbol.asyncIterator]() {
76+
yield {
77+
choices: [{ delta: { content: "I'm doing well!" } }],
78+
usage: {
79+
prompt_tokens: 100,
80+
completion_tokens: 50,
81+
cache_creation_input_tokens: 20,
82+
cache_read_input_tokens: 30,
83+
},
84+
}
85+
},
86+
}
87+
88+
mockOpenAIClient.chat.completions.create.mockReturnValue({
89+
withResponse: vi.fn().mockResolvedValue({ data: mockStream }),
90+
})
91+
92+
const generator = handler.createMessage(systemPrompt, messages)
93+
const results = []
94+
for await (const chunk of generator) {
95+
results.push(chunk)
96+
}
97+
98+
// Verify that create was called with cache control headers
99+
const createCall = mockOpenAIClient.chat.completions.create.mock.calls[0][0]
100+
101+
// Check system message has cache control in the proper format
102+
expect(createCall.messages[0]).toMatchObject({
103+
role: "system",
104+
content: [
105+
{
106+
type: "text",
107+
text: systemPrompt,
108+
cache_control: { type: "ephemeral" },
109+
},
110+
],
111+
})
112+
113+
// Check that the last two user messages have cache control
114+
const userMessageIndices = createCall.messages
115+
.map((msg: any, idx: number) => (msg.role === "user" ? idx : -1))
116+
.filter((idx: number) => idx !== -1)
117+
118+
const lastUserIdx = userMessageIndices[userMessageIndices.length - 1]
119+
const secondLastUserIdx = userMessageIndices[userMessageIndices.length - 2]
120+
121+
// Check last user message has proper structure with cache control
122+
expect(createCall.messages[lastUserIdx]).toMatchObject({
123+
role: "user",
124+
content: [
125+
{
126+
type: "text",
127+
text: "How are you?",
128+
cache_control: { type: "ephemeral" },
129+
},
130+
],
131+
})
132+
133+
// Check second last user message (first user message in this case)
134+
if (secondLastUserIdx !== -1) {
135+
expect(createCall.messages[secondLastUserIdx]).toMatchObject({
136+
role: "user",
137+
content: [
138+
{
139+
type: "text",
140+
text: "Hello",
141+
cache_control: { type: "ephemeral" },
142+
},
143+
],
144+
})
145+
}
146+
147+
// Verify usage includes cache tokens
148+
const usageChunk = results.find((chunk) => chunk.type === "usage")
149+
expect(usageChunk).toMatchObject({
150+
type: "usage",
151+
inputTokens: 100,
152+
outputTokens: 50,
153+
cacheWriteTokens: 20,
154+
cacheReadTokens: 30,
155+
})
156+
})
157+
})
158+
})

src/api/providers/lite-llm.ts

Lines changed: 79 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,78 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa
3939
): ApiStream {
4040
const { id: modelId, info } = await this.fetchModel()
4141

42-
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
43-
{ role: "system", content: systemPrompt },
44-
...convertToOpenAiMessages(messages),
45-
]
42+
const openAiMessages = convertToOpenAiMessages(messages)
43+
44+
// Prepare messages with cache control if enabled and supported
45+
let systemMessage: OpenAI.Chat.ChatCompletionMessageParam
46+
let enhancedMessages: OpenAI.Chat.ChatCompletionMessageParam[]
47+
48+
if (this.options.litellmUsePromptCache && info.supportsPromptCache) {
49+
// Create system message with cache control in the proper format
50+
systemMessage = {
51+
role: "system",
52+
content: [
53+
{
54+
type: "text",
55+
text: systemPrompt,
56+
cache_control: { type: "ephemeral" },
57+
} as any,
58+
],
59+
}
60+
61+
// Find the last two user messages to apply caching
62+
const userMsgIndices = openAiMessages.reduce(
63+
(acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc),
64+
[] as number[],
65+
)
66+
const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1
67+
const secondLastUserMsgIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1
68+
69+
// Apply cache_control to the last two user messages
70+
enhancedMessages = openAiMessages.map((message, index) => {
71+
if ((index === lastUserMsgIndex || index === secondLastUserMsgIndex) && message.role === "user") {
72+
// Handle both string and array content types
73+
if (typeof message.content === "string") {
74+
return {
75+
...message,
76+
content: [
77+
{
78+
type: "text",
79+
text: message.content,
80+
cache_control: { type: "ephemeral" },
81+
} as any,
82+
],
83+
}
84+
} else if (Array.isArray(message.content)) {
85+
// Apply cache control to the last content item in the array
86+
return {
87+
...message,
88+
content: message.content.map((content, contentIndex) =>
89+
contentIndex === message.content.length - 1
90+
? ({
91+
...content,
92+
cache_control: { type: "ephemeral" },
93+
} as any)
94+
: content,
95+
),
96+
}
97+
}
98+
}
99+
return message
100+
})
101+
} else {
102+
// No cache control - use simple format
103+
systemMessage = { role: "system", content: systemPrompt }
104+
enhancedMessages = openAiMessages
105+
}
46106

47107
// Required by some providers; others default to max tokens allowed
48108
let maxTokens: number | undefined = info.maxTokens ?? undefined
49109

50110
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
51111
model: modelId,
52112
max_tokens: maxTokens,
53-
messages: openAiMessages,
113+
messages: [systemMessage, ...enhancedMessages],
54114
stream: true,
55115
stream_options: {
56116
include_usage: true,
@@ -80,20 +140,30 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa
80140
}
81141

82142
if (lastUsage) {
143+
// Extract cache-related information if available
144+
// LiteLLM may use different field names for cache tokens
145+
const cacheWriteTokens =
146+
lastUsage.cache_creation_input_tokens || (lastUsage as any).prompt_cache_miss_tokens || 0
147+
const cacheReadTokens =
148+
lastUsage.prompt_tokens_details?.cached_tokens ||
149+
(lastUsage as any).cache_read_input_tokens ||
150+
(lastUsage as any).prompt_cache_hit_tokens ||
151+
0
152+
83153
const usageData: ApiStreamUsageChunk = {
84154
type: "usage",
85155
inputTokens: lastUsage.prompt_tokens || 0,
86156
outputTokens: lastUsage.completion_tokens || 0,
87-
cacheWriteTokens: lastUsage.cache_creation_input_tokens || 0,
88-
cacheReadTokens: lastUsage.prompt_tokens_details?.cached_tokens || 0,
157+
cacheWriteTokens: cacheWriteTokens > 0 ? cacheWriteTokens : undefined,
158+
cacheReadTokens: cacheReadTokens > 0 ? cacheReadTokens : undefined,
89159
}
90160

91161
usageData.totalCost = calculateApiCostOpenAI(
92162
info,
93163
usageData.inputTokens,
94164
usageData.outputTokens,
95-
usageData.cacheWriteTokens,
96-
usageData.cacheReadTokens,
165+
usageData.cacheWriteTokens || 0,
166+
usageData.cacheReadTokens || 0,
97167
)
98168

99169
yield usageData

webview-ui/src/components/settings/providers/LiteLLM.tsx

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { useCallback, useState, useEffect, useRef } from "react"
2-
import { VSCodeTextField } from "@vscode/webview-ui-toolkit/react"
2+
import { VSCodeTextField, VSCodeCheckbox } from "@vscode/webview-ui-toolkit/react"
33

44
import { type ProviderSettings, type OrganizationAllowList, litellmDefaultModelId } from "@roo-code/types"
55

@@ -151,6 +151,29 @@ export const LiteLLM = ({
151151
organizationAllowList={organizationAllowList}
152152
errorMessage={modelValidationError}
153153
/>
154+
155+
{/* Show prompt caching option if the selected model supports it */}
156+
{(() => {
157+
const selectedModelId = apiConfiguration.litellmModelId || litellmDefaultModelId
158+
const selectedModel = routerModels?.litellm?.[selectedModelId]
159+
if (selectedModel?.supportsPromptCache) {
160+
return (
161+
<div className="mt-4">
162+
<VSCodeCheckbox
163+
checked={apiConfiguration.litellmUsePromptCache || false}
164+
onChange={(e: any) => {
165+
setApiConfigurationField("litellmUsePromptCache", e.target.checked)
166+
}}>
167+
<span className="font-medium">{t("settings:providers.enablePromptCaching")}</span>
168+
</VSCodeCheckbox>
169+
<div className="text-sm text-vscode-descriptionForeground ml-6 mt-1">
170+
{t("settings:providers.enablePromptCachingTitle")}
171+
</div>
172+
</div>
173+
)
174+
}
175+
return null
176+
})()}
154177
</>
155178
)
156179
}

0 commit comments

Comments
 (0)