Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions packages/types/src/provider-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import {
qwenCodeModels,
rooModels,
sambaNovaModels,
tarsModels,
vertexModels,
vscodeLlmModels,
xaiModels,
Expand Down Expand Up @@ -66,6 +67,7 @@ export const providerNames = [
"featherless",
"io-intelligence",
"roo",
"tars",
"vercel-ai-gateway",
] as const

Expand Down Expand Up @@ -265,6 +267,11 @@ const requestySchema = baseProviderSettingsSchema.extend({
requestyModelId: z.string().optional(),
})

const tarsSchema = baseProviderSettingsSchema.extend({
tarsApiKey: z.string().optional(),
tarsModelId: z.string().optional(),
})

const humanRelaySchema = baseProviderSettingsSchema

const fakeAiSchema = baseProviderSettingsSchema.extend({
Expand Down Expand Up @@ -359,6 +366,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv
moonshotSchema.merge(z.object({ apiProvider: z.literal("moonshot") })),
unboundSchema.merge(z.object({ apiProvider: z.literal("unbound") })),
requestySchema.merge(z.object({ apiProvider: z.literal("requesty") })),
tarsSchema.merge(z.object({ apiProvider: z.literal("tars") })),
humanRelaySchema.merge(z.object({ apiProvider: z.literal("human-relay") })),
fakeAiSchema.merge(z.object({ apiProvider: z.literal("fake-ai") })),
xaiSchema.merge(z.object({ apiProvider: z.literal("xai") })),
Expand Down Expand Up @@ -399,6 +407,7 @@ export const providerSettingsSchema = z.object({
...moonshotSchema.shape,
...unboundSchema.shape,
...requestySchema.shape,
...tarsSchema.shape,
...humanRelaySchema.shape,
...fakeAiSchema.shape,
...xaiSchema.shape,
Expand All @@ -414,6 +423,7 @@ export const providerSettingsSchema = z.object({
...ioIntelligenceSchema.shape,
...qwenCodeSchema.shape,
...rooSchema.shape,
...tarsSchema.shape,
...vercelAiGatewaySchema.shape,
...codebaseIndexProviderSchema.shape,
})
Expand Down Expand Up @@ -543,6 +553,11 @@ export const MODELS_BY_PROVIDER: Record<
label: "SambaNova",
models: Object.keys(sambaNovaModels),
},
tars: {
id: "tars",
label: "Tars",
models: Object.keys(tarsModels),
},
vertex: {
id: "vertex",
label: "GCP Vertex AI",
Expand Down
1 change: 1 addition & 0 deletions packages/types/src/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ export * from "./qwen-code.js"
export * from "./requesty.js"
export * from "./roo.js"
export * from "./sambanova.js"
export * from "./tars.js"
export * from "./unbound.js"
export * from "./vertex.js"
export * from "./vscode-llm.js"
Expand Down
21 changes: 21 additions & 0 deletions packages/types/src/providers/tars.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import type { ModelInfo } from "../model.js"

export const tarsDefaultModelId = "claude-3-5-haiku-20241022"

export const tarsDefaultModelInfo: ModelInfo = {
maxTokens: 8192,
contextWindow: 200000,
supportsImages: true,
supportsComputerUse: false,
supportsPromptCache: true,
inputPrice: 0.8,
outputPrice: 4.0,
cacheWritesPrice: 1.0,
cacheReadsPrice: 0.08,
description:
"Claude 3.5 Haiku - Fast and cost-effective with excellent coding capabilities. Ideal for development tasks with 200k context window",
}

export const tarsModels = {
[tarsDefaultModelId]: tarsDefaultModelInfo,
} as const satisfies Record<string, ModelInfo>
3 changes: 3 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
VsCodeLmHandler,
UnboundHandler,
RequestyHandler,
TarsHandler,
HumanRelayHandler,
FakeAIHandler,
XAIHandler,
Expand Down Expand Up @@ -130,6 +131,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
return new UnboundHandler(options)
case "requesty":
return new RequestyHandler(options)
case "tars":
return new TarsHandler(options)
case "human-relay":
return new HumanRelayHandler()
case "fake-ai":
Expand Down
234 changes: 234 additions & 0 deletions src/api/providers/__tests__/tars.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
// npx vitest run api/providers/__tests__/tars.spec.ts

import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"

import { TarsHandler } from "../tars"
import { ApiHandlerOptions } from "../../../shared/api"
import { Package } from "../../../shared/package"

const mockCreate = vitest.fn()

vitest.mock("openai", () => {
return {
default: vitest.fn().mockImplementation(() => ({
chat: {
completions: {
create: mockCreate,
},
},
})),
}
})

vitest.mock("delay", () => ({ default: vitest.fn(() => Promise.resolve()) }))

vitest.mock("../fetchers/modelCache", () => ({
getModels: vitest.fn().mockImplementation(() => {
return Promise.resolve({
"gpt-4o": {
maxTokens: 16384,
contextWindow: 128000,
supportsImages: true,
supportsPromptCache: true,
supportsComputerUse: false,
inputPrice: 2.5,
outputPrice: 10.0,
cacheWritesPrice: 0,
cacheReadsPrice: 0,
description:
"OpenAI GPT-4o model routed through TARS for optimal performance and reliability. TARS automatically selects the best available provider.",
},
})
}),
}))

describe("TarsHandler", () => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test coverage looks good! Consider adding a few more edge case tests:

  • Handling when TARS API returns models with missing fields
  • Testing the behavior when API key is invalid
  • Testing model fetching with network errors

These additional tests would help ensure robustness of the implementation.

const mockOptions: ApiHandlerOptions = {
tarsApiKey: "test-key",
tarsModelId: "gpt-4o",
}

beforeEach(() => vitest.clearAllMocks())

it("initializes with correct options", () => {
const handler = new TarsHandler(mockOptions)
expect(handler).toBeInstanceOf(TarsHandler)

expect(OpenAI).toHaveBeenCalledWith({
baseURL: "https://api.router.tetrate.ai/v1",
apiKey: mockOptions.tarsApiKey,
defaultHeaders: {
"HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline",
"X-Title": "Roo Code",
"User-Agent": `RooCode/${Package.version}`,
},
})
})

describe("fetchModel", () => {
it("returns correct model info when options are provided", async () => {
const handler = new TarsHandler(mockOptions)
const result = await handler.fetchModel()

expect(result).toMatchObject({
id: mockOptions.tarsModelId,
info: {
maxTokens: 16384,
contextWindow: 128000,
supportsImages: true,
supportsPromptCache: true,
supportsComputerUse: false,
inputPrice: 2.5,
outputPrice: 10.0,
cacheWritesPrice: 0,
cacheReadsPrice: 0,
description:
"OpenAI GPT-4o model routed through TARS for optimal performance and reliability. TARS automatically selects the best available provider.",
},
})
})

it("returns default model info when options are not provided", async () => {
const handler = new TarsHandler({})
const result = await handler.fetchModel()

expect(result).toMatchObject({
id: "claude-3-5-haiku-20241022",
info: {
maxTokens: 8192,
contextWindow: 200000,
supportsImages: true,
supportsPromptCache: true,
supportsComputerUse: false,
inputPrice: 0.8,
outputPrice: 4.0,
cacheWritesPrice: 1.0,
cacheReadsPrice: 0.08,
description:
"Claude 3.5 Haiku - Fast and cost-effective with excellent coding capabilities. Ideal for development tasks with 200k context window",
},
})
})
})

describe("createMessage", () => {
it("generates correct stream chunks", async () => {
const handler = new TarsHandler(mockOptions)

const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
id: mockOptions.tarsModelId,
choices: [{ delta: { content: "test response" } }],
}
yield {
id: "test-id",
choices: [{ delta: { reasoning_content: "test reasoning" } }],
}
yield {
id: "test-id",
choices: [{ delta: {} }],
usage: {
prompt_tokens: 10,
completion_tokens: 20,
prompt_tokens_details: {
caching_tokens: 5,
cached_tokens: 2,
},
},
}
},
}

mockCreate.mockResolvedValue(mockStream)

const systemPrompt = "test system prompt"
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user" as const, content: "test message" }]
const metadata = { taskId: "test-task-id", mode: "test-mode" }

const generator = handler.createMessage(systemPrompt, messages, metadata)
const chunks = []

for await (const chunk of generator) {
chunks.push(chunk)
}

// Verify stream chunks
expect(chunks).toHaveLength(3) // text, reasoning, and usage chunks
expect(chunks[0]).toEqual({ type: "text", text: "test response" })
expect(chunks[1]).toEqual({ type: "reasoning", text: "test reasoning" })
expect(chunks[2]).toEqual({
type: "usage",
inputTokens: 10,
outputTokens: 20,
cacheWriteTokens: 5,
cacheReadTokens: 2,
totalCost: expect.any(Number),
})

// Verify OpenAI client was called with correct parameters
expect(mockCreate).toHaveBeenCalledWith({
max_tokens: 16384,
messages: [
{
role: "system",
content: "test system prompt",
},
{
role: "user",
content: "test message",
},
],
model: "gpt-4o",
stream: true,
stream_options: { include_usage: true },
temperature: 0,
})
})

it("handles API errors", async () => {
const handler = new TarsHandler(mockOptions)
const mockError = new Error("API Error")
mockCreate.mockRejectedValue(mockError)

const generator = handler.createMessage("test", [])
await expect(generator.next()).rejects.toThrow("API Error")
})
})

describe("completePrompt", () => {
it("returns correct response", async () => {
const handler = new TarsHandler(mockOptions)
const mockResponse = { choices: [{ message: { content: "test completion" } }] }

mockCreate.mockResolvedValue(mockResponse)

const result = await handler.completePrompt("test prompt")

expect(result).toBe("test completion")

expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.tarsModelId,
max_tokens: 16384,
messages: [{ role: "system", content: "test prompt" }],
temperature: 0,
})
})

it("handles API errors", async () => {
const handler = new TarsHandler(mockOptions)
const mockError = new Error("API Error")
mockCreate.mockRejectedValue(mockError)

await expect(handler.completePrompt("test prompt")).rejects.toThrow("API Error")
})

it("handles unexpected errors", async () => {
const handler = new TarsHandler(mockOptions)
mockCreate.mockRejectedValue(new Error("Unexpected error"))

await expect(handler.completePrompt("test prompt")).rejects.toThrow("Unexpected error")
})
})
})
5 changes: 5 additions & 0 deletions src/api/providers/fetchers/modelCache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { fileExistsAtPath } from "../../../utils/fs"
import { getOpenRouterModels } from "./openrouter"
import { getVercelAiGatewayModels } from "./vercel-ai-gateway"
import { getRequestyModels } from "./requesty"
import { getTarsModels } from "./tars"
import { getGlamaModels } from "./glama"
import { getUnboundModels } from "./unbound"
import { getLiteLLMModels } from "./litellm"
Expand Down Expand Up @@ -62,6 +63,10 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
// Requesty models endpoint requires an API key for per-user custom policies
models = await getRequestyModels(options.baseUrl, options.apiKey)
break
case "tars":
// TARS models endpoint requires an API key
models = await getTarsModels(options.apiKey)
break
case "glama":
models = await getGlamaModels()
break
Expand Down
Loading
Loading