Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import {
import { NativeOllamaHandler } from "./providers/native-ollama"

export interface SingleCompletionHandler {
completePrompt(prompt: string): Promise<string>
completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise<string>
}

export interface ApiHandlerCreateMessageMetadata {
Expand All @@ -65,6 +65,7 @@ export interface ApiHandlerCreateMessageMetadata {
* @default true
*/
store?: boolean
safetyIdentifier?: string
}

export interface ApiHandler {
Expand Down
2 changes: 1 addition & 1 deletion src/api/providers/anthropic-vertex.ts
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ export class AnthropicVertexHandler extends BaseProvider implements SingleComple
return { id: id.endsWith(":thinking") ? id.replace(":thinking", "") : id, info, ...params }
}

async completePrompt(prompt: string) {
async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata) {
try {
let {
id,
Expand Down
2 changes: 1 addition & 1 deletion src/api/providers/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa
}
}

async completePrompt(prompt: string) {
async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata) {
let { id: model, temperature } = this.getModel()

const message = await this.client.messages.create({
Expand Down
6 changes: 5 additions & 1 deletion src/api/providers/base-openai-compatible-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
stream: true,
stream_options: { include_usage: true },
prompt_cache_key: metadata?.taskId,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0: Official OpenAI endpoints may reject unknown request args (e.g., 'Unrecognized request argument: prompt_cache_key'). Please gate 'prompt_cache_key' and 'safety_identifier' so they’re only sent to endpoints that accept them (OpenRouter, vLLM/sglang gateways, etc.). Otherwise this can cause 400s for users on api.openai.com.

safety_identifier: metadata?.safetyIdentifier,
}

try {
Expand Down Expand Up @@ -119,13 +121,15 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
}
}

async completePrompt(prompt: string): Promise<string> {
async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise<string> {
const { id: modelId } = this.getModel()

try {
const response = await this.client.chat.completions.create({
model: modelId,
messages: [{ role: "user", content: prompt }],
prompt_cache_key: metadata?.taskId,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1: The OpenAI SDK param types may not permit extra fields. If you keep these fields, ensure types allow them or route via a supported 'extra body' mechanism. Otherwise TS type-check or runtime validation could fail depending on the SDK/version.

safety_identifier: metadata?.safetyIdentifier,
})

return response.choices[0]?.message.content || ""
Expand Down
2 changes: 1 addition & 1 deletion src/api/providers/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
}
}

async completePrompt(prompt: string): Promise<string> {
async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise<string> {
try {
const modelConfig = this.getModel()

Expand Down
4 changes: 3 additions & 1 deletion src/api/providers/cerebras.ts
Original file line number Diff line number Diff line change
Expand Up @@ -277,14 +277,16 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan
}
}

async completePrompt(prompt: string): Promise<string> {
async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise<string> {
const { id: model } = this.getModel()

// Prepare request body for non-streaming completion
const requestBody = {
model,
messages: [{ role: "user", content: prompt }],
stream: false,
prompt_cache_key: metadata?.taskId,
safety_identifier: metadata?.safetyIdentifier,
}

try {
Expand Down
11 changes: 9 additions & 2 deletions src/api/providers/chutes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream"

import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"
import { ApiHandlerCreateMessageMetadata } from ".."

export class ChutesHandler extends BaseOpenAiCompatibleProvider<ChutesModelId> {
constructor(options: ApiHandlerOptions) {
Expand Down Expand Up @@ -44,13 +45,19 @@ export class ChutesHandler extends BaseOpenAiCompatibleProvider<ChutesModelId> {
}
}

override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
override async *createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
metadata?: ApiHandlerCreateMessageMetadata,
): ApiStream {
const model = this.getModel()

if (model.id.includes("DeepSeek-R1")) {
const stream = await this.client.chat.completions.create({
...this.getCompletionParams(systemPrompt, messages),
messages: convertToR1Format([{ role: "user", content: systemPrompt }, ...messages]),
prompt_cache_key: metadata?.taskId,
safety_identifier: metadata?.safetyIdentifier,
})

const matcher = new XmlMatcher(
Expand Down Expand Up @@ -85,7 +92,7 @@ export class ChutesHandler extends BaseOpenAiCompatibleProvider<ChutesModelId> {
yield processedChunk
}
} else {
yield* super.createMessage(systemPrompt, messages)
yield* super.createMessage(systemPrompt, messages, metadata)
}
}

Expand Down
11 changes: 5 additions & 6 deletions src/api/providers/deepinfra.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,15 @@ export class DeepInfraHandler extends RouterProvider implements SingleCompletion
// Ensure we have up-to-date model metadata
await this.fetchModel()
const { id: modelId, info, reasoningEffort: reasoning_effort } = await this.fetchModel()
let prompt_cache_key = undefined
if (info.supportsPromptCache && _metadata?.taskId) {
prompt_cache_key = _metadata.taskId
}

const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
model: modelId,
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
stream: true,
stream_options: { include_usage: true },
reasoning_effort,
prompt_cache_key,
prompt_cache_key: _metadata?.taskId,
safety_identifier: _metadata?.safetyIdentifier,
} as OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming

if (this.supportsTemperature(modelId)) {
Expand Down Expand Up @@ -106,13 +103,15 @@ export class DeepInfraHandler extends RouterProvider implements SingleCompletion
}
}

async completePrompt(prompt: string): Promise<string> {
async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise<string> {
await this.fetchModel()
const { id: modelId, info } = this.getModel()

const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: modelId,
messages: [{ role: "user", content: prompt }],
prompt_cache_key: metadata?.taskId,
safety_identifier: metadata?.safetyIdentifier,
}
if (this.supportsTemperature(modelId)) {
requestOptions.temperature = this.options.modelTemperature ?? 0
Expand Down
4 changes: 2 additions & 2 deletions src/api/providers/fake-ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ interface FakeAI {
): ApiStream
getModel(): { id: string; info: ModelInfo }
countTokens(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number>
completePrompt(prompt: string): Promise<string>
completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise<string>
}

/**
Expand Down Expand Up @@ -75,7 +75,7 @@ export class FakeAIHandler implements ApiHandler, SingleCompletionHandler {
return this.ai.countTokens(content)
}

completePrompt(prompt: string): Promise<string> {
completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise<string> {
return this.ai.completePrompt(prompt)
}
}
21 changes: 17 additions & 4 deletions src/api/providers/featherless.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import { DEEP_SEEK_DEFAULT_TEMPERATURE, type FeatherlessModelId, featherlessDefaultModelId, featherlessModels } from "@roo-code/types"
import {
DEEP_SEEK_DEFAULT_TEMPERATURE,
type FeatherlessModelId,
featherlessDefaultModelId,
featherlessModels,
} from "@roo-code/types"
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"

Expand All @@ -9,6 +14,7 @@ import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream"

import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"
import { ApiHandlerCreateMessageMetadata } from ".."

export class FeatherlessHandler extends BaseOpenAiCompatibleProvider<FeatherlessModelId> {
constructor(options: ApiHandlerOptions) {
Expand All @@ -26,6 +32,7 @@ export class FeatherlessHandler extends BaseOpenAiCompatibleProvider<Featherless
private getCompletionParams(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
metadata?: ApiHandlerCreateMessageMetadata,
): OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming {
const {
id: model,
Expand All @@ -41,15 +48,21 @@ export class FeatherlessHandler extends BaseOpenAiCompatibleProvider<Featherless
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
stream: true,
stream_options: { include_usage: true },
prompt_cache_key: metadata?.taskId,
safety_identifier: metadata?.safetyIdentifier,
}
}

override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
override async *createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
metadata?: ApiHandlerCreateMessageMetadata,
): ApiStream {
const model = this.getModel()

if (model.id.includes("DeepSeek-R1")) {
const stream = await this.client.chat.completions.create({
...this.getCompletionParams(systemPrompt, messages),
...this.getCompletionParams(systemPrompt, messages, metadata),
messages: convertToR1Format([{ role: "user", content: systemPrompt }, ...messages]),
})

Expand Down Expand Up @@ -85,7 +98,7 @@ export class FeatherlessHandler extends BaseOpenAiCompatibleProvider<Featherless
yield processedChunk
}
} else {
yield* super.createMessage(systemPrompt, messages)
yield* super.createMessage(systemPrompt, messages, metadata)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/api/providers/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
return citationLinks.join(", ")
}

async completePrompt(prompt: string): Promise<string> {
async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise<string> {
try {
const { id: model } = this.getModel()

Expand Down
4 changes: 3 additions & 1 deletion src/api/providers/glama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,15 @@ export class GlamaHandler extends RouterProvider implements SingleCompletionHand
}
}

async completePrompt(prompt: string): Promise<string> {
async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise<string> {
const { id: modelId, info } = await this.fetchModel()

try {
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: modelId,
messages: [{ role: "user", content: prompt }],
prompt_cache_key: metadata?.taskId,
safety_identifier: metadata?.safetyIdentifier,
}

if (this.supportsTemperature(modelId)) {
Expand Down
6 changes: 5 additions & 1 deletion src/api/providers/huggingface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ export class HuggingFaceHandler extends BaseProvider implements SingleCompletion
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
stream: true,
stream_options: { include_usage: true },
prompt_cache_key: metadata?.taskId,
safety_identifier: metadata?.safetyIdentifier,
}

// Add max_tokens if specified
Expand Down Expand Up @@ -93,13 +95,15 @@ export class HuggingFaceHandler extends BaseProvider implements SingleCompletion
}
}

async completePrompt(prompt: string): Promise<string> {
async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise<string> {
const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct"

try {
const response = await this.client.chat.completions.create({
model: modelId,
messages: [{ role: "user", content: prompt }],
prompt_cache_key: metadata?.taskId,
safety_identifier: metadata?.safetyIdentifier,
})

return response.choices[0]?.message.content || ""
Expand Down
2 changes: 1 addition & 1 deletion src/api/providers/human-relay.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ export class HumanRelayHandler implements ApiHandler, SingleCompletionHandler {
* Implementation of a single prompt
* @param prompt Prompt content
*/
async completePrompt(prompt: string): Promise<string> {
async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise<string> {
// Copy to clipboard
await vscode.env.clipboard.writeText(prompt)

Expand Down
6 changes: 5 additions & 1 deletion src/api/providers/lite-llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa
stream_options: {
include_usage: true,
},
prompt_cache_key: metadata?.taskId,
safety_identifier: metadata?.safetyIdentifier,
}

// GPT-5 models require max_completion_tokens instead of the deprecated max_tokens parameter
Expand Down Expand Up @@ -191,7 +193,7 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa
}
}

async completePrompt(prompt: string): Promise<string> {
async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise<string> {
const { id: modelId, info } = await this.fetchModel()

// Check if this is a GPT-5 model that requires max_completion_tokens instead of max_tokens
Expand All @@ -201,6 +203,8 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: modelId,
messages: [{ role: "user", content: prompt }],
prompt_cache_key: metadata?.taskId,
safety_identifier: metadata?.safetyIdentifier,
}

if (this.supportsTemperature(modelId)) {
Expand Down
6 changes: 5 additions & 1 deletion src/api/providers/lm-studio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan
messages: openAiMessages,
temperature: this.options.modelTemperature ?? LMSTUDIO_DEFAULT_TEMPERATURE,
stream: true,
prompt_cache_key: metadata?.taskId,
safety_identifier: metadata?.safetyIdentifier,
}

if (this.options.lmStudioSpeculativeDecodingEnabled && this.options.lmStudioDraftModelId) {
Expand Down Expand Up @@ -159,14 +161,16 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan
}
}

async completePrompt(prompt: string): Promise<string> {
async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise<string> {
try {
// Create params object with optional draft model
const params: any = {
model: this.getModel().id,
messages: [{ role: "user", content: prompt }],
temperature: this.options.modelTemperature ?? LMSTUDIO_DEFAULT_TEMPERATURE,
stream: false,
prompt_cache_key: metadata?.taskId,
safety_identifier: metadata?.safetyIdentifier,
}

// Add draft model if speculative decoding is enabled and a draft model is specified
Expand Down
2 changes: 1 addition & 1 deletion src/api/providers/mistral.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ export class MistralHandler extends BaseProvider implements SingleCompletionHand
return { id, info, maxTokens, temperature }
}

async completePrompt(prompt: string): Promise<string> {
async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise<string> {
try {
const { id: model, temperature } = this.getModel()

Expand Down
2 changes: 1 addition & 1 deletion src/api/providers/native-ollama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ export class NativeOllamaHandler extends BaseProvider implements SingleCompletio
}
}

async completePrompt(prompt: string): Promise<string> {
async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise<string> {
try {
const client = this.ensureClient()
const { id: modelId } = await this.fetchModel()
Expand Down
Loading
Loading