Skip to content

Commit 54820c9

Browse files
committed
feat: add prompt_cache_key
1 parent 765f0aa commit 54820c9

29 files changed

+122
-42
lines changed

src/api/index.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ import {
4444
import { NativeOllamaHandler } from "./providers/native-ollama"
4545

4646
export interface SingleCompletionHandler {
47-
completePrompt(prompt: string): Promise<string>
47+
completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise<string>
4848
}
4949

5050
export interface ApiHandlerCreateMessageMetadata {
@@ -65,6 +65,7 @@ export interface ApiHandlerCreateMessageMetadata {
6565
* @default true
6666
*/
6767
store?: boolean
68+
safetyIdentifier?: string
6869
}
6970

7071
export interface ApiHandler {

src/api/providers/anthropic-vertex.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ export class AnthropicVertexHandler extends BaseProvider implements SingleComple
175175
return { id: id.endsWith(":thinking") ? id.replace(":thinking", "") : id, info, ...params }
176176
}
177177

178-
async completePrompt(prompt: string) {
178+
async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata) {
179179
try {
180180
let {
181181
id,

src/api/providers/anthropic.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa
278278
}
279279
}
280280

281-
async completePrompt(prompt: string) {
281+
async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata) {
282282
let { id: model, temperature } = this.getModel()
283283

284284
const message = await this.client.messages.create({

src/api/providers/base-openai-compatible-provider.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
8383
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
8484
stream: true,
8585
stream_options: { include_usage: true },
86+
prompt_cache_key: metadata?.taskId,
87+
safety_identifier: metadata?.safetyIdentifier,
8688
}
8789

8890
try {
@@ -119,13 +121,15 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
119121
}
120122
}
121123

122-
async completePrompt(prompt: string): Promise<string> {
124+
async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise<string> {
123125
const { id: modelId } = this.getModel()
124126

125127
try {
126128
const response = await this.client.chat.completions.create({
127129
model: modelId,
128130
messages: [{ role: "user", content: prompt }],
131+
prompt_cache_key: metadata?.taskId,
132+
safety_identifier: metadata?.safetyIdentifier,
129133
})
130134

131135
return response.choices[0]?.message.content || ""

src/api/providers/bedrock.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,7 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
633633
}
634634
}
635635

636-
async completePrompt(prompt: string): Promise<string> {
636+
async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise<string> {
637637
try {
638638
const modelConfig = this.getModel()
639639

src/api/providers/cerebras.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,14 +277,16 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan
277277
}
278278
}
279279

280-
async completePrompt(prompt: string): Promise<string> {
280+
async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise<string> {
281281
const { id: model } = this.getModel()
282282

283283
// Prepare request body for non-streaming completion
284284
const requestBody = {
285285
model,
286286
messages: [{ role: "user", content: prompt }],
287287
stream: false,
288+
prompt_cache_key: metadata?.taskId,
289+
safety_identifier: metadata?.safetyIdentifier,
288290
}
289291

290292
try {

src/api/providers/chutes.ts

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import { convertToOpenAiMessages } from "../transform/openai-format"
99
import { ApiStream } from "../transform/stream"
1010

1111
import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"
12+
import { ApiHandlerCreateMessageMetadata } from ".."
1213

1314
export class ChutesHandler extends BaseOpenAiCompatibleProvider<ChutesModelId> {
1415
constructor(options: ApiHandlerOptions) {
@@ -44,13 +45,19 @@ export class ChutesHandler extends BaseOpenAiCompatibleProvider<ChutesModelId> {
4445
}
4546
}
4647

47-
override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
48+
override async *createMessage(
49+
systemPrompt: string,
50+
messages: Anthropic.Messages.MessageParam[],
51+
metadata?: ApiHandlerCreateMessageMetadata,
52+
): ApiStream {
4853
const model = this.getModel()
4954

5055
if (model.id.includes("DeepSeek-R1")) {
5156
const stream = await this.client.chat.completions.create({
5257
...this.getCompletionParams(systemPrompt, messages),
5358
messages: convertToR1Format([{ role: "user", content: systemPrompt }, ...messages]),
59+
prompt_cache_key: metadata?.taskId,
60+
safety_identifier: metadata?.safetyIdentifier,
5461
})
5562

5663
const matcher = new XmlMatcher(
@@ -85,7 +92,7 @@ export class ChutesHandler extends BaseOpenAiCompatibleProvider<ChutesModelId> {
8592
yield processedChunk
8693
}
8794
} else {
88-
yield* super.createMessage(systemPrompt, messages)
95+
yield* super.createMessage(systemPrompt, messages, metadata)
8996
}
9097
}
9198

src/api/providers/deepinfra.ts

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,15 @@ export class DeepInfraHandler extends RouterProvider implements SingleCompletion
6060
// Ensure we have up-to-date model metadata
6161
await this.fetchModel()
6262
const { id: modelId, info, reasoningEffort: reasoning_effort } = await this.fetchModel()
63-
let prompt_cache_key = undefined
64-
if (info.supportsPromptCache && _metadata?.taskId) {
65-
prompt_cache_key = _metadata.taskId
66-
}
6763

6864
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
6965
model: modelId,
7066
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
7167
stream: true,
7268
stream_options: { include_usage: true },
7369
reasoning_effort,
74-
prompt_cache_key,
70+
prompt_cache_key: _metadata?.taskId,
71+
safety_identifier: _metadata?.safetyIdentifier,
7572
} as OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming
7673

7774
if (this.supportsTemperature(modelId)) {
@@ -106,13 +103,15 @@ export class DeepInfraHandler extends RouterProvider implements SingleCompletion
106103
}
107104
}
108105

109-
async completePrompt(prompt: string): Promise<string> {
106+
async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise<string> {
110107
await this.fetchModel()
111108
const { id: modelId, info } = this.getModel()
112109

113110
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
114111
model: modelId,
115112
messages: [{ role: "user", content: prompt }],
113+
prompt_cache_key: metadata?.taskId,
114+
safety_identifier: metadata?.safetyIdentifier,
116115
}
117116
if (this.supportsTemperature(modelId)) {
118117
requestOptions.temperature = this.options.modelTemperature ?? 0

src/api/providers/fake-ai.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ interface FakeAI {
2828
): ApiStream
2929
getModel(): { id: string; info: ModelInfo }
3030
countTokens(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number>
31-
completePrompt(prompt: string): Promise<string>
31+
completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise<string>
3232
}
3333

3434
/**
@@ -75,7 +75,7 @@ export class FakeAIHandler implements ApiHandler, SingleCompletionHandler {
7575
return this.ai.countTokens(content)
7676
}
7777

78-
completePrompt(prompt: string): Promise<string> {
78+
completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise<string> {
7979
return this.ai.completePrompt(prompt)
8080
}
8181
}

src/api/providers/featherless.ts

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import { convertToOpenAiMessages } from "../transform/openai-format"
1414
import { ApiStream } from "../transform/stream"
1515

1616
import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"
17+
import { ApiHandlerCreateMessageMetadata } from ".."
1718

1819
export class FeatherlessHandler extends BaseOpenAiCompatibleProvider<FeatherlessModelId> {
1920
constructor(options: ApiHandlerOptions) {
@@ -31,6 +32,7 @@ export class FeatherlessHandler extends BaseOpenAiCompatibleProvider<Featherless
3132
private getCompletionParams(
3233
systemPrompt: string,
3334
messages: Anthropic.Messages.MessageParam[],
35+
metadata?: ApiHandlerCreateMessageMetadata,
3436
): OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming {
3537
const {
3638
id: model,
@@ -46,15 +48,21 @@ export class FeatherlessHandler extends BaseOpenAiCompatibleProvider<Featherless
4648
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
4749
stream: true,
4850
stream_options: { include_usage: true },
51+
prompt_cache_key: metadata?.taskId,
52+
safety_identifier: metadata?.safetyIdentifier,
4953
}
5054
}
5155

52-
override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
56+
override async *createMessage(
57+
systemPrompt: string,
58+
messages: Anthropic.Messages.MessageParam[],
59+
metadata?: ApiHandlerCreateMessageMetadata,
60+
): ApiStream {
5361
const model = this.getModel()
5462

5563
if (model.id.includes("DeepSeek-R1")) {
5664
const stream = await this.client.chat.completions.create({
57-
...this.getCompletionParams(systemPrompt, messages),
65+
...this.getCompletionParams(systemPrompt, messages, metadata),
5866
messages: convertToR1Format([{ role: "user", content: systemPrompt }, ...messages]),
5967
})
6068

@@ -90,7 +98,7 @@ export class FeatherlessHandler extends BaseOpenAiCompatibleProvider<Featherless
9098
yield processedChunk
9199
}
92100
} else {
93-
yield* super.createMessage(systemPrompt, messages)
101+
yield* super.createMessage(systemPrompt, messages, metadata)
94102
}
95103
}
96104

0 commit comments

Comments
 (0)