Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import { MistralHandler } from "./providers/mistral"
import { VsCodeLmHandler } from "./providers/vscode-lm"
import { ApiStream } from "./transform/stream"
import { UnboundHandler } from "./providers/unbound"
import { RequestyHandler } from "./providers/requesty"

export interface SingleCompletionHandler {
completePrompt(prompt: string): Promise<string>
Expand Down Expand Up @@ -56,6 +57,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
return new MistralHandler(options)
case "unbound":
return new UnboundHandler(options)
case "requesty":
return new RequestyHandler(options)
default:
return new AnthropicHandler(options)
}
Expand Down
247 changes: 247 additions & 0 deletions src/api/providers/__tests__/requesty.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"
import { ApiHandlerOptions, ModelInfo, requestyModelInfoSaneDefaults } from "../../../shared/api"
import { RequestyHandler } from "../requesty"
import { convertToOpenAiMessages } from "../../transform/openai-format"
import { convertToR1Format } from "../../transform/r1-format"

// Mock OpenAI and transform functions
jest.mock("openai")
jest.mock("../../transform/openai-format")
jest.mock("../../transform/r1-format")

describe("RequestyHandler", () => {
let handler: RequestyHandler
let mockCreate: jest.Mock

const defaultOptions: ApiHandlerOptions = {
requestyApiKey: "test-key",
requestyModelId: "test-model",
requestyModelInfo: {
maxTokens: 1000,
contextWindow: 4000,
supportsPromptCache: false,
supportsImages: true,
inputPrice: 0,
outputPrice: 0,
},
openAiStreamingEnabled: true,
includeMaxTokens: true, // Add this to match the implementation
}

beforeEach(() => {
// Clear mocks
jest.clearAllMocks()

// Setup mock create function
mockCreate = jest.fn()

// Mock OpenAI constructor
;(OpenAI as jest.MockedClass<typeof OpenAI>).mockImplementation(
() =>
({
chat: {
completions: {
create: mockCreate,
},
},
}) as unknown as OpenAI,
)

// Mock transform functions
;(convertToOpenAiMessages as jest.Mock).mockImplementation((messages) => messages)
;(convertToR1Format as jest.Mock).mockImplementation((messages) => messages)

// Create handler instance
handler = new RequestyHandler(defaultOptions)
})

describe("constructor", () => {
it("should initialize with correct options", () => {
expect(OpenAI).toHaveBeenCalledWith({
baseURL: "https://router.requesty.ai/v1",
apiKey: defaultOptions.requestyApiKey,
defaultHeaders: {
"HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline",
"X-Title": "Roo Code",
},
})
})
})

describe("createMessage", () => {
const systemPrompt = "You are a helpful assistant"
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }]

describe("with streaming enabled", () => {
beforeEach(() => {
const stream = {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [{ delta: { content: "Hello" } }],
}
yield {
choices: [{ delta: { content: " world" } }],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
},
}
},
}
mockCreate.mockResolvedValue(stream)
})

it("should handle streaming response correctly", async () => {
const stream = handler.createMessage(systemPrompt, messages)
const results = []

for await (const chunk of stream) {
results.push(chunk)
}

expect(results).toEqual([
{ type: "text", text: "Hello" },
{ type: "text", text: " world" },
{
type: "usage",
inputTokens: 10,
outputTokens: 5,
cacheWriteTokens: undefined,
cacheReadTokens: undefined,
},
])

expect(mockCreate).toHaveBeenCalledWith({
model: defaultOptions.requestyModelId,
temperature: 0,
messages: [
{ role: "system", content: systemPrompt },
{ role: "user", content: "Hello" },
],
stream: true,
stream_options: { include_usage: true },
max_tokens: defaultOptions.requestyModelInfo?.maxTokens,
})
})

it("should not include max_tokens when includeMaxTokens is false", async () => {
handler = new RequestyHandler({
...defaultOptions,
includeMaxTokens: false,
})

await handler.createMessage(systemPrompt, messages).next()

expect(mockCreate).toHaveBeenCalledWith(
expect.not.objectContaining({
max_tokens: expect.any(Number),
}),
)
})

it("should handle deepseek-reasoner model format", async () => {
handler = new RequestyHandler({
...defaultOptions,
requestyModelId: "deepseek-reasoner",
})

await handler.createMessage(systemPrompt, messages).next()

expect(convertToR1Format).toHaveBeenCalledWith([{ role: "user", content: systemPrompt }, ...messages])
})
})

describe("with streaming disabled", () => {
beforeEach(() => {
handler = new RequestyHandler({
...defaultOptions,
openAiStreamingEnabled: false,
})

mockCreate.mockResolvedValue({
choices: [{ message: { content: "Hello world" } }],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
},
})
})

it("should handle non-streaming response correctly", async () => {
const stream = handler.createMessage(systemPrompt, messages)
const results = []

for await (const chunk of stream) {
results.push(chunk)
}

expect(results).toEqual([
{ type: "text", text: "Hello world" },
{
type: "usage",
inputTokens: 10,
outputTokens: 5,
},
])

expect(mockCreate).toHaveBeenCalledWith({
model: defaultOptions.requestyModelId,
messages: [
{ role: "user", content: systemPrompt },
{ role: "user", content: "Hello" },
],
})
})
})
})

