diff --git a/src/api/providers/__tests__/roo.spec.ts b/src/api/providers/__tests__/roo.spec.ts index cd1ab3330104..7555a49d498e 100644 --- a/src/api/providers/__tests__/roo.spec.ts +++ b/src/api/providers/__tests__/roo.spec.ts @@ -182,6 +182,21 @@ describe("RooHandler", () => { handler = new RooHandler(mockOptions) }) + it("should update API key before making request", async () => { + // Set up a fresh token that will be returned when createMessage is called + const freshToken = "fresh-session-token" + mockGetSessionTokenFn.mockReturnValue(freshToken) + + const stream = handler.createMessage(systemPrompt, messages) + // Consume the stream to trigger the API call + for await (const _chunk of stream) { + // Just consume + } + + // Verify getSessionToken was called to get the fresh token + expect(mockGetSessionTokenFn).toHaveBeenCalled() + }) + it("should handle streaming responses", async () => { const stream = handler.createMessage(systemPrompt, messages) const chunks: any[] = [] @@ -290,6 +305,25 @@ describe("RooHandler", () => { }) }) + it("should update API key before making request", async () => { + // Set up a fresh token that will be returned when completePrompt is called + const freshToken = "fresh-session-token" + mockGetSessionTokenFn.mockReturnValue(freshToken) + + // Access the client's apiKey property to verify it gets updated + const clientApiKeyGetter = vitest.fn() + Object.defineProperty(handler["client"], "apiKey", { + get: clientApiKeyGetter, + set: vitest.fn(), + configurable: true, + }) + + await handler.completePrompt("Test prompt") + + // Verify getSessionToken was called to get the fresh token + expect(mockGetSessionTokenFn).toHaveBeenCalled() + }) + it("should handle API errors", async () => { mockCreate.mockRejectedValueOnce(new Error("API Error")) await expect(handler.completePrompt("Test prompt")).rejects.toThrow( diff --git a/src/api/providers/roo.ts b/src/api/providers/roo.ts index 3bd1bb65dc36..327796a1ffca 100644 --- a/src/api/providers/roo.ts +++ b/src/api/providers/roo.ts @@ -1,7 +1,7 @@ import { Anthropic } from "@anthropic-ai/sdk" import OpenAI from "openai" -import { AuthState, rooDefaultModelId, type ModelInfo } from "@roo-code/types" +import { rooDefaultModelId } from "@roo-code/types" import { CloudService } from "@roo-code/cloud" import type { ApiHandlerOptions, ModelRecord } from "../../shared/api" @@ -12,9 +12,8 @@ import type { RooReasoningParams } from "../transform/reasoning" import { getRooReasoning } from "../transform/reasoning" import type { ApiHandlerCreateMessageMetadata } from "../index" -import { DEFAULT_HEADERS } from "./constants" import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider" -import { getModels, flushModels, getModelsFromCache } from "../providers/fetchers/modelCache" +import { getModels, getModelsFromCache } from "../providers/fetchers/modelCache" import { handleOpenAIError } from "./utils/openai-error-handler" // Extend OpenAI's CompletionUsage to include Roo specific fields @@ -28,16 +27,16 @@ type RooChatCompletionParams = OpenAI.Chat.ChatCompletionCreateParamsStreaming & reasoning?: RooReasoningParams } +function getSessionToken(): string { + const token = CloudService.hasInstance() ? CloudService.instance.authService?.getSessionToken() : undefined + return token ?? "unauthenticated" +} + export class RooHandler extends BaseOpenAiCompatibleProvider { - private authStateListener?: (state: { state: AuthState }) => void private fetcherBaseURL: string constructor(options: ApiHandlerOptions) { - let sessionToken: string | undefined = undefined - - if (CloudService.hasInstance()) { - sessionToken = CloudService.instance.authService?.getSessionToken() - } + const sessionToken = getSessionToken() let baseURL = process.env.ROO_CODE_PROVIDER_URL ?? "https://api.roocode.com/proxy" @@ -52,7 +51,7 @@ export class RooHandler extends BaseOpenAiCompatibleProvider { ...options, providerName: "Roo Code Cloud", baseURL, // Already has /v1 suffix - apiKey: sessionToken || "unauthenticated", // Use a placeholder if no token. + apiKey: sessionToken, defaultProviderModelId: rooDefaultModelId, providerModels: {}, defaultTemperature: 0.7, @@ -63,29 +62,6 @@ export class RooHandler extends BaseOpenAiCompatibleProvider { this.loadDynamicModels(this.fetcherBaseURL, sessionToken).catch((error) => { console.error("[RooHandler] Failed to load dynamic models:", error) }) - - if (CloudService.hasInstance()) { - const cloudService = CloudService.instance - - this.authStateListener = (state: { state: AuthState }) => { - // Update OpenAI client with current auth token - // Note: Model cache flush/reload is handled by extension.ts authStateChangedHandler - const newToken = cloudService.authService?.getSessionToken() - this.client = new OpenAI({ - baseURL: this.baseURL, - apiKey: newToken ?? "unauthenticated", - defaultHeaders: DEFAULT_HEADERS, - }) - } - - cloudService.on("auth-state-changed", this.authStateListener) - } - } - - dispose() { - if (this.authStateListener && CloudService.hasInstance()) { - CloudService.instance.off("auth-state-changed", this.authStateListener) - } } protected override createStream( @@ -127,6 +103,7 @@ export class RooHandler extends BaseOpenAiCompatibleProvider { } try { + this.client.apiKey = getSessionToken() return this.client.chat.completions.create(rooParams, requestOptions) } catch (error) { throw handleOpenAIError(error, this.providerName) @@ -195,6 +172,11 @@ export class RooHandler extends BaseOpenAiCompatibleProvider { } } } + override async completePrompt(prompt: string): Promise { + // Update API key before making request to ensure we use the latest session token + this.client.apiKey = getSessionToken() + return super.completePrompt(prompt) + } private async loadDynamicModels(baseURL: string, apiKey?: string): Promise { try {