Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions packages/types/src/provider-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ export const providerNames = [
"sambanova",
"zai",
"fireworks",
"tars",
] as const

export const providerNamesSchema = z.enum(providerNames)
Expand Down Expand Up @@ -220,6 +221,11 @@ const requestySchema = baseProviderSettingsSchema.extend({
requestyModelId: z.string().optional(),
})

const tarsSchema = baseProviderSettingsSchema.extend({
tarsApiKey: z.string().optional(),
tarsModelId: z.string().optional(),
})

const humanRelaySchema = baseProviderSettingsSchema

const fakeAiSchema = baseProviderSettingsSchema.extend({
Expand Down Expand Up @@ -292,6 +298,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv
moonshotSchema.merge(z.object({ apiProvider: z.literal("moonshot") })),
unboundSchema.merge(z.object({ apiProvider: z.literal("unbound") })),
requestySchema.merge(z.object({ apiProvider: z.literal("requesty") })),
tarsSchema.merge(z.object({ apiProvider: z.literal("tars") })),
humanRelaySchema.merge(z.object({ apiProvider: z.literal("human-relay") })),
fakeAiSchema.merge(z.object({ apiProvider: z.literal("fake-ai") })),
xaiSchema.merge(z.object({ apiProvider: z.literal("xai") })),
Expand Down Expand Up @@ -327,6 +334,7 @@ export const providerSettingsSchema = z.object({
...moonshotSchema.shape,
...unboundSchema.shape,
...requestySchema.shape,
...tarsSchema.shape,
...humanRelaySchema.shape,
...fakeAiSchema.shape,
...xaiSchema.shape,
Expand Down
1 change: 1 addition & 0 deletions packages/types/src/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export * from "./openai.js"
export * from "./openrouter.js"
export * from "./requesty.js"
export * from "./sambanova.js"
export * from "./tars.js"
export * from "./unbound.js"
export * from "./vertex.js"
export * from "./vscode-llm.js"
Expand Down
17 changes: 17 additions & 0 deletions packages/types/src/providers/tars.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import type { ModelInfo } from "../model.js"

export const tarsDefaultModelId = "claude-3-5-haiku-20241022"

export const tarsDefaultModelInfo: ModelInfo = {
maxTokens: 8192,
contextWindow: 200000,
supportsImages: true,
supportsComputerUse: false,
supportsPromptCache: true,
inputPrice: 0.8,
outputPrice: 4.0,
cacheWritesPrice: 1.0,
cacheReadsPrice: 0.08,
description:
"Claude 3.5 Haiku - Fast and cost-effective with excellent coding capabilities. Ideal for development tasks with 200k context window",
}
3 changes: 3 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {
VsCodeLmHandler,
UnboundHandler,
RequestyHandler,
TarsHandler,
HumanRelayHandler,
FakeAIHandler,
XAIHandler,
Expand Down Expand Up @@ -108,6 +109,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
return new UnboundHandler(options)
case "requesty":
return new RequestyHandler(options)
case "tars":
return new TarsHandler(options)
case "human-relay":
return new HumanRelayHandler()
case "fake-ai":
Expand Down
234 changes: 234 additions & 0 deletions src/api/providers/__tests__/tars.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
// npx vitest run api/providers/__tests__/tars.spec.ts

import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"

import { TarsHandler } from "../tars"
import { ApiHandlerOptions } from "../../../shared/api"
import { Package } from "../../../shared/package"

const mockCreate = vitest.fn()

vitest.mock("openai", () => {
return {
default: vitest.fn().mockImplementation(() => ({
chat: {
completions: {
create: mockCreate,
},
},
})),
}
})

vitest.mock("delay", () => ({ default: vitest.fn(() => Promise.resolve()) }))

vitest.mock("../fetchers/modelCache", () => ({
getModels: vitest.fn().mockImplementation(() => {
return Promise.resolve({
"gpt-4o": {
maxTokens: 16384,
contextWindow: 128000,
supportsImages: true,
supportsPromptCache: true,
supportsComputerUse: false,
inputPrice: 2.5,
outputPrice: 10.0,
cacheWritesPrice: 0,
cacheReadsPrice: 0,
description:
"OpenAI GPT-4o model routed through TARS for optimal performance and reliability. TARS automatically selects the best available provider.",
},
})
}),
}))

describe("TarsHandler", () => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test coverage looks good! Consider adding a few more edge case tests:

  • Handling when TARS API returns models with missing fields
  • Testing the behavior when API key is invalid
  • Testing model fetching with network errors

These additional tests would help ensure robustness of the implementation.

const mockOptions: ApiHandlerOptions = {
tarsApiKey: "test-key",
tarsModelId: "gpt-4o",
}

beforeEach(() => vitest.clearAllMocks())

it("initializes with correct options", () => {
const handler = new TarsHandler(mockOptions)
expect(handler).toBeInstanceOf(TarsHandler)

expect(OpenAI).toHaveBeenCalledWith({
baseURL: "https://api.router.tetrate.ai/v1",
apiKey: mockOptions.tarsApiKey,
defaultHeaders: {
"HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline",
"X-Title": "Roo Code",
"User-Agent": `RooCode/${Package.version}`,
},
})
})

describe("fetchModel", () => {
it("returns correct model info when options are provided", async () => {
const handler = new TarsHandler(mockOptions)
const result = await handler.fetchModel()

expect(result).toMatchObject({
id: mockOptions.tarsModelId,
info: {
maxTokens: 16384,
contextWindow: 128000,
supportsImages: true,
supportsPromptCache: true,
supportsComputerUse: false,
inputPrice: 2.5,
outputPrice: 10.0,
cacheWritesPrice: 0,
cacheReadsPrice: 0,
description:
"OpenAI GPT-4o model routed through TARS for optimal performance and reliability. TARS automatically selects the best available provider.",
},
})
})

it("returns default model info when options are not provided", async () => {
const handler = new TarsHandler({})
const result = await handler.fetchModel()

expect(result).toMatchObject({
id: "claude-3-5-haiku-20241022",
info: {
maxTokens: 8192,
contextWindow: 200000,
supportsImages: true,
supportsPromptCache: true,
supportsComputerUse: false,
inputPrice: 0.8,
outputPrice: 4.0,
cacheWritesPrice: 1.0,
cacheReadsPrice: 0.08,
description:
"Claude 3.5 Haiku - Fast and cost-effective with excellent coding capabilities. Ideal for development tasks with 200k context window",
},
})
})
})

describe("createMessage", () => {
it("generates correct stream chunks", async () => {
const handler = new TarsHandler(mockOptions)

const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
id: mockOptions.tarsModelId,
choices: [{ delta: { content: "test response" } }],
}
yield {
id: "test-id",
choices: [{ delta: { reasoning_content: "test reasoning" } }],
}
yield {
id: "test-id",
choices: [{ delta: {} }],
usage: {
prompt_tokens: 10,
completion_tokens: 20,
prompt_tokens_details: {
caching_tokens: 5,
cached_tokens: 2,
},
},
}
},
}

mockCreate.mockResolvedValue(mockStream)

const systemPrompt = "test system prompt"
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user" as const, content: "test message" }]
const metadata = { taskId: "test-task-id", mode: "test-mode" }

const generator = handler.createMessage(systemPrompt, messages, metadata)
const chunks = []

for await (const chunk of generator) {
chunks.push(chunk)
}

// Verify stream chunks
expect(chunks).toHaveLength(3) // text, reasoning, and usage chunks
expect(chunks[0]).toEqual({ type: "text", text: "test response" })
expect(chunks[1]).toEqual({ type: "reasoning", text: "test reasoning" })
expect(chunks[2]).toEqual({
type: "usage",
inputTokens: 10,
outputTokens: 20,
cacheWriteTokens: 5,
cacheReadTokens: 2,
totalCost: expect.any(Number),
})

// Verify OpenAI client was called with correct parameters
expect(mockCreate).toHaveBeenCalledWith({
max_tokens: 16384,
messages: [
{
role: "system",
content: "test system prompt",
},
{
role: "user",
content: "test message",
},
],
model: "gpt-4o",
stream: true,
stream_options: { include_usage: true },
temperature: 0,
})
})

it("handles API errors", async () => {
const handler = new TarsHandler(mockOptions)
const mockError = new Error("API Error")
mockCreate.mockRejectedValue(mockError)

const generator = handler.createMessage("test", [])
await expect(generator.next()).rejects.toThrow("API Error")
})
})

describe("completePrompt", () => {
it("returns correct response", async () => {
const handler = new TarsHandler(mockOptions)
const mockResponse = { choices: [{ message: { content: "test completion" } }] }

mockCreate.mockResolvedValue(mockResponse)

const result = await handler.completePrompt("test prompt")

expect(result).toBe("test completion")

expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.tarsModelId,
max_tokens: 16384,
messages: [{ role: "system", content: "test prompt" }],
temperature: 0,
})
})

it("handles API errors", async () => {
const handler = new TarsHandler(mockOptions)
const mockError = new Error("API Error")
mockCreate.mockRejectedValue(mockError)

await expect(handler.completePrompt("test prompt")).rejects.toThrow("API Error")
})

it("handles unexpected errors", async () => {
const handler = new TarsHandler(mockOptions)
mockCreate.mockRejectedValue(new Error("Unexpected error"))

await expect(handler.completePrompt("test prompt")).rejects.toThrow("Unexpected error")
})
})
})
5 changes: 5 additions & 0 deletions src/api/providers/fetchers/modelCache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { fileExistsAtPath } from "../../../utils/fs"