describe("getModel", () => {
it("should return correct model information", () => {
const result = handler.getModel()
expect(result).toEqual({
id: defaultOptions.requestyModelId,
info: defaultOptions.requestyModelInfo,
})
})

it("should use sane defaults when no model info provided", () => {
handler = new RequestyHandler({
...defaultOptions,
requestyModelInfo: undefined,
})

const result = handler.getModel()
expect(result).toEqual({
id: defaultOptions.requestyModelId,
info: requestyModelInfoSaneDefaults,
})
})
})

describe("completePrompt", () => {
beforeEach(() => {
mockCreate.mockResolvedValue({
choices: [{ message: { content: "Completed response" } }],
})
})

it("should complete prompt successfully", async () => {
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Completed response")
expect(mockCreate).toHaveBeenCalledWith({
model: defaultOptions.requestyModelId,
messages: [{ role: "user", content: "Test prompt" }],
})
})

it("should handle errors correctly", async () => {
const errorMessage = "API error"
mockCreate.mockRejectedValue(new Error(errorMessage))

await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
`OpenAI completion error: ${errorMessage}`,
)
})
})
})
6 changes: 3 additions & 3 deletions src/api/providers/deepseek.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import { OpenAiHandler } from "./openai"
import { ApiHandlerOptions, ModelInfo } from "../../shared/api"
import { OpenAiHandler, OpenAiHandlerOptions } from "./openai"
import { ModelInfo } from "../../shared/api"
import { deepSeekModels, deepSeekDefaultModelId } from "../../shared/api"

export class DeepSeekHandler extends OpenAiHandler {
constructor(options: ApiHandlerOptions) {
constructor(options: OpenAiHandlerOptions) {
super({
...options,
openAiApiKey: options.deepSeekApiKey ?? "not-provided",
Expand Down
32 changes: 18 additions & 14 deletions src/api/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@ import { ApiHandler, SingleCompletionHandler } from "../index"
import { convertToOpenAiMessages } from "../transform/openai-format"
import { convertToR1Format } from "../transform/r1-format"
import { convertToSimpleMessages } from "../transform/simple-format"
import { ApiStream } from "../transform/stream"
import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"

export interface OpenAiHandlerOptions extends ApiHandlerOptions {
defaultHeaders?: Record<string, string>
}

export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
protected options: ApiHandlerOptions
protected options: OpenAiHandlerOptions
private client: OpenAI

constructor(options: ApiHandlerOptions) {
constructor(options: OpenAiHandlerOptions) {
this.options = options

const baseURL = this.options.openAiBaseUrl ?? "https://api.openai.com/v1"
Expand All @@ -41,7 +45,7 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
apiVersion: this.options.azureApiVersion || azureOpenAiDefaultApiVersion,
})
} else {
this.client = new OpenAI({ baseURL, apiKey })
this.client = new OpenAI({ baseURL, apiKey, defaultHeaders: this.options.defaultHeaders })
}
}

Expand Down Expand Up @@ -98,11 +102,7 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
}
}
if (chunk.usage) {
yield {
type: "usage",
inputTokens: chunk.usage.prompt_tokens || 0,
outputTokens: chunk.usage.completion_tokens || 0,
}
yield this.processUsageMetrics(chunk.usage)
}
}
} else {
Expand All @@ -125,11 +125,15 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
type: "text",
text: response.choices[0]?.message.content || "",
}
yield {
type: "usage",
inputTokens: response.usage?.prompt_tokens || 0,
outputTokens: response.usage?.completion_tokens || 0,
}
yield this.processUsageMetrics(response.usage)
}
}

protected processUsageMetrics(usage: any): ApiStreamUsageChunk {
return {
type: "usage",
inputTokens: usage?.prompt_tokens || 0,
outputTokens: usage?.completion_tokens || 0,
}
}

Expand Down
40 changes: 40 additions & 0 deletions src/api/providers/requesty.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import { OpenAiHandler, OpenAiHandlerOptions } from "./openai"
import { ModelInfo, requestyModelInfoSaneDefaults, requestyDefaultModelId } from "../../shared/api"
import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"

export class RequestyHandler extends OpenAiHandler {
constructor(options: OpenAiHandlerOptions) {
if (!options.requestyApiKey) {
throw new Error("Requesty API key is required. Please provide it in the settings.")
}
super({
...options,
openAiApiKey: options.requestyApiKey,
openAiModelId: options.requestyModelId ?? requestyDefaultModelId,
openAiBaseUrl: "https://router.requesty.ai/v1",
openAiCustomModelInfo: options.requestyModelInfo ?? requestyModelInfoSaneDefaults,
defaultHeaders: {
"HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline",
"X-Title": "Roo Code",
},
})
}

override getModel(): { id: string; info: ModelInfo } {
const modelId = this.options.requestyModelId ?? requestyDefaultModelId
return {
id: modelId,
info: this.options.requestyModelInfo ?? requestyModelInfoSaneDefaults,
}
}

protected override processUsageMetrics(usage: any): ApiStreamUsageChunk {
return {
type: "usage",
inputTokens: usage?.prompt_tokens || 0,
outputTokens: usage?.completion_tokens || 0,
cacheWriteTokens: usage?.cache_creation_input_tokens,
cacheReadTokens: usage?.cache_read_input_tokens,
}
}
}
Loading
Loading