Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions src/api/providers/__tests__/roo.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,21 @@ describe("RooHandler", () => {
handler = new RooHandler(mockOptions)
})

it("should update API key before making request", async () => {
// Set up a fresh token that will be returned when createMessage is called
const freshToken = "fresh-session-token"
mockGetSessionTokenFn.mockReturnValue(freshToken)

const stream = handler.createMessage(systemPrompt, messages)
// Consume the stream to trigger the API call
for await (const _chunk of stream) {
// Just consume
}

// Verify getSessionToken was called to get the fresh token
expect(mockGetSessionTokenFn).toHaveBeenCalled()
})

it("should handle streaming responses", async () => {
const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []
Expand Down Expand Up @@ -290,6 +305,25 @@ describe("RooHandler", () => {
})
})

it("should update API key before making request", async () => {
// Set up a fresh token that will be returned when completePrompt is called
const freshToken = "fresh-session-token"
mockGetSessionTokenFn.mockReturnValue(freshToken)

// Access the client's apiKey property to verify it gets updated
const clientApiKeyGetter = vitest.fn()
Object.defineProperty(handler["client"], "apiKey", {
get: clientApiKeyGetter,
set: vitest.fn(),
configurable: true,
})

await handler.completePrompt("Test prompt")

// Verify getSessionToken was called to get the fresh token
expect(mockGetSessionTokenFn).toHaveBeenCalled()
})

it("should handle API errors", async () => {
mockCreate.mockRejectedValueOnce(new Error("API Error"))
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
Expand Down
48 changes: 15 additions & 33 deletions src/api/providers/roo.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"

import { AuthState, rooDefaultModelId, type ModelInfo } from "@roo-code/types"
import { rooDefaultModelId } from "@roo-code/types"
import { CloudService } from "@roo-code/cloud"

import type { ApiHandlerOptions, ModelRecord } from "../../shared/api"
Expand All @@ -12,9 +12,8 @@ import type { RooReasoningParams } from "../transform/reasoning"
import { getRooReasoning } from "../transform/reasoning"

import type { ApiHandlerCreateMessageMetadata } from "../index"
import { DEFAULT_HEADERS } from "./constants"
import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"
import { getModels, flushModels, getModelsFromCache } from "../providers/fetchers/modelCache"
import { getModels, getModelsFromCache } from "../providers/fetchers/modelCache"
import { handleOpenAIError } from "./utils/openai-error-handler"

// Extend OpenAI's CompletionUsage to include Roo specific fields
Expand All @@ -28,16 +27,16 @@ type RooChatCompletionParams = OpenAI.Chat.ChatCompletionCreateParamsStreaming &
reasoning?: RooReasoningParams
}

function getSessionToken(): string {
const token = CloudService.hasInstance() ? CloudService.instance.authService?.getSessionToken() : undefined
return token ?? "unauthenticated"
}

export class RooHandler extends BaseOpenAiCompatibleProvider<string> {
private authStateListener?: (state: { state: AuthState }) => void
private fetcherBaseURL: string

constructor(options: ApiHandlerOptions) {
let sessionToken: string | undefined = undefined

if (CloudService.hasInstance()) {
sessionToken = CloudService.instance.authService?.getSessionToken()
}
const sessionToken = getSessionToken()

let baseURL = process.env.ROO_CODE_PROVIDER_URL ?? "https://api.roocode.com/proxy"

Expand All @@ -52,7 +51,7 @@ export class RooHandler extends BaseOpenAiCompatibleProvider<string> {
...options,
providerName: "Roo Code Cloud",
baseURL, // Already has /v1 suffix
apiKey: sessionToken || "unauthenticated", // Use a placeholder if no token.
apiKey: sessionToken,
defaultProviderModelId: rooDefaultModelId,
providerModels: {},
defaultTemperature: 0.7,
Expand All @@ -63,29 +62,6 @@ export class RooHandler extends BaseOpenAiCompatibleProvider<string> {
this.loadDynamicModels(this.fetcherBaseURL, sessionToken).catch((error) => {
console.error("[RooHandler] Failed to load dynamic models:", error)
})

if (CloudService.hasInstance()) {
const cloudService = CloudService.instance

this.authStateListener = (state: { state: AuthState }) => {
// Update OpenAI client with current auth token
// Note: Model cache flush/reload is handled by extension.ts authStateChangedHandler
const newToken = cloudService.authService?.getSessionToken()
this.client = new OpenAI({
baseURL: this.baseURL,
apiKey: newToken ?? "unauthenticated",
defaultHeaders: DEFAULT_HEADERS,
})
}

cloudService.on("auth-state-changed", this.authStateListener)
}
}

dispose() {
if (this.authStateListener && CloudService.hasInstance()) {
CloudService.instance.off("auth-state-changed", this.authStateListener)
}
}

protected override createStream(
Expand Down Expand Up @@ -127,6 +103,7 @@ export class RooHandler extends BaseOpenAiCompatibleProvider<string> {
}

try {
this.client.apiKey = getSessionToken()
return this.client.chat.completions.create(rooParams, requestOptions)
} catch (error) {
throw handleOpenAIError(error, this.providerName)
Expand Down Expand Up @@ -195,6 +172,11 @@ export class RooHandler extends BaseOpenAiCompatibleProvider<string> {
}
}
}
override async completePrompt(prompt: string): Promise<string> {
// Update API key before making request to ensure we use the latest session token
this.client.apiKey = getSessionToken()
return super.completePrompt(prompt)
}

private async loadDynamicModels(baseURL: string, apiKey?: string): Promise<void> {
try {
Expand Down