Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/types/src/global-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ export const SECRET_STATE_KEYS = [
"chutesApiKey",
"litellmApiKey",
"deepInfraApiKey",
"cognimaApiKey",
"codeIndexOpenAiKey",
"codeIndexQdrantApiKey",
"codebaseIndexOpenAiCompatibleApiKey",
Expand Down
12 changes: 12 additions & 0 deletions packages/types/src/provider-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ export const DEFAULT_CONSECUTIVE_MISTAKE_LIMIT = 3

export const dynamicProviders = [
"openrouter",
"cognima",
"vercel-ai-gateway",
"huggingface",
"litellm",
Expand Down Expand Up @@ -122,6 +123,7 @@ export const providerNames = [
"cerebras",
"chutes",
"claude-code",
"cognima",
"doubao",
"deepseek",
"featherless",
Expand Down Expand Up @@ -352,6 +354,11 @@ const groqSchema = apiModelIdProviderModelSchema.extend({
groqApiKey: z.string().optional(),
})

const cognimaSchema = baseProviderSettingsSchema.extend({
cognimaModelId: z.string().optional(),
cognimaApiKey: z.string().optional(),
})

const huggingFaceSchema = baseProviderSettingsSchema.extend({
huggingFaceApiKey: z.string().optional(),
huggingFaceModelId: z.string().optional(),
Expand Down Expand Up @@ -441,6 +448,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv
fakeAiSchema.merge(z.object({ apiProvider: z.literal("fake-ai") })),
xaiSchema.merge(z.object({ apiProvider: z.literal("xai") })),
groqSchema.merge(z.object({ apiProvider: z.literal("groq") })),
cognimaSchema.merge(z.object({ apiProvider: z.literal("cognima") })),
huggingFaceSchema.merge(z.object({ apiProvider: z.literal("huggingface") })),
chutesSchema.merge(z.object({ apiProvider: z.literal("chutes") })),
litellmSchema.merge(z.object({ apiProvider: z.literal("litellm") })),
Expand Down Expand Up @@ -482,6 +490,7 @@ export const providerSettingsSchema = z.object({
...fakeAiSchema.shape,
...xaiSchema.shape,
...groqSchema.shape,
...cognimaSchema.shape,
...huggingFaceSchema.shape,
...chutesSchema.shape,
...litellmSchema.shape,
Expand Down Expand Up @@ -528,6 +537,7 @@ export const modelIdKeys = [
"ioIntelligenceModelId",
"vercelAiGatewayModelId",
"deepInfraModelId",
"cognimaModelId",
] as const satisfies readonly (keyof ProviderSettings)[]

export type ModelIdKey = (typeof modelIdKeys)[number]
Expand Down Expand Up @@ -568,6 +578,7 @@ export const modelIdKeysByProvider: Record<TypicalProvider, ModelIdKey> = {
requesty: "requestyModelId",
xai: "apiModelId",
groq: "apiModelId",
cognima: "cognimaModelId",
chutes: "apiModelId",
litellm: "litellmModelId",
huggingface: "huggingFaceModelId",
Expand Down Expand Up @@ -661,6 +672,7 @@ export const MODELS_BY_PROVIDER: Record<
models: Object.keys(geminiModels),
},
groq: { id: "groq", label: "Groq", models: Object.keys(groqModels) },
"cognima": { id: "cognima", label: "Cognima", models: [] },
"io-intelligence": {
id: "io-intelligence",
label: "IO Intelligence",
Expand Down
3 changes: 3 additions & 0 deletions packages/types/src/providers/cognima.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
export type CognimaModelId = string

export const cognimaDefaultModelId: CognimaModelId = "gpt-4o"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing newline at end of file. Add a newline after the last line to follow standard conventions.

1 change: 1 addition & 0 deletions packages/types/src/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ export * from "./anthropic.js"
export * from "./bedrock.js"
export * from "./cerebras.js"
export * from "./chutes.js"
export * from "./cognima.js"
export * from "./claude-code.js"
export * from "./deepseek.js"
export * from "./doubao.js"
Expand Down
3 changes: 3 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
AnthropicHandler,
AwsBedrockHandler,
CerebrasHandler,
CognimaHandler,
OpenRouterHandler,
VertexHandler,
AnthropicVertexHandler,
Expand Down Expand Up @@ -139,6 +140,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
return new XAIHandler(options)
case "groq":
return new GroqHandler(options)
case "cognima":
return new CognimaHandler(options)
case "deepinfra":
return new DeepInfraHandler(options)
case "huggingface":
Expand Down
92 changes: 92 additions & 0 deletions src/api/providers/cognima.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"

import { type CognimaModelId, cognimaDefaultModelId } from "@roo-code/types"

import type { ApiHandlerOptions } from "../../shared/api"

import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStreamChunk } from "../transform/stream"
import { RouterProvider } from "./router-provider"
import { handleOpenAIError } from "./utils/openai-error-handler"

export class CognimaHandler extends RouterProvider {
private readonly providerName = "Cognima"

constructor(options: ApiHandlerOptions) {
super({
options,
name: "cognima",
baseURL: "https://cog2.cognima.com.br/openai/v1",
apiKey: options.cognimaApiKey,
modelId: options.cognimaModelId,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cognimaModelId property is referenced here but is not defined in the cognimaSchema in packages/types/src/provider-settings.ts. The schema only includes cognimaApiKey. You need to add cognimaModelId: z.string().optional() to the cognimaSchema (around line 358) to match the pattern used by other providers like Groq.

defaultModelId: cognimaDefaultModelId,
defaultModelInfo: {
maxTokens: 16384,
contextWindow: 128000,
supportsImages: true,
supportsPromptCache: false,
inputPrice: 2.5,
outputPrice: 10,
supportsTemperature: true,
},
})
}

override async *createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
): AsyncGenerator<ApiStreamChunk> {
const model = await this.fetchModel()
const modelId = model.id
const maxTokens = model.info.maxTokens
const temperature = 0 // Default temperature

// Convert Anthropic messages to OpenAI format
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
{ role: "system", content: systemPrompt },
...convertToOpenAiMessages(messages),
]

const completionParams: OpenAI.Chat.ChatCompletionCreateParams = {
model: modelId,
...(maxTokens && maxTokens > 0 && { max_tokens: maxTokens }),
temperature,
messages: openAiMessages,
stream: true,
stream_options: { include_usage: true },
}

let stream
try {
stream = await this.client.chat.completions.create(completionParams)
} catch (error) {
throw handleOpenAIError(error, this.providerName)
}

for await (const chunk of stream) {
// Handle OpenAI error responses
if ("error" in chunk) {
const error = chunk.error as { message?: string; code?: number }
console.error(`Cognima API Error: ${error?.code} - ${error?.message}`)
throw new Error(`Cognima API Error ${error?.code}: ${error?.message}`)
}

const delta = chunk.choices[0]?.delta

if (delta?.content) {
yield { type: "text", text: delta.content }
}

if (chunk.usage) {
const usage = chunk.usage
yield {
type: "usage",
inputTokens: usage.prompt_tokens || 0,
outputTokens: usage.completion_tokens || 0,
totalCost: 0, // Cognima doesn't provide cost info in usage
}
}
}
}
}
111 changes: 111 additions & 0 deletions src/api/providers/fetchers/cognima.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import axios from "axios"
import { z } from "zod"

import type { ModelInfo } from "@roo-code/types"

/**
* CognimaModel
*/

const cognimaModelSchema = z.object({
id: z.string(),
owned_by: z.string(),
object: z.string(),
created: z.number().optional(),
updated: z.number().optional(),
})

export type CognimaModel = z.infer<typeof cognimaModelSchema>

/**
* CognimaModelsResponse
*/

const cognimaModelsResponseSchema = z.object({
data: z.array(cognimaModelSchema),
object: z.string(),
})

type CognimaModelsResponse = z.infer<typeof cognimaModelsResponseSchema>

/**
* getCognimaModels
*/

export async function getCognimaModels(apiKey?: string, baseUrl?: string): Promise<Record<string, ModelInfo>> {
const models: Record<string, ModelInfo> = {}
const baseURL = baseUrl || "https://cog2.cognima.com.br/openai/v1"

try {
const response = await axios.get<CognimaModelsResponse>(`${baseURL}/models`, {
headers: {
Authorization: `Bearer ${apiKey || "not-provided"}`,
"Content-Type": "application/json",
},
})

const result = cognimaModelsResponseSchema.safeParse(response.data)
const data = result.success ? result.data.data : response.data.data

if (!result.success) {
console.error("Cognima models response is invalid", result.error.format())
}

for (const model of data) {
models[model.id] = parseCognimaModel(model)
}
} catch (error) {
console.error(
`Error fetching Cognima models: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`,
)
}

return models
}

/**
* parseCognimaModel
*/

const parseCognimaModel = (model: CognimaModel): ModelInfo => {
// Provide basic ModelInfo with default values since Cognima API doesn't provide detailed pricing/info
// These defaults can be adjusted based on the actual models available
const modelInfo: ModelInfo = {
maxTokens: 4096, // Default value, can be adjusted per model if needed
contextWindow: 128000, // Default value, can be adjusted per model if needed
supportsImages: false, // Default to false, can be determined by model id patterns
supportsPromptCache: false, // Default to false
inputPrice: 0, // Default pricing, should be determined by actual API response or config
outputPrice: 0, // Default pricing, should be determined by actual API response or config
supportsTemperature: true,
}

// Add model-specific overrides based on ID patterns
if (model.id.includes("gpt-4o")) {
modelInfo.maxTokens = 16384
modelInfo.contextWindow = 128000
modelInfo.supportsImages = true
modelInfo.inputPrice = 2.5
modelInfo.outputPrice = 10
} else if (model.id.includes("gpt-4o-mini")) {
modelInfo.maxTokens = 16384
modelInfo.contextWindow = 128000
modelInfo.supportsImages = true
modelInfo.inputPrice = 0.15
modelInfo.outputPrice = 0.6
} else if (model.id.includes("claude-3-5-sonnet")) {
modelInfo.maxTokens = 8192
modelInfo.contextWindow = 200000
modelInfo.supportsImages = true
modelInfo.inputPrice = 3.0
modelInfo.outputPrice = 15.0
} else if (model.id.includes("llama-3.1-70b")) {
modelInfo.maxTokens = 4096
modelInfo.contextWindow = 128000
modelInfo.supportsImages = false
modelInfo.inputPrice = 0.52
modelInfo.outputPrice = 0.75
}

return modelInfo
}
4 changes: 4 additions & 0 deletions src/api/providers/fetchers/modelCache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import { getIOIntelligenceModels } from "./io-intelligence"
import { getDeepInfraModels } from "./deepinfra"
import { getHuggingFaceModels } from "./huggingface"
import { getRooModels } from "./roo"
import { getCognimaModels } from "./cognima"

const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })

Expand Down Expand Up @@ -67,6 +68,9 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
case "openrouter":
models = await getOpenRouterModels()
break
case "cognima":
models = await getCognimaModels(options.apiKey, options.baseUrl)
break
case "requesty":
// Requesty models endpoint requires an API key for per-user custom policies.
models = await getRequestyModels(options.baseUrl, options.apiKey)
Expand Down
1 change: 1 addition & 0 deletions src/api/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ export { AnthropicVertexHandler } from "./anthropic-vertex"
export { AnthropicHandler } from "./anthropic"
export { AwsBedrockHandler } from "./bedrock"
export { CerebrasHandler } from "./cerebras"
export { CognimaHandler } from "./cognima"
export { ChutesHandler } from "./chutes"
export { ClaudeCodeHandler } from "./claude-code"
export { DeepSeekHandler } from "./deepseek"
Expand Down
6 changes: 5 additions & 1 deletion src/core/webview/__tests__/webviewMessageHandler.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ describe("webviewMessageHandler - requestRouterModels", () => {
expect(mockGetModels).toHaveBeenCalledWith({ provider: "openrouter" })
expect(mockGetModels).toHaveBeenCalledWith({ provider: "requesty", apiKey: "requesty-key" })
expect(mockGetModels).toHaveBeenCalledWith({ provider: "glama" })
expect(mockGetModels).toHaveBeenCalledWith({ provider: "unbound", apiKey: "unbound-key" })
expect(mockGetModels).toHaveBeenCalledWith({ provider: "unbound",apiKey: "unbound-key" })
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing space after comma. Should be { provider: "unbound", apiKey: "unbound-key" } for consistency with the rest of the codebase.

Suggested change
expect(mockGetModels).toHaveBeenCalledWith({ provider: "unbound",apiKey: "unbound-key" })
expect(mockGetModels).toHaveBeenCalledWith({ provider: "unbound", apiKey: "unbound-key" })

expect(mockGetModels).toHaveBeenCalledWith({ provider: "vercel-ai-gateway" })
expect(mockGetModels).toHaveBeenCalledWith({ provider: "deepinfra" })
expect(mockGetModels).toHaveBeenCalledWith(
Expand All @@ -230,6 +230,7 @@ describe("webviewMessageHandler - requestRouterModels", () => {
baseUrl: expect.any(String),
}),
)
expect(mockGetModels).toHaveBeenCalledWith({ provider: "cognima", apiKey: undefined })
expect(mockGetModels).toHaveBeenCalledWith({
provider: "litellm",
apiKey: "litellm-key",
Expand All @@ -249,6 +250,7 @@ describe("webviewMessageHandler - requestRouterModels", () => {
unbound: mockModels,
litellm: mockModels,
roo: mockModels,
cognima: {},
ollama: {},
lmstudio: {},
"vercel-ai-gateway": mockModels,
Expand Down Expand Up @@ -340,6 +342,7 @@ describe("webviewMessageHandler - requestRouterModels", () => {
glama: mockModels,
unbound: mockModels,
roo: mockModels,
cognima: {},
litellm: {},
ollama: {},
lmstudio: {},
Expand Down Expand Up @@ -385,6 +388,7 @@ describe("webviewMessageHandler - requestRouterModels", () => {
glama: mockModels,
unbound: {},
roo: mockModels,
cognima: {},
litellm: {},
ollama: {},
lmstudio: {},
Expand Down
Loading
Loading