import { getOpenRouterModels } from "./openrouter"
import { getRequestyModels } from "./requesty"
import { getTarsModels } from "./tars"
import { getGlamaModels } from "./glama"
import { getUnboundModels } from "./unbound"
import { getLiteLLMModels } from "./litellm"
Expand Down Expand Up @@ -61,6 +62,10 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
// Requesty models endpoint requires an API key for per-user custom policies
models = await getRequestyModels(options.apiKey)
break
case "tars":
// TARS models endpoint requires an API key
models = await getTarsModels(options.apiKey)
break
case "glama":
models = await getGlamaModels()
break
Expand Down
54 changes: 54 additions & 0 deletions src/api/providers/fetchers/tars.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import axios from "axios"

import type { ModelInfo } from "@roo-code/types"

import { parseApiPrice } from "../../../shared/cost"

export async function getTarsModels(apiKey?: string): Promise<Record<string, ModelInfo>> {
const models: Record<string, ModelInfo> = {}

try {
const headers: Record<string, string> = {}

if (apiKey) {
headers["Authorization"] = `Bearer ${apiKey}`
}

const url = "https://api.router.tetrate.ai/v1/models"
const response = await axios.get(url, { headers })
const rawModels = response.data.data

for (const rawModel of rawModels) {
// TARS supports reasoning for Claude and Gemini models similar to Requesty
const reasoningBudget =
rawModel.supports_reasoning &&
(rawModel.id.includes("claude") ||
rawModel.id.includes("coding/gemini-2.5") ||
rawModel.id.includes("vertex/gemini-2.5"))
const reasoningEffort =
rawModel.supports_reasoning &&
(rawModel.id.includes("openai") || rawModel.id.includes("google/gemini-2.5"))

const modelInfo: ModelInfo = {
maxTokens: rawModel.max_output_tokens || rawModel.max_tokens || 4096,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed the handling here uses . Is this intentional? Could we verify if the TARS API actually returns both fields, or should we handle this more consistently?

For comparison, the Requesty fetcher only uses without the fallback. Would it make sense to align the approach?

contextWindow: rawModel.context_window || 128000,
supportsPromptCache: rawModel.supports_caching || rawModel.supports_prompt_cache || false,
supportsImages: rawModel.supports_vision || rawModel.supports_images || false,
supportsComputerUse: rawModel.supports_computer_use || false,
supportsReasoningBudget: reasoningBudget,
supportsReasoningEffort: reasoningEffort,
inputPrice: parseApiPrice(rawModel.input_price) || 0,
outputPrice: parseApiPrice(rawModel.output_price) || 0,
description: rawModel.description,
cacheWritesPrice: parseApiPrice(rawModel.caching_price || rawModel.cache_write_price) || 0,
cacheReadsPrice: parseApiPrice(rawModel.cached_price || rawModel.cache_read_price) || 0,
}

models[rawModel.id] = modelInfo
}
} catch (error) {
console.error(`Error fetching TARS models: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`)
}

return models
}
1 change: 1 addition & 0 deletions src/api/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ export { OpenAiHandler } from "./openai"
export { OpenRouterHandler } from "./openrouter"
export { RequestyHandler } from "./requesty"
export { SambaNovaHandler } from "./sambanova"
export { TarsHandler } from "./tars"
export { UnboundHandler } from "./unbound"
export { VertexHandler } from "./vertex"
export { VsCodeLmHandler } from "./vscode-lm"
Expand Down
Loading
Loading