Skip to content

Commit d460f43

Browse files
committed
feat: add prompt caching support for LiteLLM (#5791)
- Add litellmUsePromptCache configuration option to provider settings - Implement cache control headers in LiteLLM handler when enabled - Add UI checkbox for enabling prompt caching (only shown for supported models) - Track cache read/write tokens in usage data - Add comprehensive test for prompt caching functionality - Reuse existing translation keys for consistency across languages This allows LiteLLM users to benefit from prompt caching with supported models like Claude 3.7, reducing costs and improving response times.
1 parent 9fce90b commit d460f43

File tree

4 files changed

+201
-6
lines changed

4 files changed

+201
-6
lines changed

packages/types/src/provider-settings.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ const litellmSchema = baseProviderSettingsSchema.extend({
218218
litellmBaseUrl: z.string().optional(),
219219
litellmApiKey: z.string().optional(),
220220
litellmModelId: z.string().optional(),
221+
litellmUsePromptCache: z.boolean().optional(),
221222
})
222223

223224
const defaultSchema = z.object({
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import { describe, it, expect, vi, beforeEach } from "vitest"
2+
import OpenAI from "openai"
3+
import { Anthropic } from "@anthropic-ai/sdk"
4+
5+
import { LiteLLMHandler } from "../lite-llm"
6+
import { ApiHandlerOptions } from "../../../shared/api"
7+
import { litellmDefaultModelId, litellmDefaultModelInfo } from "@roo-code/types"
8+
9+
// Mock vscode first to avoid import errors
10+
vi.mock("vscode", () => ({}))
11+
12+
// Mock OpenAI
13+
vi.mock("openai", () => {
14+
const mockStream = {
15+
[Symbol.asyncIterator]: vi.fn(),
16+
}
17+
18+
const mockCreate = vi.fn().mockReturnValue({
19+
withResponse: vi.fn().mockResolvedValue({ data: mockStream }),
20+
})
21+
22+
return {
23+
default: vi.fn().mockImplementation(() => ({
24+
chat: {
25+
completions: {
26+
create: mockCreate,
27+
},
28+
},
29+
})),
30+
}
31+
})
32+
33+
// Mock model fetching
34+
vi.mock("../fetchers/modelCache", () => ({
35+
getModels: vi.fn().mockImplementation(() => {
36+
return Promise.resolve({
37+
[litellmDefaultModelId]: litellmDefaultModelInfo,
38+
})
39+
}),
40+
}))
41+
42+
describe("LiteLLMHandler", () => {
43+
let handler: LiteLLMHandler
44+
let mockOptions: ApiHandlerOptions
45+
let mockOpenAIClient: any
46+
47+
beforeEach(() => {
48+
vi.clearAllMocks()
49+
mockOptions = {
50+
litellmApiKey: "test-key",
51+
litellmBaseUrl: "http://localhost:4000",
52+
litellmModelId: litellmDefaultModelId,
53+
}
54+
handler = new LiteLLMHandler(mockOptions)
55+
mockOpenAIClient = new OpenAI()
56+
})
57+
58+
describe("prompt caching", () => {
59+
it("should add cache control headers when litellmUsePromptCache is enabled", async () => {
60+
const optionsWithCache: ApiHandlerOptions = {
61+
...mockOptions,
62+
litellmUsePromptCache: true,
63+
}
64+
handler = new LiteLLMHandler(optionsWithCache)
65+
66+
const systemPrompt = "You are a helpful assistant"
67+
const messages: Anthropic.Messages.MessageParam[] = [
68+
{ role: "user", content: "Hello" },
69+
{ role: "assistant", content: "Hi there!" },
70+
{ role: "user", content: "How are you?" },
71+
]
72+
73+
// Mock the stream response
74+
const mockStream = {
75+
async *[Symbol.asyncIterator]() {
76+
yield {
77+
choices: [{ delta: { content: "I'm doing well!" } }],
78+
usage: {
79+
prompt_tokens: 100,
80+
completion_tokens: 50,
81+
cache_creation_input_tokens: 20,
82+
cache_read_input_tokens: 30,
83+
},
84+
}
85+
},
86+
}
87+
88+
mockOpenAIClient.chat.completions.create.mockReturnValue({
89+
withResponse: vi.fn().mockResolvedValue({ data: mockStream }),
90+
})
91+
92+
const generator = handler.createMessage(systemPrompt, messages)
93+
const results = []
94+
for await (const chunk of generator) {
95+
results.push(chunk)
96+
}
97+
98+
// Verify that create was called with cache control headers
99+
const createCall = mockOpenAIClient.chat.completions.create.mock.calls[0][0]
100+
101+
// Check system message has cache control
102+
expect(createCall.messages[0]).toMatchObject({
103+
role: "system",
104+
content: systemPrompt,
105+
cache_control: { type: "ephemeral" },
106+
})
107+
108+
// Check that the last two user messages have cache control
109+
const userMessageIndices = createCall.messages
110+
.map((msg: any, idx: number) => (msg.role === "user" ? idx : -1))
111+
.filter((idx: number) => idx !== -1)
112+
113+
const lastUserIdx = userMessageIndices[userMessageIndices.length - 1]
114+
115+
expect(createCall.messages[lastUserIdx]).toMatchObject({
116+
cache_control: { type: "ephemeral" },
117+
})
118+
119+
// Verify usage includes cache tokens
120+
const usageChunk = results.find((chunk) => chunk.type === "usage")
121+
expect(usageChunk).toMatchObject({
122+
type: "usage",
123+
inputTokens: 100,
124+
outputTokens: 50,
125+
cacheWriteTokens: 20,
126+
cacheReadTokens: 30,
127+
})
128+
})
129+
})
130+
})

src/api/providers/lite-llm.ts

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,44 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa
4444
...convertToOpenAiMessages(messages),
4545
]
4646

47+
// Apply cache control if prompt caching is enabled and supported
48+
let enhancedMessages = openAiMessages
49+
if (this.options.litellmUsePromptCache && info.supportsPromptCache) {
50+
const cacheControl = { cache_control: { type: "ephemeral" } }
51+
52+
// Add cache control to system message
53+
enhancedMessages[0] = {
54+
...enhancedMessages[0],
55+
...cacheControl,
56+
}
57+
58+
// Find the last two user messages to apply caching
59+
const userMsgIndices = enhancedMessages.reduce(
60+
(acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc),
61+
[] as number[],
62+
)
63+
const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1
64+
const secondLastUserMsgIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1
65+
66+
// Apply cache_control to the last two user messages
67+
enhancedMessages = enhancedMessages.map((message, index) => {
68+
if (index === lastUserMsgIndex || index === secondLastUserMsgIndex) {
69+
return {
70+
...message,
71+
...cacheControl,
72+
}
73+
}
74+
return message
75+
})
76+
}
77+
4778
// Required by some providers; others default to max tokens allowed
4879
let maxTokens: number | undefined = info.maxTokens ?? undefined
4980

5081
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
5182
model: modelId,
5283
max_tokens: maxTokens,
53-
messages: openAiMessages,
84+
messages: enhancedMessages,
5485
stream: true,
5586
stream_options: {
5687
include_usage: true,
@@ -80,20 +111,30 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa
80111
}
81112

82113
if (lastUsage) {
114+
// Extract cache-related information if available
115+
// LiteLLM may use different field names for cache tokens
116+
const cacheWriteTokens =
117+
lastUsage.cache_creation_input_tokens || (lastUsage as any).prompt_cache_miss_tokens || 0
118+
const cacheReadTokens =
119+
lastUsage.prompt_tokens_details?.cached_tokens ||
120+
(lastUsage as any).cache_read_input_tokens ||
121+
(lastUsage as any).prompt_cache_hit_tokens ||
122+
0
123+
83124
const usageData: ApiStreamUsageChunk = {
84125
type: "usage",
85126
inputTokens: lastUsage.prompt_tokens || 0,
86127
outputTokens: lastUsage.completion_tokens || 0,
87-
cacheWriteTokens: lastUsage.cache_creation_input_tokens || 0,
88-
cacheReadTokens: lastUsage.prompt_tokens_details?.cached_tokens || 0,
128+
cacheWriteTokens: cacheWriteTokens > 0 ? cacheWriteTokens : undefined,
129+
cacheReadTokens: cacheReadTokens > 0 ? cacheReadTokens : undefined,
89130
}
90131

91132
usageData.totalCost = calculateApiCostOpenAI(
92133
info,
93134
usageData.inputTokens,
94135
usageData.outputTokens,
95-
usageData.cacheWriteTokens,
96-
usageData.cacheReadTokens,
136+
usageData.cacheWriteTokens || 0,
137+
usageData.cacheReadTokens || 0,
97138
)
98139

99140
yield usageData

webview-ui/src/components/settings/providers/LiteLLM.tsx

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { useCallback, useState, useEffect, useRef } from "react"
2-
import { VSCodeTextField } from "@vscode/webview-ui-toolkit/react"
2+
import { VSCodeTextField, VSCodeCheckbox } from "@vscode/webview-ui-toolkit/react"
33

44
import { type ProviderSettings, type OrganizationAllowList, litellmDefaultModelId } from "@roo-code/types"
55

@@ -151,6 +151,29 @@ export const LiteLLM = ({
151151
organizationAllowList={organizationAllowList}
152152
errorMessage={modelValidationError}
153153
/>
154+
155+
{/* Show prompt caching option if the selected model supports it */}
156+
{(() => {
157+
const selectedModelId = apiConfiguration.litellmModelId || litellmDefaultModelId
158+
const selectedModel = routerModels?.litellm?.[selectedModelId]
159+
if (selectedModel?.supportsPromptCache) {
160+
return (
161+
<div className="mt-4">
162+
<VSCodeCheckbox
163+
checked={apiConfiguration.litellmUsePromptCache || false}
164+
onChange={(e: any) => {
165+
setApiConfigurationField("litellmUsePromptCache", e.target.checked)
166+
}}>
167+
<span className="font-medium">{t("settings:providers.enablePromptCaching")}</span>
168+
</VSCodeCheckbox>
169+
<div className="text-sm text-vscode-descriptionForeground ml-6 mt-1">
170+
{t("settings:providers.enablePromptCachingTitle")}
171+
</div>
172+
</div>
173+
)
174+
}
175+
return null
176+
})()}
154177
</>
155178
)
156179
}

0 commit comments

Comments
 (0)