Skip to content

Commit 801946f

Browse files
authored
feat: allow enabling prompt caching for LiteLLM + Claude (RooCodeInc#2627)
* feat: allow enabling prompt caching for LiteLLM + Claude
1 parent 13b6941 commit 801946f

File tree

6 files changed

+87
-5
lines changed

6 files changed

+87
-5
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"claude-dev": patch
3+
---
4+
5+
allow enabling prompt caching for LiteLLM + Claude

src/api/providers/litellm.ts

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,37 @@ export class LiteLlmHandler implements ApiHandler {
7171
temperature = undefined // Thinking mode doesn't support temperature
7272
}
7373

74+
// Define cache control object if prompt caching is enabled
75+
const cacheControl = this.options.liteLlmUsePromptCache ? { cache_control: { type: "ephemeral" } } : undefined
76+
77+
// Add cache_control to system message if enabled
78+
const enhancedSystemMessage = {
79+
...systemMessage,
80+
...(cacheControl && cacheControl),
81+
}
82+
83+
// Find the last two user messages to apply caching
84+
const userMsgIndices = formattedMessages.reduce(
85+
(acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc),
86+
[] as number[],
87+
)
88+
const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1
89+
const secondLastUserMsgIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1
90+
91+
// Apply cache_control to the last two user messages if enabled
92+
const enhancedMessages = formattedMessages.map((message, index) => {
93+
if ((index === lastUserMsgIndex || index === secondLastUserMsgIndex) && cacheControl) {
94+
return {
95+
...message,
96+
...cacheControl,
97+
}
98+
}
99+
return message
100+
})
101+
74102
const stream = await this.client.chat.completions.create({
75103
model: this.options.liteLlmModelId || liteLlmDefaultModelId,
76-
messages: [systemMessage, ...formattedMessages],
104+
messages: [enhancedSystemMessage, ...enhancedMessages],
77105
temperature,
78106
stream: true,
79107
stream_options: { include_usage: true },
@@ -111,10 +139,27 @@ export class LiteLlmHandler implements ApiHandler {
111139
if (chunk.usage) {
112140
const totalCost =
113141
(inputCost * chunk.usage.prompt_tokens) / 1e6 + (outputCost * chunk.usage.completion_tokens) / 1e6
142+
143+
// Extract cache-related information if available
144+
// Need to use type assertion since these properties are not in the standard OpenAI types
145+
const usage = chunk.usage as {
146+
prompt_tokens: number
147+
completion_tokens: number
148+
cache_creation_input_tokens?: number
149+
prompt_cache_miss_tokens?: number
150+
cache_read_input_tokens?: number
151+
prompt_cache_hit_tokens?: number
152+
}
153+
154+
const cacheWriteTokens = usage.cache_creation_input_tokens || usage.prompt_cache_miss_tokens || 0
155+
const cacheReadTokens = usage.cache_read_input_tokens || usage.prompt_cache_hit_tokens || 0
156+
114157
yield {
115158
type: "usage",
116-
inputTokens: chunk.usage.prompt_tokens || 0,
117-
outputTokens: chunk.usage.completion_tokens || 0,
159+
inputTokens: usage.prompt_tokens || 0,
160+
outputTokens: usage.completion_tokens || 0,
161+
cacheWriteTokens: cacheWriteTokens > 0 ? cacheWriteTokens : undefined,
162+
cacheReadTokens: cacheReadTokens > 0 ? cacheReadTokens : undefined,
118163
totalCost,
119164
}
120165
}

src/core/storage/state-keys.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ export type GlobalStateKey =
5858
| "previousModeModelInfo"
5959
| "liteLlmBaseUrl"
6060
| "liteLlmModelId"
61+
| "liteLlmUsePromptCache"
6162
| "qwenApiLine"
6263
| "requestyModelId"
6364
| "togetherModelId"

src/core/storage/state.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ export async function getAllExtensionState(context: vscode.ExtensionContext) {
101101
vsCodeLmModelSelector,
102102
liteLlmBaseUrl,
103103
liteLlmModelId,
104+
liteLlmUsePromptCache,
104105
userInfo,
105106
previousModeApiProvider,
106107
previousModeModelId,
@@ -166,6 +167,7 @@ export async function getAllExtensionState(context: vscode.ExtensionContext) {
166167
getGlobalState(context, "vsCodeLmModelSelector") as Promise<vscode.LanguageModelChatSelector | undefined>,
167168
getGlobalState(context, "liteLlmBaseUrl") as Promise<string | undefined>,
168169
getGlobalState(context, "liteLlmModelId") as Promise<string | undefined>,
170+
getGlobalState(context, "liteLlmUsePromptCache") as Promise<boolean | undefined>,
169171
getGlobalState(context, "userInfo") as Promise<UserInfo | undefined>,
170172
getGlobalState(context, "previousModeApiProvider") as Promise<ApiProvider | undefined>,
171173
getGlobalState(context, "previousModeModelId") as Promise<string | undefined>,
@@ -268,6 +270,7 @@ export async function getAllExtensionState(context: vscode.ExtensionContext) {
268270
liteLlmBaseUrl,
269271
liteLlmModelId,
270272
liteLlmApiKey,
273+
liteLlmUsePromptCache,
271274
asksageApiKey,
272275
asksageApiUrl,
273276
xaiApiKey,
@@ -336,6 +339,7 @@ export async function updateApiConfiguration(context: vscode.ExtensionContext, a
336339
liteLlmBaseUrl,
337340
liteLlmModelId,
338341
liteLlmApiKey,
342+
liteLlmUsePromptCache,
339343
qwenApiLine,
340344
asksageApiKey,
341345
asksageApiUrl,
@@ -386,6 +390,7 @@ export async function updateApiConfiguration(context: vscode.ExtensionContext, a
386390
await updateGlobalState(context, "vsCodeLmModelSelector", vsCodeLmModelSelector)
387391
await updateGlobalState(context, "liteLlmBaseUrl", liteLlmBaseUrl)
388392
await updateGlobalState(context, "liteLlmModelId", liteLlmModelId)
393+
await updateGlobalState(context, "liteLlmUsePromptCache", liteLlmUsePromptCache)
389394
await updateGlobalState(context, "qwenApiLine", qwenApiLine)
390395
await updateGlobalState(context, "requestyModelId", requestyModelId)
391396
await updateGlobalState(context, "togetherModelId", togetherModelId)

src/shared/api.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ export interface ApiHandlerOptions {
2929
liteLlmBaseUrl?: string
3030
liteLlmModelId?: string
3131
liteLlmApiKey?: string
32+
liteLlmUsePromptCache?: boolean
3233
anthropicBaseUrl?: string
3334
openRouterApiKey?: string
3435
openRouterModelId?: string
@@ -1239,9 +1240,11 @@ export const liteLlmModelInfoSaneDefaults: ModelInfo = {
12391240
maxTokens: -1,
12401241
contextWindow: 128_000,
12411242
supportsImages: true,
1242-
supportsPromptCache: false,
1243+
supportsPromptCache: true,
12431244
inputPrice: 0,
12441245
outputPrice: 0,
1246+
cacheWritesPrice: 0,
1247+
cacheReadsPrice: 0,
12451248
}
12461249

12471250
// AskSage Models

webview-ui/src/components/settings/ApiOptions.tsx

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ import {
4747
sambanovaDefaultModelId,
4848
doubaoModels,
4949
doubaoDefaultModelId,
50+
liteLlmModelInfoSaneDefaults,
5051
} from "@shared/api"
5152
import { ExtensionMessage } from "@shared/ExtensionMessage"
5253
import { useExtensionState } from "@/context/ExtensionStateContext"
@@ -1240,6 +1241,28 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage, modelIdErrorMessage, is
12401241
<span style={{ fontWeight: 500 }}>Model ID</span>
12411242
</VSCodeTextField>
12421243

1244+
<div style={{ display: "flex", flexDirection: "column", marginTop: 10, marginBottom: 10 }}>
1245+
{selectedModelInfo.supportsPromptCache && (
1246+
<>
1247+
<VSCodeCheckbox
1248+
checked={apiConfiguration?.liteLlmUsePromptCache || false}
1249+
onChange={(e: any) => {
1250+
const isChecked = e.target.checked === true
1251+
setApiConfiguration({
1252+
...apiConfiguration,
1253+
liteLlmUsePromptCache: isChecked,
1254+
})
1255+
}}
1256+
style={{ fontWeight: 500, color: "var(--vscode-charts-green)" }}>
1257+
Use prompt caching (GA)
1258+
</VSCodeCheckbox>
1259+
<p style={{ fontSize: "12px", marginTop: 3, color: "var(--vscode-charts-green)" }}>
1260+
Prompt caching requires a supported provider and model
1261+
</p>
1262+
</>
1263+
)}
1264+
</div>
1265+
12431266
<>
12441267
<ThinkingBudgetSlider apiConfiguration={apiConfiguration} setApiConfiguration={setApiConfiguration} />
12451268
<p
@@ -1778,7 +1801,7 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration):
17781801
return {
17791802
selectedProvider: provider,
17801803
selectedModelId: apiConfiguration?.liteLlmModelId || "",
1781-
selectedModelInfo: openAiModelInfoSaneDefaults,
1804+
selectedModelInfo: liteLlmModelInfoSaneDefaults,
17821805
}
17831806
case "xai":
17841807
return getProviderData(xaiModels, xaiDefaultModelId)

0 commit comments

Comments
 (0)