diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index fef7d811a4f..c5ceaa3c4ba 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -47,6 +47,7 @@ export const providerNames = [ "zai", "fireworks", "io-intelligence", + "copilot", ] as const export const providerNamesSchema = z.enum(providerNames) @@ -288,6 +289,10 @@ const ioIntelligenceSchema = apiModelIdProviderModelSchema.extend({ ioIntelligenceApiKey: z.string().optional(), }) +const copilotSchema = baseProviderSettingsSchema.extend({ + copilotModelId: z.string().optional(), +}) + const defaultSchema = z.object({ apiProvider: z.undefined(), }) @@ -324,6 +329,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv zaiSchema.merge(z.object({ apiProvider: z.literal("zai") })), fireworksSchema.merge(z.object({ apiProvider: z.literal("fireworks") })), ioIntelligenceSchema.merge(z.object({ apiProvider: z.literal("io-intelligence") })), + copilotSchema.merge(z.object({ apiProvider: z.literal("copilot") })), defaultSchema, ]) @@ -360,6 +366,7 @@ export const providerSettingsSchema = z.object({ ...zaiSchema.shape, ...fireworksSchema.shape, ...ioIntelligenceSchema.shape, + ...copilotSchema.shape, ...codebaseIndexProviderSchema.shape, }) @@ -386,6 +393,7 @@ export const MODEL_ID_KEYS: Partial[] = [ "litellmModelId", "huggingFaceModelId", "ioIntelligenceModelId", + "copilotModelId", ] export const getModelId = (settings: ProviderSettings): string | undefined => { diff --git a/packages/types/src/providers/copilot.ts b/packages/types/src/providers/copilot.ts new file mode 100644 index 00000000000..d43373d3100 --- /dev/null +++ b/packages/types/src/providers/copilot.ts @@ -0,0 +1,16 @@ +export const copilotDefaultModelId = "gpt-4.1" + +export const GITHUB_CLIENT_ID = "Iv1.b507a08c87ecfe98" +export const GITHUB_DEVICE_CODE_URL = "https://github.com/login/device/code" +export const GITHUB_ACCESS_TOKEN_URL = "https://github.com/login/oauth/access_token" +export const GITHUB_API_KEY_URL = "https://api.github.com/copilot_internal/v2/token" +export const GITHUB_COPILOT_API_BASE = "https://api.githubcopilot.com/" + +export const COPILOT_DEFAULT_HEADER = { + accept: "application/json", + "content-type": "application/json", + "editor-version": "vscode/1.85.1", + "editor-plugin-version": "copilot/1.155.0", + "user-agent": "GithubCopilot/1.155.0", + "accept-encoding": "gzip,deflate,br", +} diff --git a/packages/types/src/providers/index.ts b/packages/types/src/providers/index.ts index b7f1cd334e4..cd85456c791 100644 --- a/packages/types/src/providers/index.ts +++ b/packages/types/src/providers/index.ts @@ -25,3 +25,4 @@ export * from "./xai.js" export * from "./doubao.js" export * from "./zai.js" export * from "./fireworks.js" +export * from "./copilot.js" diff --git a/src/api/index.ts b/src/api/index.ts index c29c230b063..d52d541a2dc 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -36,6 +36,7 @@ import { DoubaoHandler, ZAiHandler, FireworksHandler, + CopilotHandler, } from "./providers" export interface SingleCompletionHandler { @@ -140,6 +141,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler { return new FireworksHandler(options) case "io-intelligence": return new IOIntelligenceHandler(options) + case "copilot": + return new CopilotHandler(options) default: apiProvider satisfies "gemini-cli" | undefined return new AnthropicHandler(options) diff --git a/src/api/providers/__tests__/copilot.spec.ts b/src/api/providers/__tests__/copilot.spec.ts new file mode 100644 index 00000000000..94f552e774f --- /dev/null +++ b/src/api/providers/__tests__/copilot.spec.ts @@ -0,0 +1,501 @@ +import { vi, describe, it, expect, beforeEach } from "vitest" + +// Create mock functions at the top level +const mockCreate = vi.fn() + +// Mock the CopilotAuthenticator +const mockAuthenticator = { + getApiKey: vi.fn(), + isAuthenticated: vi.fn(), + clearAuth: vi.fn(), +} + +vi.mock("../fetchers/copilot", () => ({ + CopilotAuthenticator: { + getInstance: () => mockAuthenticator, + }, +})) + +// Mock getModels from modelCache +vi.mock("../fetchers/modelCache", () => ({ + getModels: vi.fn(), +})) +vi.mock("openai", () => { + return { + __esModule: true, + default: vi.fn().mockImplementation(() => ({ + chat: { + completions: { + create: mockCreate.mockImplementation(async (options) => { + if (!options.stream) { + return { + id: "test-completion", + choices: [ + { + message: { role: "assistant", content: "Test Copilot response", refusal: null }, + finish_reason: "stop", + index: 0, + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + prompt_tokens_details: { + cache_miss_tokens: 8, + cached_tokens: 2, + }, + }, + } + } + + // Return async iterator for streaming + return { + [Symbol.asyncIterator]: async function* () { + yield { + choices: [ + { + delta: { content: "Test Copilot response" }, + index: 0, + }, + ], + usage: null, + } + yield { + choices: [ + { + delta: {}, + index: 0, + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + prompt_tokens_details: { + cache_miss_tokens: 8, + cached_tokens: 2, + }, + }, + } + }, + } + }), + }, + }, + })), + } +}) + +import OpenAI from "openai" +import type { Anthropic } from "@anthropic-ai/sdk" + +import { copilotDefaultModelId, GITHUB_COPILOT_API_BASE } from "@roo-code/types" +import type { ApiHandlerOptions } from "../../../shared/api" +import { getModels } from "../fetchers/modelCache" + +import { CopilotHandler } from "../copilot" + +const mockGetModels = getModels as any + +describe("CopilotHandler", () => { + let handler: CopilotHandler + let mockOptions: ApiHandlerOptions + + beforeEach(() => { + mockOptions = { + copilotModelId: "gpt-4", + apiModelId: "gpt-4", + openAiStreamingEnabled: true, + } + + // Mock successful authentication + mockAuthenticator.getApiKey.mockResolvedValue({ + apiKey: "test-api-key", + apiBase: GITHUB_COPILOT_API_BASE, + }) + + // Mock models + mockGetModels.mockResolvedValue({ + "gpt-4": { + maxTokens: 8192, + contextWindow: 128000, + supportsImages: false, + supportsPromptCache: true, + inputPrice: 0, + outputPrice: 0, + description: "GPT-4 via Copilot", + }, + }) + + handler = new CopilotHandler(mockOptions) + vi.clearAllMocks() + }) + + describe("constructor", () => { + it("should initialize with provided options", async () => { + expect(handler).toBeInstanceOf(CopilotHandler) + mockAuthenticator.getApiKey.mockResolvedValueOnce({ + apiKey: "new-api-key", + apiBase: GITHUB_COPILOT_API_BASE, + }) + // Access private method through any cast for testing + await (handler as any).ensureAuthenticated() + + expect(mockAuthenticator.getApiKey).toHaveBeenCalled() + expect(OpenAI).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey: "new-api-key", + baseURL: GITHUB_COPILOT_API_BASE, + }), + ) + }) + + it("should set up CopilotAuthenticator", () => { + expect(mockAuthenticator).toBeDefined() + }) + }) + + describe("ensureAuthenticated", () => { + it("should authenticate and update client", async () => { + mockAuthenticator.getApiKey.mockResolvedValueOnce({ + apiKey: "new-api-key", + apiBase: "https://custom.copilot.api", + }) + + // Access private method through any cast for testing + await (handler as any).ensureAuthenticated() + + expect(mockAuthenticator.getApiKey).toHaveBeenCalled() + // The client should be updated with new credentials + expect(OpenAI).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey: "new-api-key", + baseURL: "https://custom.copilot.api", + }), + ) + }) + + it("should handle authentication errors", async () => { + mockAuthenticator.getApiKey.mockRejectedValueOnce(new Error("Auth failed")) + + await expect((handler as any).ensureAuthenticated()).rejects.toThrow( + "Failed to authenticate with Copilot: Error: Auth failed", + ) + }) + }) + + describe("getModel", () => { + it("should return model info for valid model ID", async () => { + const model = await handler.fetchModel() + expect(model.id).toBe("gpt-4") + expect(model.info).toBeDefined() + expect(model.info.maxTokens).toBe(8192) + expect(model.info.contextWindow).toBe(128000) + }) + + it("should return default model if model ID not provided", async () => { + const handlerWithoutModel = new CopilotHandler({ + ...mockOptions, + copilotModelId: undefined, + }) + const model = handlerWithoutModel.getModel() + expect(model.id).toBe(copilotDefaultModelId) + }) + + it("should return fallback info for unknown models", async () => { + mockGetModels.mockResolvedValueOnce({}) + const handlerWithUnknownModel = new CopilotHandler({ + ...mockOptions, + copilotModelId: "unknown-model", + }) + const model = handlerWithUnknownModel.getModel() + expect(model.id).toBe("unknown-model") + expect(model.info.description).toContain("Copilot Model (Fallback)") + }) + }) + + describe("determineInitiator", () => { + it("should return 'user' for task messages", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [{ type: "text", text: "Do something" }], + }, + ] + const initiator = (handler as any).determineInitiator(messages) + expect(initiator).toBe("user") + }) + + it("should return 'agent' for assistant messages", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "assistant", + content: [{ type: "text", text: "I can help with that." }], + }, + ] + const initiator = (handler as any).determineInitiator(messages) + expect(initiator).toBe("agent") + }) + + it("should return 'agent' for tool result messages", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "tool_result", + tool_use_id: "test-tool", + content: "Tool executed successfully", + }, + ], + }, + ] + const initiator = (handler as any).determineInitiator(messages) + expect(initiator).toBe("agent") + }) + + it("should return 'user' for empty messages array", () => { + const messages: Anthropic.Messages.MessageParam[] = [] + const initiator = (handler as any).determineInitiator(messages) + expect(initiator).toBe("user") + }) + }) + + describe("createMessage", () => { + const systemPrompt = "You are a helpful coding assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "text" as const, + text: "Write a function to reverse a string", + }, + ], + }, + ] + + beforeEach(() => { + // Reset authentication mock for each test + mockAuthenticator.getApiKey.mockResolvedValue({ + apiKey: "test-api-key", + apiBase: GITHUB_COPILOT_API_BASE, + }) + }) + + it("should handle streaming responses", async () => { + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks).toHaveLength(1) + expect(textChunks[0].text).toBe("Test Copilot response") + + // Verify authentication was called + expect(mockAuthenticator.getApiKey).toHaveBeenCalled() + }) + + it("should include X-Initiator header", async () => { + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(mockCreate).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + headers: { + "X-Initiator": "user", + }, + }), + ) + }) + + it("should include usage information", async () => { + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0].inputTokens).toBe(10) + expect(usageChunks[0].outputTokens).toBe(5) + }) + + it("should include cache metrics in usage information", async () => { + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + const chunk = usageChunks[0] + expect(chunk.cacheWriteTokens).toBe(8) + expect(chunk.cacheReadTokens).toBe(2) + }) + + it("should handle non-streaming requests", async () => { + const nonStreamingHandler = new CopilotHandler({ + ...mockOptions, + openAiStreamingEnabled: false, + }) + + const stream = nonStreamingHandler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks).toHaveLength(1) + expect(textChunks[0].text).toBe("Test Copilot response") + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks).toHaveLength(1) + }) + + it("should add max_tokens when includeMaxTokens is enabled", async () => { + const handlerWithMaxTokens = new CopilotHandler({ + ...mockOptions, + includeMaxTokens: true, + modelMaxTokens: 4096, + }) + + const stream = handlerWithMaxTokens.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + max_completion_tokens: 4096, + }), + expect.any(Object), + ) + }) + + it("should handle authentication failures gracefully", async () => { + mockAuthenticator.getApiKey.mockRejectedValueOnce(new Error("Auth failed")) + + const stream = handler.createMessage(systemPrompt, messages) + + await expect(async () => { + for await (const chunk of stream) { + // Should throw before yielding any chunks + } + }).rejects.toThrow("Failed to authenticate with Copilot") + }) + }) + + describe("completePrompt", () => { + it("should complete a simple prompt", async () => { + const prompt = "Write a hello world function" + const result = await handler.completePrompt(prompt) + + expect(result).toBe("Test Copilot response") + expect(mockAuthenticator.getApiKey).toHaveBeenCalled() + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + messages: [{ role: "user", content: prompt }], + }), + ) + }) + + it("should handle authentication errors in completePrompt", async () => { + mockAuthenticator.getApiKey.mockRejectedValueOnce(new Error("Auth failed")) + + await expect(handler.completePrompt("test prompt")).rejects.toThrow( + "Copilot completion error: Failed to authenticate with Copilot: Error: Auth failed", + ) + }) + + it("should handle API errors in completePrompt", async () => { + mockCreate.mockRejectedValueOnce(new Error("API Error")) + + await expect(handler.completePrompt("test prompt")).rejects.toThrow("API Error") + }) + }) + + describe("authentication methods", () => { + it("should check if authenticated", async () => { + mockAuthenticator.isAuthenticated.mockResolvedValueOnce(true) + + const result = await handler.isAuthenticated() + expect(result).toBe(true) + expect(mockAuthenticator.isAuthenticated).toHaveBeenCalled() + }) + + it("should clear authentication", async () => { + await handler.clearAuth() + expect(mockAuthenticator.clearAuth).toHaveBeenCalled() + }) + }) + + describe("processUsageMetrics", () => { + it("should correctly process usage metrics including cache information", () => { + const usage = { + prompt_tokens: 100, + completion_tokens: 50, + total_tokens: 150, + prompt_tokens_details: { + cache_miss_tokens: 80, + cached_tokens: 20, + }, + } + + const result = (handler as any).processUsageMetrics(usage) + + expect(result.type).toBe("usage") + expect(result.inputTokens).toBe(100) + expect(result.outputTokens).toBe(50) + expect(result.cacheWriteTokens).toBe(80) + expect(result.cacheReadTokens).toBe(20) + }) + + it("should handle missing usage data gracefully", () => { + const usage = { + prompt_tokens: 100, + completion_tokens: 50, + // No details + } + + const result = (handler as any).processUsageMetrics(usage) + + expect(result.type).toBe("usage") + expect(result.inputTokens).toBe(100) + expect(result.outputTokens).toBe(50) + expect(result.cacheWriteTokens).toBeUndefined() + expect(result.cacheReadTokens).toBeUndefined() + }) + }) + + describe("fetchModel", () => { + it("should fetch models from cache and return current model", async () => { + const model = await handler.fetchModel() + + expect(mockGetModels).toHaveBeenCalledWith({ provider: "copilot" }) + expect(model.id).toBe("gpt-4") + expect(model.info).toBeDefined() + }) + + it("should handle empty models cache", async () => { + mockGetModels.mockResolvedValueOnce({}) + + const model = await handler.fetchModel() + + expect(model.id).toBe("gpt-4") + expect(model.info.description).toContain("Copilot Model (Fallback)") + }) + }) +}) diff --git a/src/api/providers/copilot.ts b/src/api/providers/copilot.ts new file mode 100644 index 00000000000..3ae0c933e80 --- /dev/null +++ b/src/api/providers/copilot.ts @@ -0,0 +1,261 @@ +import { Anthropic } from "@anthropic-ai/sdk" +import OpenAI from "openai" + +import { + COPILOT_DEFAULT_HEADER, + copilotDefaultModelId, + GITHUB_COPILOT_API_BASE, + type ModelInfo, + openAiModelInfoSaneDefaults, +} from "@roo-code/types" + +import type { ApiHandlerOptions, ModelRecord } from "../../shared/api" + +import { convertToOpenAiMessages } from "../transform/openai-format" +import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" +import { getModelParams } from "../transform/model-params" + +import { BaseProvider } from "./base-provider" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import { getApiRequestTimeout } from "./utils/timeout-config" +import { CopilotAuthenticator } from "./fetchers/copilot" +import { getModels } from "./fetchers/modelCache" + +/** + * Copilot API handler that provides direct access to Copilot models + * using GitHub's OAuth Device Code Flow for authentication. + * + * This handler automatically handles authentication via device code flow + * and supports dynamic model discovery through the Copilot API. + */ +export class CopilotHandler extends BaseProvider implements SingleCompletionHandler { + protected options: ApiHandlerOptions + private client: OpenAI + private authenticator: CopilotAuthenticator + private models: ModelRecord = {} + + constructor(options: ApiHandlerOptions) { + super() + this.options = options + this.authenticator = CopilotAuthenticator.getInstance() + + // Initialize with placeholder values - actual API key will be obtained dynamically + this.client = new OpenAI({ + baseURL: GITHUB_COPILOT_API_BASE, + apiKey: "placeholder", + defaultHeaders: COPILOT_DEFAULT_HEADER, + timeout: getApiRequestTimeout(), + }) + } + + /** + * Get or refresh the Copilot API key using device code flow + */ + private async ensureAuthenticated(): Promise { + try { + const { apiKey, apiBase } = await this.authenticator.getApiKey() + + // Update client with new API key and base URL + const baseURL = apiBase || GITHUB_COPILOT_API_BASE + this.client = new OpenAI({ + baseURL, + apiKey, + defaultHeaders: COPILOT_DEFAULT_HEADER, + timeout: getApiRequestTimeout(), + }) + } catch (error) { + throw new Error(`Failed to authenticate with Copilot: ${error}`) + } + } + + /** + * Determine the X-Initiator header based on message roles + */ + private determineInitiator(messages: Anthropic.Messages.MessageParam[]): string { + const isUserMessage = (text: string) => text.includes("") || text.includes("") + if (messages.length === 0) { + return "user" + } + const lastMessage = messages[messages.length - 1] + if (lastMessage.role === "assistant") { + return "agent" + } + if (typeof lastMessage === "string") { + return "user" + } + if (Array.isArray(lastMessage.content)) { + if (lastMessage.content.some((i) => i.type === "tool_result")) { + return "agent" + } + if (lastMessage.content.some((i) => i.type === "text")) { + let typeMode = "agent" + if (lastMessage.content.some((i) => i.type === "text" && isUserMessage(i.text))) { + typeMode = "user" + } + return typeMode + } + } + + return "user" + } + + override async *createMessage( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + // Ensure we have a valid API key + await this.ensureAuthenticated() + + const { id: modelId, info: modelInfo } = await this.fetchModel() + + // Convert Anthropic messages to OpenAI format + let systemMessage: OpenAI.Chat.ChatCompletionSystemMessageParam = { + role: "system", + content: systemPrompt, + } + const convertedMessages = [systemMessage, ...convertToOpenAiMessages(messages)] + + // Add X-Initiator header + const initiator = this.determineInitiator(messages) + const headers = { + "X-Initiator": initiator, + } + + if (this.options.openAiStreamingEnabled ?? true) { + const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { + model: modelId, + temperature: this.options.modelTemperature ?? 0, + messages: convertedMessages, + stream: true as const, + stream_options: { include_usage: true }, + } + + // Add max_tokens if needed + this.addMaxTokensIfNeeded(requestOptions, modelInfo) + + const stream = await this.client.chat.completions.create(requestOptions, { + headers, + }) + + for await (const chunk of stream) { + const delta = chunk.choices?.[0]?.delta + if (delta?.content) { + yield { + type: "text", + text: delta.content, + } + } + + // Handle usage information + if (chunk.usage) { + yield this.processUsageMetrics(chunk.usage) + } + } + } else { + // Non-streaming implementation + const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { + model: modelId, + temperature: this.options.modelTemperature ?? 0, + messages: convertedMessages, + } + + this.addMaxTokensIfNeeded(requestOptions, modelInfo) + + const response = await this.client.chat.completions.create(requestOptions, { + headers, + }) + + yield { + type: "text", + text: response.choices[0]?.message.content || "", + } + yield this.processUsageMetrics(response.usage) + } + } + + /** + * Process usage metrics from OpenAI response + */ + private processUsageMetrics(usage: any): ApiStreamUsageChunk { + return { + type: "usage", + inputTokens: usage?.prompt_tokens || 0, + outputTokens: usage?.completion_tokens || 0, + cacheWriteTokens: usage?.prompt_tokens_details?.cache_miss_tokens, + cacheReadTokens: usage?.prompt_tokens_details?.cached_tokens, + } + } + + override getModel() { + const id = this.options.copilotModelId ?? copilotDefaultModelId + if (id in this.models) { + const info = this.models[id] + const params = getModelParams({ format: "openai", modelId: id, model: info, settings: this.options }) + return { id, info, ...params } + } + return { + id, + info: { + ...openAiModelInfoSaneDefaults, + description: `Copilot Model (Fallback): ${id}`, + }, + } + } + + async completePrompt(prompt: string): Promise { + try { + await this.ensureAuthenticated() + const model = this.getModel() + + const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { + model: model.id, + messages: [{ role: "user", content: prompt }], + } + + // Add max_tokens if needed + this.addMaxTokensIfNeeded(requestOptions, model.info) + + const response = await this.client.chat.completions.create(requestOptions) + + return response.choices[0]?.message.content || "" + } catch (error) { + if (error instanceof Error) { + throw new Error(`Copilot completion error: ${error.message}`) + } + + throw error + } + } + + /** + * Add max_completion_tokens to request options if needed + */ + private addMaxTokensIfNeeded( + requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParams, + modelInfo: ModelInfo, + ): void { + if (this.options.includeMaxTokens) { + requestOptions.max_completion_tokens = this.options.modelMaxTokens || modelInfo.maxTokens + } + } + + /** + * Check if the user is authenticated + */ + async isAuthenticated(): Promise { + return this.authenticator.isAuthenticated() + } + + /** + * Clear authentication data + */ + async clearAuth(): Promise { + return this.authenticator.clearAuth() + } + + public async fetchModel() { + this.models = await getModels({ provider: "copilot" }) + return this.getModel() + } +} diff --git a/src/api/providers/fetchers/__tests__/copilot.test.ts b/src/api/providers/fetchers/__tests__/copilot.test.ts new file mode 100644 index 00000000000..82298ebf87d --- /dev/null +++ b/src/api/providers/fetchers/__tests__/copilot.test.ts @@ -0,0 +1,190 @@ +// src/api/providers/fetchers/__tests__/copilot.test.ts +import { describe, it, expect, beforeEach, vi } from "vitest" +import axios from "axios" +import * as fs from "fs" +import { CopilotAuthenticator, getCopilotModels } from "../copilot" + +vi.mock("axios") +vi.mock("fs", () => ({ + promises: { + readFile: vi.fn(), + writeFile: vi.fn(), + mkdir: vi.fn(), + unlink: vi.fn(), + }, +})) +vi.mock("path", () => ({ + join: (...args: string[]) => args.join("/"), +})) +vi.mock("os", () => ({ + homedir: () => "/home/test", +})) + +const mockApiKey = "mock-api-key" +const mockAccessToken = "mock-access-token" +const mockCopilotToken = { + token: mockApiKey, + expires_at: Math.floor(Date.now() / 1000) + 3600, + endpoints: { api: "https://copilot.api" }, +} +const mockStoredTokens = { + access_token: mockAccessToken, + api_key: mockApiKey, + api_key_expires_at: Math.floor(Date.now() / 1000) + 3600, + api_base: "https://copilot.api", +} + +describe("CopilotAuthenticator", () => { + let authenticator: CopilotAuthenticator + + beforeEach(() => { + authenticator = CopilotAuthenticator.getInstance() + vi.clearAllMocks() + }) + + it("should return valid apiKey from stored tokens", async () => { + ;(fs.promises.readFile as any).mockResolvedValue(JSON.stringify(mockStoredTokens)) + expect(await authenticator.getApiKey()).toEqual({ + apiKey: mockApiKey, + apiBase: "https://copilot.api", + }) + }) + + it("should refresh apiKey if expired", async () => { + ;(fs.promises.readFile as any).mockResolvedValue( + JSON.stringify({ ...mockStoredTokens, api_key_expires_at: Math.floor(Date.now() / 1000) - 10 }), + ) + ;(axios.get as any).mockResolvedValue({ data: mockCopilotToken }) + ;(fs.promises.writeFile as any).mockResolvedValue(undefined) + expect(await authenticator.getApiKey()).toEqual({ + apiKey: mockApiKey, + apiBase: "https://copilot.api", + }) + }) + + it("should start device code flow if no token", async () => { + ;(fs.promises.readFile as any).mockRejectedValue(new Error("not found")) + ;(axios.post as any).mockResolvedValueOnce({ + data: { + device_code: "dev-code", + user_code: "user-code", + verification_uri: "https://verify", + expires_in: 600, + interval: 1, + }, + }) + ;(axios.post as any).mockResolvedValueOnce({ + data: { access_token: mockAccessToken }, + }) + ;(axios.get as any).mockResolvedValue({ data: mockCopilotToken }) + ;(fs.promises.writeFile as any).mockResolvedValue(undefined) + expect(await authenticator.getApiKey()).toEqual({ + apiKey: mockApiKey, + apiBase: "https://copilot.api", + }) + }) + + it("should handle pollForAccessToken timeout", async () => { + ;(fs.promises.readFile as any).mockRejectedValue(new Error("not found")) + ;(axios.post as any).mockResolvedValueOnce({ + data: { + device_code: "dev-code", + user_code: "user-code", + verification_uri: "https://verify", + expires_in: 600, + interval: 0, // 立即执行 + }, + }) + ;(axios.post as any).mockResolvedValue({ data: { error: "authorization_pending" } }) + + // mock setTimeout 立即执行 + vi.stubGlobal("setTimeout", (fn: () => void, _ms: number) => { + fn() + }) + + authenticator.setAuthTimeoutCallback(vi.fn()) + await expect(authenticator.getApiKey()).rejects.toThrow("Authentication timed out") + + vi.unstubAllGlobals() + }) + + it("should clear authentication data", async () => { + ;(fs.promises.unlink as any).mockResolvedValue(undefined) + await expect(authenticator.clearAuth()).resolves.toBeUndefined() + }) + + it("should return isAuthenticated true if access_token exists", async () => { + ;(fs.promises.readFile as any).mockResolvedValue(JSON.stringify(mockStoredTokens)) + expect(await authenticator.isAuthenticated()).toBe(true) + }) + + it("should return isAuthenticated false if no access_token", async () => { + ;(fs.promises.readFile as any).mockResolvedValue(JSON.stringify({})) + expect(await authenticator.isAuthenticated()).toBe(false) + }) +}) + +describe("getCopilotModels", () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it("should fetch models and return ModelRecord", async () => { + const authenticator = CopilotAuthenticator.getInstance() + vi.spyOn(authenticator, "getApiKey").mockResolvedValue({ + apiKey: mockApiKey, + apiBase: "https://copilot.api", + }) + global.fetch = vi.fn().mockResolvedValue({ + ok: true, + json: async () => ({ + data: [ + { + id: "model-1", + name: "Model One", + model_picker_enabled: true, + capabilities: { + limits: { max_output_tokens: 2048, max_context_window_tokens: 4096 }, + supports: { max_thinking_budget: 100 }, + }, + }, + ], + }), + }) + const result = await getCopilotModels() + expect(result["model-1"]).toMatchObject({ + maxTokens: 2048, + maxThinkingTokens: 100, + contextWindow: 4096, + description: "Model One", + supportsImages: false, + supportsComputerUse: false, + supportsPromptCache: true, + supportsVerbosity: false, + supportsReasoningBudget: false, + requiredReasoningBudget: false, + supportsReasoningEffort: false, + supportedParameters: ["reasoning"], + }) + }) + + it("should throw error if fetch fails", async () => { + const authenticator = CopilotAuthenticator.getInstance() + vi.spyOn(authenticator, "getApiKey").mockResolvedValue({ + apiKey: mockApiKey, + apiBase: "https://copilot.api", + }) + global.fetch = vi.fn().mockResolvedValue({ ok: false, statusText: "Bad Request" }) + await expect(getCopilotModels()).rejects.toThrow("Failed to fetch Copilot models: Bad Request") + }) + + it("should throw error if fetch throws", async () => { + const authenticator = CopilotAuthenticator.getInstance() + vi.spyOn(authenticator, "getApiKey").mockResolvedValue({ + apiKey: mockApiKey, + apiBase: "https://copilot.api", + }) + global.fetch = vi.fn().mockRejectedValue(new Error("Network error")) + await expect(getCopilotModels()).rejects.toThrow("Failed to fetch Copilot models: Network error") + }) +}) diff --git a/src/api/providers/fetchers/__tests__/modelCache.spec.ts b/src/api/providers/fetchers/__tests__/modelCache.spec.ts index 21f5ce8bff9..3ba5d479078 100644 --- a/src/api/providers/fetchers/__tests__/modelCache.spec.ts +++ b/src/api/providers/fetchers/__tests__/modelCache.spec.ts @@ -25,6 +25,7 @@ vi.mock("../requesty") vi.mock("../glama") vi.mock("../unbound") vi.mock("../io-intelligence") +vi.mock("../copilot") // Then imports import type { Mock } from "vitest" @@ -35,6 +36,7 @@ import { getRequestyModels } from "../requesty" import { getGlamaModels } from "../glama" import { getUnboundModels } from "../unbound" import { getIOIntelligenceModels } from "../io-intelligence" +import { getCopilotModels } from "../copilot" const mockGetLiteLLMModels = getLiteLLMModels as Mock const mockGetOpenRouterModels = getOpenRouterModels as Mock @@ -42,6 +44,7 @@ const mockGetRequestyModels = getRequestyModels as Mock const mockGetUnboundModels = getUnboundModels as Mock const mockGetIOIntelligenceModels = getIOIntelligenceModels as Mock +const mockGetCopilotModels = getCopilotModels as Mock const DUMMY_REQUESTY_KEY = "requesty-key-for-testing" const DUMMY_UNBOUND_KEY = "unbound-key-for-testing" @@ -158,6 +161,23 @@ describe("getModels with new GetModelsOptions", () => { expect(result).toEqual(mockModels) }) + it("calls getCopilotModels for copilot provider", async () => { + const mockModels = { + "copilot/model": { + maxTokens: 4096, + contextWindow: 8192, + supportsPromptCache: true, + description: "Copilot model", + }, + } + mockGetCopilotModels.mockResolvedValue(mockModels) + + const result = await getModels({ provider: "copilot" }) + + expect(mockGetCopilotModels).toHaveBeenCalled() + expect(result).toEqual(mockModels) + }) + it("handles errors and re-throws them", async () => { const expectedError = new Error("LiteLLM connection failed") mockGetLiteLLMModels.mockRejectedValue(expectedError) diff --git a/src/api/providers/fetchers/copilot.ts b/src/api/providers/fetchers/copilot.ts new file mode 100644 index 00000000000..f9036038615 --- /dev/null +++ b/src/api/providers/fetchers/copilot.ts @@ -0,0 +1,387 @@ +import axios from "axios" +import { promises as fs } from "fs" +import { join } from "path" +import { homedir } from "os" +import { ModelRecord } from "../../../shared/api" +import { + GITHUB_ACCESS_TOKEN_URL, + GITHUB_API_KEY_URL, + GITHUB_CLIENT_ID, + GITHUB_COPILOT_API_BASE, + GITHUB_DEVICE_CODE_URL, +} from "@roo-code/types" + +interface DeviceCodeResponse { + device_code: string + user_code: string + verification_uri: string + expires_in: number + interval: number +} + +interface AccessTokenResponse { + access_token?: string + error?: string + error_description?: string +} + +interface CopilotTokenResponse { + token: string + expires_at: number + refresh_in?: number + endpoints?: { + api?: string + } +} + +interface StoredTokenData { + access_token: string + api_key?: string + api_key_expires_at?: number + api_base?: string +} + +interface DeviceCodeInfo { + device_code: string + user_code: string + verification_uri: string + expires_in: number + interval: number +} + +export class CopilotAuthenticator { + private tokenDir: string + private tokenFile: string + private static instance: CopilotAuthenticator | null = null + private deviceCodeCallback?: (info: DeviceCodeInfo) => void + private authTimeoutCallback?: (error: string) => void + + private constructor() { + // Store tokens in user's home directory + this.tokenDir = join(homedir(), ".roo-code", "copilot") + this.tokenFile = join(this.tokenDir, "tokens.json") + } + + public static getInstance() { + if (CopilotAuthenticator.instance === null) { + CopilotAuthenticator.instance = new CopilotAuthenticator() + } + return CopilotAuthenticator.instance + } + + /** + * Set callback for device code information + */ + setDeviceCodeCallback(callback: (info: DeviceCodeInfo) => void) { + this.deviceCodeCallback = callback + } + + /** + * Set callback for authentication timeout + */ + setAuthTimeoutCallback(callback: (error: string) => void) { + this.authTimeoutCallback = callback + } + + /** + * Get a valid API key for Copilot + */ + async getApiKey(): Promise<{ apiKey: string; apiBase?: string }> { + try { + const stored = await this.loadStoredTokens() + + // Check if we have a valid API key + if (stored.api_key && stored.api_key_expires_at) { + const now = Math.floor(Date.now() / 1000) + if (stored.api_key_expires_at > now + 60) { + // 60 second buffer + return { + apiKey: stored.api_key, + apiBase: stored.api_base, + } + } + } + + // If we have an access token, try to refresh the API key + if (stored.access_token) { + try { + const copilotToken = await this.refreshApiKey(stored.access_token) + await this.saveTokens({ + access_token: stored.access_token, + api_key: copilotToken.token, + api_key_expires_at: copilotToken.expires_at, + api_base: copilotToken.endpoints?.api, + }) + return { + apiKey: copilotToken.token, + apiBase: copilotToken.endpoints?.api, + } + } catch (error) { + console.warn("Failed to refresh API key, starting new authentication:", error) + // Fall through to device code flow + } + } + + // Start device code flow + const accessToken = await this.authenticateWithDeviceCode() + const copilotToken = await this.refreshApiKey(accessToken) + + await this.saveTokens({ + access_token: accessToken, + api_key: copilotToken.token, + api_key_expires_at: copilotToken.expires_at, + api_base: copilotToken.endpoints?.api, + }) + + return { + apiKey: copilotToken.token, + apiBase: copilotToken.endpoints?.api, + } + } catch (error) { + throw new Error(`Failed to authenticate with Copilot: ${error}`) + } + } + + /** + * Start device code authentication flow + */ + private async authenticateWithDeviceCode(): Promise { + // Step 1: Get device code + const deviceResponse = await axios.post( + GITHUB_DEVICE_CODE_URL, + { + client_id: GITHUB_CLIENT_ID, + scope: "read:user", + }, + { + headers: { + Accept: "application/json", + "Content-Type": "application/json", + "User-Agent": "GitHubCopilotChat/0.26.7", + }, + }, + ) + + const deviceData = deviceResponse.data + + // Step 2: Show user code to user via callback + if (this.deviceCodeCallback) { + this.deviceCodeCallback({ + device_code: deviceData.device_code, + user_code: deviceData.user_code, + verification_uri: deviceData.verification_uri, + expires_in: deviceData.expires_in, + interval: deviceData.interval || 5, + }) + } + + // Step 3: Poll for access token + return this.pollForAccessToken(deviceData.device_code, deviceData.interval || 5) + } + + /** + * Poll GitHub for access token after user authorization + */ + private async pollForAccessToken(deviceCode: string, interval: number): Promise { + const maxAttempts = 60 // 5 minutes maximum + + for (let attempt = 0; attempt < maxAttempts; attempt++) { + await new Promise((resolve) => setTimeout(resolve, interval * 1000)) + + try { + const response = await axios.post( + GITHUB_ACCESS_TOKEN_URL, + { + client_id: GITHUB_CLIENT_ID, + device_code: deviceCode, + grant_type: "urn:ietf:params:oauth:grant-type:device_code", + }, + { + headers: { + Accept: "application/json", + "Content-Type": "application/json", + "User-Agent": "GitHubCopilotChat/0.26.7", + }, + }, + ) + + const data = response.data + + if (data.access_token) { + console.log("✅ Authentication successful!") + return data.access_token + } + + if (data.error === "authorization_pending") { + continue // Keep polling + } + + if (data.error === "slow_down") { + interval = Math.min(interval * 2, 10) // Increase interval + continue + } + + if (data.error) { + const errorMsg = `GitHub OAuth error: ${data.error} - ${data.error_description}` + if (this.authTimeoutCallback) { + this.authTimeoutCallback(errorMsg) + } + throw new Error(errorMsg) + } + } catch (error) { + if (axios.isAxiosError(error) && error.response?.status === 400) { + // Continue polling on 400 errors (authorization_pending) + continue + } + if (this.authTimeoutCallback) { + this.authTimeoutCallback(error instanceof Error ? error.message : "Authentication failed") + } + throw error + } + } + + const timeoutError = "Authentication timed out. Please try again." + if (this.authTimeoutCallback) { + this.authTimeoutCallback(timeoutError) + } + throw new Error(timeoutError) + } + + /** + * Exchange access token for Copilot API key + */ + private async refreshApiKey(accessToken: string): Promise { + const response = await axios.get(GITHUB_API_KEY_URL, { + headers: { + Accept: "application/json", + Authorization: `Bearer ${accessToken}`, + "User-Agent": "GitHubCopilotChat/0.26.7", + "Editor-Version": "vscode/1.85.1", + "Editor-Plugin-Version": "copilot-chat/0.26.7", + }, + }) + + return response.data + } + + /** + * Load stored tokens from file + */ + private async loadStoredTokens(): Promise> { + try { + await this.ensureTokenDir() + const data = await fs.readFile(this.tokenFile, "utf-8") + return JSON.parse(data) + } catch (error) { + return {} + } + } + + /** + * Save tokens to file + */ + private async saveTokens(tokens: StoredTokenData): Promise { + await this.ensureTokenDir() + await fs.writeFile(this.tokenFile, JSON.stringify(tokens, null, 2)) + } + + /** + * Ensure token directory exists + */ + private async ensureTokenDir(): Promise { + try { + await fs.mkdir(this.tokenDir, { recursive: true }) + } catch (error) { + // Directory might already exist + } + } + + /** + * Clear stored authentication data + */ + async clearAuth(): Promise { + try { + await fs.unlink(this.tokenFile) + console.log("🗑️ Cleared Copilot authentication data") + } catch (error) { + // File might not exist + } + } + + /** + * Check if user is authenticated + */ + async isAuthenticated(): Promise { + try { + const stored = await this.loadStoredTokens() + return !!stored.access_token + } catch (error) { + return false + } + } +} + +/** + * Get available Copilot models using device code authentication + */ +export async function getCopilotModels(): Promise { + try { + const authenticator = CopilotAuthenticator.getInstance() + const { apiKey, apiBase } = await authenticator.getApiKey() + + const baseURL = apiBase || GITHUB_COPILOT_API_BASE + const modelsUrl = `${baseURL.replace(/\/$/, "")}/models` + + const response = await fetch(modelsUrl, { + headers: { + Authorization: `Bearer ${apiKey}`, + Accept: "application/json", + "User-Agent": "GithubCopilot/1.155.0", + "editor-version": "vscode/1.85.1", + "editor-plugin-version": "copilot/1.155.0", + }, + }) + + if (!response.ok) { + console.warn("Failed to fetch Copilot models:", response.statusText) + throw new Error(`Failed to fetch Copilot models: ${response.statusText}`) + } + + const data = await response.json() + const result = {} as ModelRecord + for (const model of data.data) { + if (model.model_picker_enabled !== true) { + continue + } + result[model.id] = { + maxTokens: model?.capabilities?.limits?.max_output_tokens, + maxThinkingTokens: model?.capabilities?.supports?.max_thinking_budget, + contextWindow: model?.capabilities?.limits?.max_context_window_tokens, + // supportsImages: !!model?.capabilities?.supports?.vision, + supportsImages: false, + supportsComputerUse: false, + supportsPromptCache: true, + supportsVerbosity: false, + // supportsReasoningBudget: !!model?.capabilities?.supports?.max_thinking_budget, + supportsReasoningBudget: false, + requiredReasoningBudget: false, + supportsReasoningEffort: false, + supportedParameters: model?.capabilities?.supports?.max_thinking_budget ? ["reasoning"] : [], + inputPrice: 0, + outputPrice: 0, + cacheWritesPrice: 0, + cacheReadsPrice: 0, + description: model.name, + reasoningEffort: undefined, + minTokensPerCachePoint: undefined, + maxCachePoints: undefined, + cachableFields: undefined, + tiers: undefined, + } + } + return result + } catch (error) { + console.error("Failed to fetch Copilot models:", error) + throw new Error(`Failed to fetch Copilot models: ${error instanceof Error ? error.message : String(error)}`) + } +} diff --git a/src/api/providers/fetchers/modelCache.ts b/src/api/providers/fetchers/modelCache.ts index a21e75ded93..407ecdcaf49 100644 --- a/src/api/providers/fetchers/modelCache.ts +++ b/src/api/providers/fetchers/modelCache.ts @@ -18,6 +18,7 @@ import { GetModelsOptions } from "../../../shared/api" import { getOllamaModels } from "./ollama" import { getLMStudioModels } from "./lmstudio" import { getIOIntelligenceModels } from "./io-intelligence" +import { getCopilotModels } from "./copilot" const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 }) async function writeModels(router: RouterName, data: ModelRecord) { @@ -81,6 +82,9 @@ export const getModels = async (options: GetModelsOptions): Promise case "io-intelligence": models = await getIOIntelligenceModels(options.apiKey) break + case "copilot": + models = await getCopilotModels() + break default: { // Ensures router is exhaustively checked if RouterName is a strict union const exhaustiveCheck: never = provider diff --git a/src/api/providers/index.ts b/src/api/providers/index.ts index 736da82d514..c245ab8b7ce 100644 --- a/src/api/providers/index.ts +++ b/src/api/providers/index.ts @@ -29,3 +29,4 @@ export { VsCodeLmHandler } from "./vscode-lm" export { XAIHandler } from "./xai" export { ZAiHandler } from "./zai" export { FireworksHandler } from "./fireworks" +export { CopilotHandler } from "./copilot" diff --git a/src/core/webview/__tests__/ClineProvider.spec.ts b/src/core/webview/__tests__/ClineProvider.spec.ts index 29df1e4a2eb..92774be1479 100644 --- a/src/core/webview/__tests__/ClineProvider.spec.ts +++ b/src/core/webview/__tests__/ClineProvider.spec.ts @@ -2674,7 +2674,7 @@ describe("ClineProvider - Router Models", () => { apiKey: "litellm-key", baseUrl: "http://localhost:4000", }) - + expect(getModels).toHaveBeenCalledWith({ provider: "copilot" }) // Verify response was sent expect(mockPostMessage).toHaveBeenCalledWith({ type: "routerModels", @@ -2686,6 +2686,7 @@ describe("ClineProvider - Router Models", () => { litellm: mockModels, ollama: {}, lmstudio: {}, + copilot: mockModels, }, }) }) @@ -2717,6 +2718,7 @@ describe("ClineProvider - Router Models", () => { .mockResolvedValueOnce(mockModels) // glama success .mockRejectedValueOnce(new Error("Unbound API error")) // unbound fail .mockRejectedValueOnce(new Error("LiteLLM connection failed")) // litellm fail + .mockResolvedValueOnce(mockModels) // copilot success await messageHandler({ type: "requestRouterModels" }) @@ -2731,6 +2733,7 @@ describe("ClineProvider - Router Models", () => { ollama: {}, lmstudio: {}, litellm: {}, + copilot: mockModels, }, }) @@ -2841,6 +2844,7 @@ describe("ClineProvider - Router Models", () => { litellm: {}, ollama: {}, lmstudio: {}, + copilot: mockModels, }, }) }) diff --git a/src/core/webview/__tests__/webviewMessageHandler.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.spec.ts index 7ba7128bb95..afd3ada7406 100644 --- a/src/core/webview/__tests__/webviewMessageHandler.spec.ts +++ b/src/core/webview/__tests__/webviewMessageHandler.spec.ts @@ -2,13 +2,17 @@ import type { Mock } from "vitest" // Mock dependencies - must come before imports vi.mock("../../../api/providers/fetchers/modelCache") +vi.mock("../../../api/providers/fetchers/copilot") import { webviewMessageHandler } from "../webviewMessageHandler" import type { ClineProvider } from "../ClineProvider" import { getModels } from "../../../api/providers/fetchers/modelCache" +import { getCopilotModels, CopilotAuthenticator } from "../../../api/providers/fetchers/copilot" import type { ModelRecord } from "../../../shared/api" const mockGetModels = getModels as Mock +const mockGetCopilotModels = getCopilotModels as Mock +const mockCopilotAuthenticator = CopilotAuthenticator as any // Mock ClineProvider const mockClineProvider = { @@ -183,6 +187,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { apiKey: "litellm-key", baseUrl: "http://localhost:4000", }) + expect(mockGetModels).toHaveBeenCalledWith({ provider: "copilot" }) // Verify response was sent expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ @@ -195,6 +200,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { litellm: mockModels, ollama: {}, lmstudio: {}, + copilot: mockModels, }, }) }) @@ -282,6 +288,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { litellm: {}, ollama: {}, lmstudio: {}, + copilot: mockModels, }, }) }) @@ -303,6 +310,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { .mockResolvedValueOnce(mockModels) // glama .mockRejectedValueOnce(new Error("Unbound API error")) // unbound .mockRejectedValueOnce(new Error("LiteLLM connection failed")) // litellm + .mockResolvedValueOnce(mockModels) // copilot success await webviewMessageHandler(mockClineProvider, { type: "requestRouterModels", @@ -319,6 +327,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { litellm: {}, ollama: {}, lmstudio: {}, + copilot: mockModels, }, }) @@ -529,6 +538,276 @@ describe("webviewMessageHandler - deleteCustomMode", () => { }) }) +describe("webviewMessageHandler - requestCopilotModels", () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it("successfully fetches Copilot models", async () => { + const mockCopilotModels = { + "gpt-4o": { + maxTokens: 4096, + contextWindow: 128000, + supportsPromptCache: false, + description: "GPT-4 Omni model", + }, + "gpt-3.5-turbo": { + maxTokens: 4096, + contextWindow: 16385, + supportsPromptCache: false, + description: "GPT-3.5 Turbo model", + }, + } + + // Mock getCopilotModels to return mock models + mockGetCopilotModels.mockResolvedValue(mockCopilotModels) + + await webviewMessageHandler(mockClineProvider, { + type: "requestCopilotModels", + }) + + expect(mockGetCopilotModels).toHaveBeenCalledTimes(1) + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "copilotModels", + copilotModels: mockCopilotModels, + }) + }) + + it("handles errors when fetching Copilot models", async () => { + mockGetCopilotModels.mockRejectedValue(new Error("Authentication failed")) + + // Spy on console.error + const consoleSpy = vi.spyOn(console, "error").mockImplementation(() => {}) + + await webviewMessageHandler(mockClineProvider, { + type: "requestCopilotModels", + }) + + expect(mockGetCopilotModels).toHaveBeenCalledTimes(1) + expect(consoleSpy).toHaveBeenCalledWith("Failed to fetch Copilot models:", expect.any(Error)) + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "copilotModels", + copilotModels: {}, + }) + + consoleSpy.mockRestore() + }) +}) + +describe("webviewMessageHandler - authenticateCopilot", () => { + let mockAuthenticator: any + + beforeEach(() => { + vi.clearAllMocks() + + // Mock CopilotAuthenticator + mockAuthenticator = { + setDeviceCodeCallback: vi.fn(), + setAuthTimeoutCallback: vi.fn(), + getApiKey: vi.fn(), + } + + vi.mocked(mockCopilotAuthenticator.getInstance).mockReturnValue(mockAuthenticator) + }) + + it("successfully starts device code authentication flow", async () => { + mockAuthenticator.getApiKey.mockResolvedValue("test-api-key") + + await webviewMessageHandler(mockClineProvider, { + type: "authenticateCopilot", + }) + + expect(mockAuthenticator.setDeviceCodeCallback).toHaveBeenCalledWith(expect.any(Function)) + expect(mockAuthenticator.setAuthTimeoutCallback).toHaveBeenCalledWith(expect.any(Function)) + expect(mockAuthenticator.getApiKey).toHaveBeenCalledTimes(1) + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "copilotAuthStatus", + copilotAuthenticated: true, + }) + }) + + it("handles authentication errors", async () => { + const error = new Error("Authentication timeout") + mockAuthenticator.getApiKey.mockRejectedValue(error) + + // Spy on console.error + const consoleSpy = vi.spyOn(console, "error").mockImplementation(() => {}) + + await webviewMessageHandler(mockClineProvider, { + type: "authenticateCopilot", + }) + + expect(consoleSpy).toHaveBeenCalledWith("Failed to authenticate with Copilot:", error) + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "copilotAuthError", + error: "Authentication timeout", + }) + + consoleSpy.mockRestore() + }) + + it("handles device code callback", async () => { + const mockDeviceInfo = { + user_code: "ABC123", + verification_uri: "https://github.com/login/device", + expires_in: 900, + } + + mockAuthenticator.getApiKey.mockResolvedValue("test-api-key") + + await webviewMessageHandler(mockClineProvider, { + type: "authenticateCopilot", + }) + + // Get the callback function that was set + const deviceCodeCallback = mockAuthenticator.setDeviceCodeCallback.mock.calls[0][0] + + // Call the callback with mock device info + deviceCodeCallback(mockDeviceInfo) + + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "copilotDeviceCode", + copilotDeviceCode: { + user_code: "ABC123", + verification_uri: "https://github.com/login/device", + expires_in: 900, + }, + }) + }) + + it("handles auth timeout callback", async () => { + const timeoutError = "Device code expired" + mockAuthenticator.getApiKey.mockResolvedValue("test-api-key") + + await webviewMessageHandler(mockClineProvider, { + type: "authenticateCopilot", + }) + + // Get the callback function that was set + const authTimeoutCallback = mockAuthenticator.setAuthTimeoutCallback.mock.calls[0][0] + + // Call the callback with timeout error + authTimeoutCallback(timeoutError) + + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "copilotAuthError", + error: timeoutError, + }) + }) +}) + +describe("webviewMessageHandler - clearCopilotAuth", () => { + let mockAuthenticator: any + + beforeEach(() => { + vi.clearAllMocks() + + // Mock CopilotAuthenticator + mockAuthenticator = { + clearAuth: vi.fn(), + } + + vi.mocked(mockCopilotAuthenticator.getInstance).mockReturnValue(mockAuthenticator) + }) + + it("successfully clears Copilot authentication", async () => { + mockAuthenticator.clearAuth.mockResolvedValue(undefined) + + await webviewMessageHandler(mockClineProvider, { + type: "clearCopilotAuth", + }) + + expect(mockAuthenticator.clearAuth).toHaveBeenCalledTimes(1) + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "copilotAuthStatus", + copilotAuthenticated: false, + }) + }) + + it("handles errors when clearing authentication", async () => { + const error = new Error("Failed to clear token") + mockAuthenticator.clearAuth.mockRejectedValue(error) + + // Spy on console.error + const consoleSpy = vi.spyOn(console, "error").mockImplementation(() => {}) + + await webviewMessageHandler(mockClineProvider, { + type: "clearCopilotAuth", + }) + + expect(consoleSpy).toHaveBeenCalledWith("Failed to clear Copilot authentication:", error) + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "copilotAuthError", + error: "Failed to clear token", + }) + + consoleSpy.mockRestore() + }) +}) + +describe("webviewMessageHandler - checkCopilotAuth", () => { + let mockAuthenticator: any + + beforeEach(() => { + vi.clearAllMocks() + + // Mock CopilotAuthenticator + mockAuthenticator = { + isAuthenticated: vi.fn(), + } + + vi.mocked(mockCopilotAuthenticator.getInstance).mockReturnValue(mockAuthenticator) + }) + + it("returns true when user is authenticated", async () => { + mockAuthenticator.isAuthenticated.mockResolvedValue(true) + + await webviewMessageHandler(mockClineProvider, { + type: "checkCopilotAuth", + }) + + expect(mockAuthenticator.isAuthenticated).toHaveBeenCalledTimes(1) + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "copilotAuthStatus", + copilotAuthenticated: true, + }) + }) + + it("returns false when user is not authenticated", async () => { + mockAuthenticator.isAuthenticated.mockResolvedValue(false) + + await webviewMessageHandler(mockClineProvider, { + type: "checkCopilotAuth", + }) + + expect(mockAuthenticator.isAuthenticated).toHaveBeenCalledTimes(1) + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "copilotAuthStatus", + copilotAuthenticated: false, + }) + }) + + it("handles errors when checking authentication status", async () => { + const error = new Error("Network error") + mockAuthenticator.isAuthenticated.mockRejectedValue(error) + + // Spy on console.error + const consoleSpy = vi.spyOn(console, "error").mockImplementation(() => {}) + + await webviewMessageHandler(mockClineProvider, { + type: "checkCopilotAuth", + }) + + expect(consoleSpy).toHaveBeenCalledWith("Failed to check Copilot authentication:", error) + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "copilotAuthStatus", + copilotAuthenticated: false, + }) + + consoleSpy.mockRestore() + }) +}) + describe("webviewMessageHandler - message dialog preferences", () => { beforeEach(() => { vi.clearAllMocks() diff --git a/src/core/webview/message-handle/copilot/authenticateCopilot.ts b/src/core/webview/message-handle/copilot/authenticateCopilot.ts new file mode 100644 index 00000000000..3bc3394ffe2 --- /dev/null +++ b/src/core/webview/message-handle/copilot/authenticateCopilot.ts @@ -0,0 +1,47 @@ +import { MessageHandlerStrategy, MessageHandlerContext } from "../types" +import { CopilotAuthenticator } from "../../../../api/providers/fetchers/copilot" + +/** + * Strategy for handling authenticateCopilot message + */ +export class AuthenticateCopilotStrategy implements MessageHandlerStrategy { + async handle(context: MessageHandlerContext): Promise { + const { provider } = context + + // Start device code authentication for Copilot + try { + const authenticator = CopilotAuthenticator.getInstance() + + // Set up callbacks + authenticator.setDeviceCodeCallback((deviceInfo) => { + provider.postMessageToWebview({ + type: "copilotDeviceCode", + copilotDeviceCode: { + user_code: deviceInfo.user_code, + verification_uri: deviceInfo.verification_uri, + expires_in: deviceInfo.expires_in, + }, + }) + }) + + authenticator.setAuthTimeoutCallback((error) => { + provider.postMessageToWebview({ + type: "copilotAuthError", + error: error, + }) + }) + + await authenticator.getApiKey() // This will trigger the device code flow + provider.postMessageToWebview({ + type: "copilotAuthStatus", + copilotAuthenticated: true, + }) + } catch (error) { + console.error("Failed to authenticate with Copilot:", error) + provider.postMessageToWebview({ + type: "copilotAuthError", + error: error instanceof Error ? error.message : "Authentication failed", + }) + } + } +} diff --git a/src/core/webview/message-handle/copilot/checkCopilotAuth.ts b/src/core/webview/message-handle/copilot/checkCopilotAuth.ts new file mode 100644 index 00000000000..e6db489a023 --- /dev/null +++ b/src/core/webview/message-handle/copilot/checkCopilotAuth.ts @@ -0,0 +1,27 @@ +import { MessageHandlerStrategy, MessageHandlerContext } from "../types" +import { CopilotAuthenticator } from "../../../../api/providers/fetchers/copilot" + +/** + * Strategy for handling checkCopilotAuth message + */ +export class CheckCopilotAuthStrategy implements MessageHandlerStrategy { + async handle(context: MessageHandlerContext): Promise { + const { provider } = context + + try { + const authenticator = CopilotAuthenticator.getInstance() + const isAuthenticated = await authenticator.isAuthenticated() + + provider.postMessageToWebview({ + type: "copilotAuthStatus", + copilotAuthenticated: isAuthenticated, + }) + } catch (error) { + console.error("Failed to check Copilot authentication:", error) + provider.postMessageToWebview({ + type: "copilotAuthStatus", + copilotAuthenticated: false, + }) + } + } +} diff --git a/src/core/webview/message-handle/copilot/clearCopilotAuth.ts b/src/core/webview/message-handle/copilot/clearCopilotAuth.ts new file mode 100644 index 00000000000..66c14edd59a --- /dev/null +++ b/src/core/webview/message-handle/copilot/clearCopilotAuth.ts @@ -0,0 +1,26 @@ +import { MessageHandlerStrategy, MessageHandlerContext } from "../types" +import { CopilotAuthenticator } from "../../../../api/providers/fetchers/copilot" + +/** + * Strategy for handling clearCopilotAuth message + */ +export class ClearCopilotAuthStrategy implements MessageHandlerStrategy { + async handle(context: MessageHandlerContext): Promise { + const { provider } = context + + try { + const authenticator = CopilotAuthenticator.getInstance() + await authenticator.clearAuth() + provider.postMessageToWebview({ + type: "copilotAuthStatus", + copilotAuthenticated: false, + }) + } catch (error) { + console.error("Failed to clear Copilot authentication:", error) + provider.postMessageToWebview({ + type: "copilotAuthError", + error: error instanceof Error ? error.message : "Failed to clear authentication", + }) + } + } +} diff --git a/src/core/webview/message-handle/copilot/index.ts b/src/core/webview/message-handle/copilot/index.ts new file mode 100644 index 00000000000..a786800e2a2 --- /dev/null +++ b/src/core/webview/message-handle/copilot/index.ts @@ -0,0 +1,29 @@ +/** + * Copilot message handling module + * + * This module implements the Strategy pattern for handling Copilot-related messages. + * Each message type has its own dedicated handler that implements the MessageHandler interface. + * + * Supported message types: + * - requestCopilotModels: Fetches available Copilot models + * - authenticateCopilot: Initiates Copilot authentication flow + * - clearCopilotAuth: Clears stored Copilot authentication + * - checkCopilotAuth: Checks current Copilot authentication status + */ + +import { RequestCopilotModelsStrategy } from "./requestCopilotModels" +import { AuthenticateCopilotStrategy } from "./authenticateCopilot" +import { ClearCopilotAuthStrategy } from "./clearCopilotAuth" +import { CheckCopilotAuthStrategy } from "./checkCopilotAuth" +import { MessageHandlerRegistry } from "../types" + +/** + * Register all Copilot-related message handler strategies + */ +export function registerCopilotStrategies(messageHandlerRegistry: MessageHandlerRegistry): void { + // Register each strategy with its corresponding message type + messageHandlerRegistry.registerStrategy("requestCopilotModels", new RequestCopilotModelsStrategy()) + messageHandlerRegistry.registerStrategy("authenticateCopilot", new AuthenticateCopilotStrategy()) + messageHandlerRegistry.registerStrategy("clearCopilotAuth", new ClearCopilotAuthStrategy()) + messageHandlerRegistry.registerStrategy("checkCopilotAuth", new CheckCopilotAuthStrategy()) +} diff --git a/src/core/webview/message-handle/copilot/requestCopilotModels.ts b/src/core/webview/message-handle/copilot/requestCopilotModels.ts new file mode 100644 index 00000000000..9eed2fc9eda --- /dev/null +++ b/src/core/webview/message-handle/copilot/requestCopilotModels.ts @@ -0,0 +1,26 @@ +import { MessageHandlerStrategy, MessageHandlerContext } from "../types" +import { getCopilotModels } from "../../../../api/providers/fetchers/copilot" +import { ModelRecord } from "../../../../shared/api" + +/** + * Strategy for handling requestCopilotModels message + */ +export class RequestCopilotModelsStrategy implements MessageHandlerStrategy { + async handle(context: MessageHandlerContext): Promise { + const { provider } = context + + try { + const copilotModels = await getCopilotModels() + provider.postMessageToWebview({ + type: "copilotModels", + copilotModels, + }) + } catch (error) { + console.error("Failed to fetch Copilot models:", error) + provider.postMessageToWebview({ + type: "copilotModels", + copilotModels: {}, + }) + } + } +} diff --git a/src/core/webview/message-handle/index.ts b/src/core/webview/message-handle/index.ts new file mode 100644 index 00000000000..b35c59aaa5f --- /dev/null +++ b/src/core/webview/message-handle/index.ts @@ -0,0 +1,2 @@ +export * from "./types" +export * from "./registry" diff --git a/src/core/webview/message-handle/registry.ts b/src/core/webview/message-handle/registry.ts new file mode 100644 index 00000000000..4ff0593d6cd --- /dev/null +++ b/src/core/webview/message-handle/registry.ts @@ -0,0 +1,42 @@ +import { registerCopilotStrategies } from "./copilot" +import { MessageHandlerStrategy, MessageHandlerRegistry, MessageHandlerContext } from "./types" + +/** + * Central registry for message handler strategies + * Implements strategy pattern with key-based registration + */ +export class DefaultMessageHandlerRegistry implements MessageHandlerRegistry { + private static instance: DefaultMessageHandlerRegistry | null = null + private strategies: Map = new Map() + + private constructor() {} + + public static getInstance() { + if (DefaultMessageHandlerRegistry.instance === null) { + DefaultMessageHandlerRegistry.instance = new DefaultMessageHandlerRegistry() + registerCopilotStrategies(DefaultMessageHandlerRegistry.instance) + } + return DefaultMessageHandlerRegistry.instance + } + + /** + * Registers a strategy for handling a specific message type + */ + registerStrategy(messageType: string, strategy: MessageHandlerStrategy): void { + this.strategies.set(messageType, strategy) + } + + /** + * Gets a strategy for the given message type + */ + getStrategy(messageType: string): MessageHandlerStrategy | null { + return this.strategies.get(messageType) || null + } + + /** + * Gets all registered message types + */ + getSupportedTypes(): string[] { + return Array.from(this.strategies.keys()) + } +} diff --git a/src/core/webview/message-handle/types.ts b/src/core/webview/message-handle/types.ts new file mode 100644 index 00000000000..e2a8181b1e9 --- /dev/null +++ b/src/core/webview/message-handle/types.ts @@ -0,0 +1,51 @@ +import { ClineProvider } from "../ClineProvider" +import { WebviewMessage } from "../../../shared/WebviewMessage" +import { MarketplaceManager } from "../../../services/marketplace" + +/** + * Context provided to message handlers + */ +export interface MessageHandlerContext { + /** The ClineProvider instance */ + provider: ClineProvider + /** The webview message to handle */ + message: WebviewMessage + /** Optional marketplace manager */ + marketplaceManager?: MarketplaceManager +} + +/** + * Strategy interface for handling specific message types + */ +export interface MessageHandlerStrategy { + /** + * Handles a specific webview message type + * @param context The message handler context + */ + handle(context: MessageHandlerContext): Promise +} + +/** + * Registry for message handler strategies + */ +export interface MessageHandlerRegistry { + /** + * Registers a strategy for handling a specific message type + * @param messageType The type of message to handle + * @param strategy The strategy to register + */ + registerStrategy(messageType: string, strategy: MessageHandlerStrategy): void + + /** + * Gets a strategy for the given message type + * @param messageType The type of message to handle + * @returns MessageHandlerStrategy instance or null if not supported + */ + getStrategy(messageType: string): MessageHandlerStrategy | null + + /** + * Gets all registered message types + * @returns Array of supported message type strings + */ + getSupportedTypes(): string[] +} diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index 7e25ae14dcd..2d6e32756cc 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -4,15 +4,8 @@ import * as os from "os" import * as fs from "fs/promises" import pWaitFor from "p-wait-for" import * as vscode from "vscode" -import * as yaml from "yaml" - -import { - type Language, - type ProviderSettings, - type GlobalState, - type ClineMessage, - TelemetryEventName, -} from "@roo-code/types" + +import { type Language, type GlobalState, type ClineMessage, TelemetryEventName } from "@roo-code/types" import { CloudService } from "@roo-code/cloud" import { TelemetryService } from "@roo-code/telemetry" import { type ApiMessage } from "../task-persistence/apiMessages" @@ -21,7 +14,6 @@ import { ClineProvider } from "./ClineProvider" import { changeLanguage, t } from "../../i18n" import { Package } from "../../shared/package" import { RouterName, toRouterName, ModelRecord } from "../../shared/api" -import { supportPrompt } from "../../shared/support-prompt" import { MessageEnhancer } from "./messageEnhancer" import { checkoutDiffPayloadSchema, checkoutRestorePayloadSchema, WebviewMessage } from "../../shared/WebviewMessage" @@ -29,7 +21,6 @@ import { checkExistKey } from "../../shared/checkExistApiConfig" import { experimentDefault } from "../../shared/experiments" import { Terminal } from "../../integrations/terminal/Terminal" import { openFile } from "../../integrations/misc/open-file" -import { CodeIndexManager } from "../../services/code-index/manager" import { openImage, saveImage } from "../../integrations/misc/image-handler" import { selectImages } from "../../integrations/misc/process-images" import { getTheme } from "../../integrations/theme/getTheme" @@ -44,12 +35,12 @@ import { getVsCodeLmModels } from "../../api/providers/vscode-lm" import { openMention } from "../mentions" import { TelemetrySetting } from "../../shared/TelemetrySetting" import { getWorkspacePath } from "../../utils/path" -import { ensureSettingsDirectoryExists } from "../../utils/globalContext" import { Mode, defaultModeSlug } from "../../shared/modes" import { getModels, flushModels } from "../../api/providers/fetchers/modelCache" import { GetModelsOptions } from "../../shared/api" import { generateSystemPrompt } from "./generateSystemPrompt" import { getCommand } from "../../utils/commands" +import { DefaultMessageHandlerRegistry } from "./message-handle" const ALLOWED_VSCODE_SETTINGS = new Set(["terminal.integrated.inheritEnv"]) @@ -65,7 +56,7 @@ export const webviewMessageHandler = async ( const getGlobalState = (key: K) => provider.contextProxy.getValue(key) const updateGlobalState = async (key: K, value: GlobalState[K]) => await provider.contextProxy.setValue(key, value) - + const messageHandler = DefaultMessageHandlerRegistry.getInstance() /** * Shared utility to find message indices based on timestamp */ @@ -527,6 +518,7 @@ export const webviewMessageHandler = async ( litellm: {}, ollama: {}, lmstudio: {}, + copilot: {}, } const safeGetModels = async (options: GetModelsOptions): Promise => { @@ -568,6 +560,7 @@ export const webviewMessageHandler = async ( options: { provider: "litellm", apiKey: litellmApiKey, baseUrl: litellmBaseUrl }, }) } + modelFetchPromises.push({ key: "copilot", options: { provider: "copilot" } }) const results = await Promise.allSettled( modelFetchPromises.map(async ({ key, options }) => { @@ -2611,5 +2604,17 @@ export const webviewMessageHandler = async ( } break } + default: { + // Try to handle the message using the strategy pattern + const handler = await messageHandler.getStrategy(message.type) + if (handler) { + await handler.handle({ provider, message, marketplaceManager }) + break + } + + // Message type not recognized + console.warn(`Unhandled message type: ${message.type}`) + break + } } } diff --git a/src/shared/ExtensionMessage.ts b/src/shared/ExtensionMessage.ts index ebdc137432b..b6b498d7196 100644 --- a/src/shared/ExtensionMessage.ts +++ b/src/shared/ExtensionMessage.ts @@ -75,6 +75,10 @@ export interface ExtensionMessage { | "ollamaModels" | "lmStudioModels" | "vsCodeLmModels" + | "copilotModels" + | "copilotAuthStatus" + | "copilotAuthError" + | "copilotDeviceCode" | "huggingFaceModels" | "vsCodeLmApiAvailable" | "updatePrompt" @@ -147,6 +151,13 @@ export interface ExtensionMessage { ollamaModels?: string[] lmStudioModels?: ModelRecord vsCodeLmModels?: { vendor?: string; family?: string; version?: string; id?: string }[] + copilotModels?: ModelRecord + copilotAuthenticated?: boolean + copilotDeviceCode?: { + user_code: string + verification_uri: string + expires_in: number + } huggingFaceModels?: Array<{ id: string object: string diff --git a/src/shared/WebviewMessage.ts b/src/shared/WebviewMessage.ts index d59ccd556c8..4aa790e2226 100644 --- a/src/shared/WebviewMessage.ts +++ b/src/shared/WebviewMessage.ts @@ -67,6 +67,10 @@ export interface WebviewMessage { | "requestOllamaModels" | "requestLmStudioModels" | "requestVsCodeLmModels" + | "requestCopilotModels" + | "authenticateCopilot" + | "clearCopilotAuth" + | "checkCopilotAuth" | "requestHuggingFaceModels" | "openImage" | "saveImage" diff --git a/src/shared/api.ts b/src/shared/api.ts index f1bf7dbaea4..bdbd3cc98e6 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -27,6 +27,7 @@ const routerNames = [ "ollama", "lmstudio", "io-intelligence", + "copilot", ] as const export type RouterName = (typeof routerNames)[number] @@ -141,3 +142,4 @@ export type GetModelsOptions = | { provider: "ollama"; baseUrl?: string } | { provider: "lmstudio"; baseUrl?: string } | { provider: "io-intelligence"; apiKey: string } + | { provider: "copilot" } diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index dcdf072a11c..c17a5a75ad9 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -32,6 +32,7 @@ import { mainlandZAiDefaultModelId, fireworksDefaultModelId, ioIntelligenceDefaultModelId, + copilotDefaultModelId, } from "@roo-code/types" import { vscode } from "@src/utils/vscode" @@ -83,6 +84,7 @@ import { Unbound, Vertex, VSCodeLM, + Copilot, XAI, ZAi, Fireworks, @@ -326,6 +328,7 @@ const ApiOptions = ({ openai: { field: "openAiModelId" }, ollama: { field: "ollamaModelId" }, lmstudio: { field: "lmStudioModelId" }, + copilot: { field: "copilotModelId", default: copilotDefaultModelId }, } const config = PROVIDER_MODEL_CONFIG[value] @@ -510,6 +513,15 @@ const ApiOptions = ({ )} + {selectedProvider === "copilot" && ( + + )} + {selectedProvider === "ollama" && ( )} diff --git a/webview-ui/src/components/settings/ModelPicker.tsx b/webview-ui/src/components/settings/ModelPicker.tsx index 2160d4b3c55..6d0455ff68e 100644 --- a/webview-ui/src/components/settings/ModelPicker.tsx +++ b/webview-ui/src/components/settings/ModelPicker.tsx @@ -35,6 +35,7 @@ type ModelIdKey = keyof Pick< | "unboundModelId" | "requestyModelId" | "openAiModelId" + | "copilotModelId" | "litellmModelId" | "ioIntelligenceModelId" > diff --git a/webview-ui/src/components/settings/constants.ts b/webview-ui/src/components/settings/constants.ts index 5882b1bf4a8..449db353a09 100644 --- a/webview-ui/src/components/settings/constants.ts +++ b/webview-ui/src/components/settings/constants.ts @@ -70,4 +70,5 @@ export const PROVIDERS = [ { value: "zai", label: "Z AI" }, { value: "fireworks", label: "Fireworks AI" }, { value: "io-intelligence", label: "IO Intelligence" }, + { value: "copilot", label: "Copilot" }, ].sort((a, b) => a.label.localeCompare(b.label)) diff --git a/webview-ui/src/components/settings/providers/Copilot.tsx b/webview-ui/src/components/settings/providers/Copilot.tsx new file mode 100644 index 00000000000..eca1fb109c9 --- /dev/null +++ b/webview-ui/src/components/settings/providers/Copilot.tsx @@ -0,0 +1,235 @@ +import { useState, useCallback, useEffect } from "react" +import { useEvent } from "react-use" +import { VSCodeButton } from "@vscode/webview-ui-toolkit/react" + +import { type ProviderSettings, type ModelInfo, copilotDefaultModelId } from "@roo-code/types" +import { ExtensionMessage } from "@roo/ExtensionMessage" + +import { useAppTranslation } from "@src/i18n/TranslationContext" +import { VSCodeButtonLink } from "@src/components/common/VSCodeButtonLink" +import { vscode } from "@src/utils/vscode" +import { ModelPicker } from "../ModelPicker" +import { OrganizationAllowList } from "@roo/cloud" + +type CopilotProps = { + apiConfiguration: ProviderSettings + setApiConfigurationField: (field: keyof ProviderSettings, value: ProviderSettings[keyof ProviderSettings]) => void + organizationAllowList: OrganizationAllowList + modelValidationError?: string +} + +export const Copilot = ({ + apiConfiguration, + setApiConfigurationField, + modelValidationError, + organizationAllowList, +}: CopilotProps) => { + const { t } = useAppTranslation() + + const [copilotModels, setCopilotModels] = useState | null>(null) + const [isAuthenticated, setIsAuthenticated] = useState(false) + const [isAuthenticating, setIsAuthenticating] = useState(false) + const [deviceCodeInfo, setDeviceCodeInfo] = useState<{ + user_code: string + verification_uri: string + expires_in: number + } | null>(null) + const [authError, setAuthError] = useState(null) + + const handleAuthenticateClick = useCallback(() => { + setIsAuthenticating(true) + setAuthError(null) + setDeviceCodeInfo(null) + // Send message to extension to start device code authentication + vscode.postMessage({ + type: "authenticateCopilot", + }) + }, []) + + const handleClearAuthClick = useCallback(() => { + // Send message to extension to clear authentication + vscode.postMessage({ + type: "clearCopilotAuth", + }) + setIsAuthenticated(false) + setCopilotModels(null) + setDeviceCodeInfo(null) + setAuthError(null) + }, []) + + const handleRefreshModels = useCallback(() => { + if (!isAuthenticated) { + return + } + + // Send message to extension to fetch Copilot models + vscode.postMessage({ + type: "requestCopilotModels", + }) + }, [isAuthenticated]) + + const onMessage = useCallback( + (event: MessageEvent) => { + const message: ExtensionMessage = event.data + + switch (message.type) { + case "copilotModels": { + const updatedModels = message.copilotModels ?? {} + setCopilotModels(updatedModels) + break + } + case "copilotAuthStatus": { + setIsAuthenticated(message.copilotAuthenticated ?? false) + setIsAuthenticating(false) + if (message.copilotAuthenticated) { + // Clear device code info and error on successful auth + setDeviceCodeInfo(null) + setAuthError(null) + // Auto-refresh models when authentication succeeds + handleRefreshModels() + } + break + } + case "copilotDeviceCode": { + // Show device code info to user + if (message.copilotDeviceCode) { + setDeviceCodeInfo(message.copilotDeviceCode) + } + break + } + case "copilotAuthError": { + setIsAuthenticating(false) + setDeviceCodeInfo(null) + setAuthError(message.error || "Authentication failed") + console.error("Copilot authentication error:", message.error) + break + } + } + }, + [handleRefreshModels], + ) + + useEvent("message", onMessage) + + // Check authentication status on component mount + useEffect(() => { + vscode.postMessage({ + type: "checkCopilotAuth", + }) + }, []) + + // Auto-refresh models when authenticated status changes + useEffect(() => { + if (isAuthenticated) { + handleRefreshModels() + } else { + setCopilotModels(null) + } + }, [isAuthenticated, handleRefreshModels]) + + return ( + <> +
+

{t("settings:providers.copilotAuthentication")}

+
+ {t("settings:providers.copilotDeviceCodeNotice")} +
+ + {/* Error message */} + {authError && ( +
+
❌ Authentication Failed
+
{authError}
+
+ )} + + {/* Device code information */} + {deviceCodeInfo && isAuthenticating && ( +
+
🔐 GitHub Authentication Required
+
+
+
1. Visit GitHub:
+
+ + Open GitHub + + + {deviceCodeInfo.verification_uri} + +
+
+
+
2. Enter this code:
+
+ + {deviceCodeInfo.user_code} + + { + navigator.clipboard.writeText(deviceCodeInfo.user_code) + }}> + Copy + +
+
+
+ ⏱️ Code expires in {Math.floor(deviceCodeInfo.expires_in / 60)} minutes +
+
+
+ )} + + {!isAuthenticated ? ( + + {isAuthenticating + ? deviceCodeInfo + ? t("settings:providers.waitingForAuth") || "Waiting for authentication..." + : t("settings:providers.authenticating") + : t("settings:providers.authenticateWithGitHub")} + + ) : ( +
+
+ + {t("settings:providers.authenticated")} +
+ + {t("settings:providers.clearAuthentication")} + +
+ )} +
+ + {isAuthenticated && ( + <> + {copilotModels && Object.keys(copilotModels).length > 0 ? ( + + ) : ( +
+ {t("settings:providers.copilotModelDescription")} +
+ )} + {modelValidationError && ( +
{modelValidationError}
+ )} + + )} + + ) +} diff --git a/webview-ui/src/components/settings/providers/__tests__/Copilot.test.tsx b/webview-ui/src/components/settings/providers/__tests__/Copilot.test.tsx new file mode 100644 index 00000000000..0c432f253c2 --- /dev/null +++ b/webview-ui/src/components/settings/providers/__tests__/Copilot.test.tsx @@ -0,0 +1,483 @@ +import { render, screen, fireEvent, waitFor } from "@testing-library/react" +import { describe, it, expect, vi, beforeEach, Mock } from "vitest" +import { Copilot } from "../Copilot" +import { useEvent } from "react-use" +import { vscode } from "@src/utils/vscode" +import { useAppTranslation } from "@src/i18n/TranslationContext" +import type { ProviderSettings, ModelInfo } from "@roo-code/types" +import type { OrganizationAllowList } from "@roo/cloud" + +// Mock dependencies +vi.mock("react-use", () => ({ + useEvent: vi.fn(), +})) +vi.mock("@src/utils/vscode", () => ({ + vscode: { + postMessage: vi.fn(), + }, +})) +vi.mock("@src/i18n/TranslationContext", () => ({ + useAppTranslation: vi.fn(), +})) +vi.mock("@src/components/common/VSCodeButtonLink", () => ({ + VSCodeButtonLink: ({ children, href, onClick, ...props }: any) => ( + + {children} + + ), +})) +vi.mock("../ModelPicker", () => ({ + ModelPicker: ({ models, serviceName }: any) => ( +
+ Model Picker for {serviceName}: {Object.keys(models || {}).length} models +
+ ), +})) + +describe("Copilot Component", () => { + const mockUseEvent = useEvent as Mock + const mockPostMessage = vscode.postMessage as Mock + const mockTranslation = useAppTranslation as Mock + + const defaultProps = { + apiConfiguration: { apiProvider: "copilot" } as ProviderSettings, + setApiConfigurationField: vi.fn(), + organizationAllowList: { allowAll: true, providers: {} } as OrganizationAllowList, + modelValidationError: undefined, + } + + const mockT = vi.fn((key: string, options?: any) => { + const translations: Record = { + "settings:providers.copilotAuthentication": "GitHub Copilot Authentication", + "settings:providers.copilotDeviceCodeNotice": + "GitHub Copilot uses OAuth device code flow for secure authentication.", + "settings:providers.authenticating": "Authenticating...", + "settings:providers.waitingForAuth": "Waiting for authentication...", + "settings:providers.authenticateWithGitHub": "Authenticate with GitHub", + "settings:providers.authenticated": "Authenticated", + "settings:providers.clearAuthentication": "Clear Authentication", + "settings:providers.copilotModelDescription": "Allows you to use models on Copilot", + } + return options ? key.replace("{{modelId}}", options.modelId) : translations[key] || key + }) + + beforeEach(() => { + vi.clearAllMocks() + mockTranslation.mockReturnValue({ t: mockT }) + + // Mock the useEvent hook to simulate message handling + let messageHandler: (event: MessageEvent) => void + mockUseEvent.mockImplementation((eventType: string, handler: any) => { + if (eventType === "message") { + messageHandler = handler + } + }) + + // Helper to simulate messages from extension + ;(global as any).simulateMessage = (message: any) => { + if (messageHandler) { + messageHandler({ data: message } as MessageEvent) + } + } + }) + + describe("Authentication Flow", () => { + it("should render authentication section when not authenticated", () => { + render() + + expect(screen.getByText("GitHub Copilot Authentication")).toBeInTheDocument() + expect( + screen.getByText("GitHub Copilot uses OAuth device code flow for secure authentication."), + ).toBeInTheDocument() + expect(screen.getByText("Authenticate with GitHub")).toBeInTheDocument() + }) + + it("should send checkCopilotAuth message on mount", () => { + render() + + expect(mockPostMessage).toHaveBeenCalledWith({ + type: "checkCopilotAuth", + }) + }) + + it("should handle authentication button click", () => { + render() + + const authButton = screen.getByText("Authenticate with GitHub") + fireEvent.click(authButton) + + expect(mockPostMessage).toHaveBeenCalledWith({ + type: "authenticateCopilot", + }) + }) + + it("should show authenticating state", () => { + render() + + const authButton = screen.getByText("Authenticate with GitHub") + fireEvent.click(authButton) + + expect(screen.getByText("Authenticating...")).toBeInTheDocument() + expect(authButton).toBeDisabled() + }) + + it("should display device code information during authentication", async () => { + render() + + // Start authentication + const authButton = screen.getByText("Authenticate with GitHub") + fireEvent.click(authButton) + + // Simulate device code message + ;(global as any).simulateMessage({ + type: "copilotDeviceCode", + copilotDeviceCode: { + user_code: "ABCD-1234", + verification_uri: "https://github.com/login/device", + expires_in: 900, + }, + }) + + await waitFor(() => { + expect(screen.getByText("🔐 GitHub Authentication Required")).toBeInTheDocument() + expect(screen.getByText("ABCD-1234")).toBeInTheDocument() + expect(screen.getByText("https://github.com/login/device")).toBeInTheDocument() + expect(screen.getByText(/Code expires in \d+ minutes/)).toBeInTheDocument() + }) + }) + + it("should handle copy device code button", async () => { + // Mock clipboard API + Object.assign(navigator, { + clipboard: { + writeText: vi.fn(), + }, + }) + + render() + + // Start authentication and get device code + const authButton = screen.getByText("Authenticate with GitHub") + fireEvent.click(authButton) + ;(global as any).simulateMessage({ + type: "copilotDeviceCode", + copilotDeviceCode: { + user_code: "ABCD-1234", + verification_uri: "https://github.com/login/device", + expires_in: 900, + }, + }) + + await waitFor(() => { + const copyButton = screen.getByText("Copy") + fireEvent.click(copyButton) + + expect(navigator.clipboard.writeText).toHaveBeenCalledWith("ABCD-1234") + }) + }) + + it("should show authenticated state", async () => { + render() + + // Simulate successful authentication + ;(global as any).simulateMessage({ + type: "copilotAuthStatus", + copilotAuthenticated: true, + }) + + await waitFor(() => { + expect(screen.getByText("✓")).toBeInTheDocument() + expect(screen.getByText("Authenticated")).toBeInTheDocument() + expect(screen.getByText("Clear Authentication")).toBeInTheDocument() + }) + }) + + it("should handle clear authentication", async () => { + render() + + // First authenticate + ;(global as any).simulateMessage({ + type: "copilotAuthStatus", + copilotAuthenticated: true, + }) + + await waitFor(() => { + const clearButton = screen.getByText("Clear Authentication") + fireEvent.click(clearButton) + + expect(mockPostMessage).toHaveBeenCalledWith({ + type: "clearCopilotAuth", + }) + }) + }) + + it("should display authentication errors", async () => { + render() + ;(global as any).simulateMessage({ + type: "copilotAuthError", + error: "Authentication failed due to network error", + }) + + await waitFor(() => { + expect(screen.getByText("❌ Authentication Failed")).toBeInTheDocument() + expect(screen.getByText("Authentication failed due to network error")).toBeInTheDocument() + }) + }) + }) + + describe("Model Management", () => { + it("should request models when authenticated", async () => { + render() + ;(global as any).simulateMessage({ + type: "copilotAuthStatus", + copilotAuthenticated: true, + }) + + await waitFor(() => { + expect(mockPostMessage).toHaveBeenCalledWith({ + type: "requestCopilotModels", + }) + }) + }) + + it("should display model picker when models are available", async () => { + render() + + // Authenticate first + ;(global as any).simulateMessage({ + type: "copilotAuthStatus", + copilotAuthenticated: true, + }) + + // Provide models + ;(global as any).simulateMessage({ + type: "copilotModels", + copilotModels: { + "claude-4": { + maxTokens: 8192, + contextWindow: 200000, + description: "Claude 4", + } as ModelInfo, + "gpt-4": { + maxTokens: 8192, + contextWindow: 128000, + description: "GPT-4", + } as ModelInfo, + }, + }) + + await waitFor(() => { + expect(screen.getByTestId("model-picker")).toBeInTheDocument() + expect(screen.getByText("Model Picker for Copilot: 2 models")).toBeInTheDocument() + }) + }) + + it("should show description when no models available", async () => { + render() + ;(global as any).simulateMessage({ + type: "copilotAuthStatus", + copilotAuthenticated: true, + }) + ;(global as any).simulateMessage({ + type: "copilotModels", + copilotModels: {}, + }) + + await waitFor(() => { + expect(screen.getByText("Allows you to use models on Copilot")).toBeInTheDocument() + }) + }) + + it("should display model validation errors", async () => { + const propsWithError = { + ...defaultProps, + modelValidationError: "Selected model is not available", + } + + render() + ;(global as any).simulateMessage({ + type: "copilotAuthStatus", + copilotAuthenticated: true, + }) + + await waitFor(() => { + expect(screen.getByText("Selected model is not available")).toBeInTheDocument() + }) + }) + }) + + describe("State Management", () => { + it("should clear device code and error on successful authentication", async () => { + render() + + // First show an error + ;(global as any).simulateMessage({ + type: "copilotAuthError", + error: "Network error", + }) + + await waitFor(() => { + expect(screen.getByText("Network error")).toBeInTheDocument() + }) + + // Then authenticate successfully + ;(global as any).simulateMessage({ + type: "copilotAuthStatus", + copilotAuthenticated: true, + }) + + await waitFor(() => { + expect(screen.queryByText("Network error")).not.toBeInTheDocument() + expect(screen.getByText("Authenticated")).toBeInTheDocument() + }) + }) + + it("should clear models when authentication is lost", async () => { + render() + + // First authenticate and get models + ;(global as any).simulateMessage({ + type: "copilotAuthStatus", + copilotAuthenticated: true, + }) + ;(global as any).simulateMessage({ + type: "copilotModels", + copilotModels: { + "claude-4": { description: "Claude 4" } as ModelInfo, + }, + }) + + await waitFor(() => { + expect(screen.getByTestId("model-picker")).toBeInTheDocument() + }) + + // Then lose authentication + ;(global as any).simulateMessage({ + type: "copilotAuthStatus", + copilotAuthenticated: false, + }) + + await waitFor(() => { + expect(screen.queryByTestId("model-picker")).not.toBeInTheDocument() + expect(screen.getByText("Authenticate with GitHub")).toBeInTheDocument() + }) + }) + + it("should reset states when starting new authentication", () => { + render() + + // Show error first + ;(global as any).simulateMessage({ + type: "copilotAuthError", + error: "Previous error", + }) + + // Start new authentication + const authButton = screen.getByText("Authenticate with GitHub") + fireEvent.click(authButton) + + expect(screen.queryByText("Previous error")).not.toBeInTheDocument() + }) + }) + + describe("UI States", () => { + it("should show waiting for auth when device code is provided", async () => { + render() + + // Start auth and provide device code + const authButton = screen.getByText("Authenticate with GitHub") + fireEvent.click(authButton) + ;(global as any).simulateMessage({ + type: "copilotDeviceCode", + copilotDeviceCode: { + user_code: "ABCD-1234", + verification_uri: "https://github.com/login/device", + expires_in: 900, + }, + }) + + await waitFor(() => { + expect(screen.getByText("Waiting for authentication...")).toBeInTheDocument() + }) + }) + + it("should handle empty models gracefully", async () => { + render() + ;(global as any).simulateMessage({ + type: "copilotAuthStatus", + copilotAuthenticated: true, + }) + ;(global as any).simulateMessage({ + type: "copilotModels", + copilotModels: null, + }) + + await waitFor(() => { + expect(screen.getByText("Allows you to use models on Copilot")).toBeInTheDocument() + }) + }) + + it("should disable authenticate button when authenticating", () => { + render() + + const authButton = screen.getByText("Authenticate with GitHub") as HTMLButtonElement + fireEvent.click(authButton) + + expect(authButton.disabled).toBe(true) + }) + + it("should not request models when not authenticated", async () => { + render() + ;(global as any).simulateMessage({ + type: "copilotAuthStatus", + copilotAuthenticated: false, + }) + + // Wait a bit to ensure no additional calls + await new Promise((resolve) => setTimeout(resolve, 100)) + + expect(mockPostMessage).not.toHaveBeenCalledWith({ + type: "requestCopilotModels", + }) + }) + }) + + describe("Message Handling", () => { + it("should handle multiple message types correctly", async () => { + render() + + // Send multiple messages in sequence + ;(global as any).simulateMessage({ + type: "copilotAuthStatus", + copilotAuthenticated: true, + }) + ;(global as any).simulateMessage({ + type: "copilotModels", + copilotModels: { + "test-model": { description: "Test Model" } as ModelInfo, + }, + }) + ;(global as any).simulateMessage({ + type: "copilotAuthError", + error: "Some error", + }) + + await waitFor(() => { + // Should process all messages appropriately + expect(screen.getByText("Some error")).toBeInTheDocument() + }) + }) + + it("should ignore unknown message types", () => { + render() + + // This should not throw an error + ;(global as any).simulateMessage({ + type: "unknownMessageType", + data: "some data", + }) + + expect(screen.getByText("Authenticate with GitHub")).toBeInTheDocument() + }) + }) +}) diff --git a/webview-ui/src/components/settings/providers/index.ts b/webview-ui/src/components/settings/providers/index.ts index f054780b06e..f54daa1113e 100644 --- a/webview-ui/src/components/settings/providers/index.ts +++ b/webview-ui/src/components/settings/providers/index.ts @@ -22,6 +22,7 @@ export { SambaNova } from "./SambaNova" export { Unbound } from "./Unbound" export { Vertex } from "./Vertex" export { VSCodeLM } from "./VSCodeLM" +export { Copilot } from "./Copilot" export { XAI } from "./XAI" export { ZAi } from "./ZAi" export { LiteLLM } from "./LiteLLM" diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index b188f3e342f..adaaf699c4e 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -296,6 +296,11 @@ function getSelectedModel({ routerModels["io-intelligence"]?.[id] ?? ioIntelligenceModels[id as keyof typeof ioIntelligenceModels] return { id, info } } + case "copilot": { + const id = apiConfiguration.copilotModelId ?? "gpt-4.1" + const info = routerModels.copilot[id] + return { id, info } + } // case "anthropic": // case "human-relay": // case "fake-ai": diff --git a/webview-ui/src/i18n/locales/ca/settings.json b/webview-ui/src/i18n/locales/ca/settings.json index 59647986848..367e84b2eca 100644 --- a/webview-ui/src/i18n/locales/ca/settings.json +++ b/webview-ui/src/i18n/locales/ca/settings.json @@ -306,6 +306,14 @@ "apiKey": "Clau API", "openAiBaseUrl": "URL base", "getOpenAiApiKey": "Obtenir clau API d'OpenAI", + "copilotAuthentication": "Autenticació de GitHub Copilot", + "copilotDeviceCodeNotice": "GitHub Copilot utilitza el flux de codi de dispositiu OAuth per a l'autenticació segura. Seràs redirigit a GitHub per autoritzar l'accés.", + "copilotModelDescription": "Et permet utilitzar models a Copilot", + "authenticating": "Autenticant...", + "waitingForAuth": "Esperant autenticació...", + "authenticateWithGitHub": "Autenticar amb GitHub", + "authenticated": "Autenticat", + "clearAuthentication": "Esborrar autenticació", "mistralApiKey": "Clau API de Mistral", "getMistralApiKey": "Obtenir clau API de Mistral / Codestral", "codestralBaseUrl": "URL base de Codestral (opcional)", diff --git a/webview-ui/src/i18n/locales/de/settings.json b/webview-ui/src/i18n/locales/de/settings.json index 28f3e847fa2..257e824e0a6 100644 --- a/webview-ui/src/i18n/locales/de/settings.json +++ b/webview-ui/src/i18n/locales/de/settings.json @@ -306,6 +306,14 @@ "apiKey": "API-Schlüssel", "openAiBaseUrl": "Basis-URL", "getOpenAiApiKey": "OpenAI API-Schlüssel erhalten", + "copilotAuthentication": "GitHub Copilot Authentifizierung", + "copilotDeviceCodeNotice": "GitHub Copilot verwendet den OAuth-Gerätecode-Flow für sichere Authentifizierung. Du wirst zu GitHub weitergeleitet, um den Zugriff zu autorisieren.", + "copilotModelDescription": "Ermöglicht dir, Modelle auf Copilot zu verwenden", + "authenticating": "Authentifizierung läuft...", + "waitingForAuth": "Warte auf Authentifizierung...", + "authenticateWithGitHub": "Mit GitHub authentifizieren", + "authenticated": "Authentifiziert", + "clearAuthentication": "Authentifizierung löschen", "mistralApiKey": "Mistral API-Schlüssel", "getMistralApiKey": "Mistral / Codestral API-Schlüssel erhalten", "codestralBaseUrl": "Codestral Basis-URL (Optional)", diff --git a/webview-ui/src/i18n/locales/en/settings.json b/webview-ui/src/i18n/locales/en/settings.json index fca3d1ade97..f3c09701a23 100644 --- a/webview-ui/src/i18n/locales/en/settings.json +++ b/webview-ui/src/i18n/locales/en/settings.json @@ -305,6 +305,14 @@ "apiKey": "API Key", "openAiBaseUrl": "Base URL", "getOpenAiApiKey": "Get OpenAI API Key", + "copilotAuthentication": "GitHub Copilot Authentication", + "copilotDeviceCodeNotice": "GitHub Copilot uses OAuth device code flow for secure authentication. You'll be redirected to GitHub to authorize access.", + "copilotModelDescription": "Allows you to use models on Copilot", + "authenticating": "Authenticating...", + "waitingForAuth": "Waiting for authentication...", + "authenticateWithGitHub": "Authenticate with GitHub", + "authenticated": "Authenticated", + "clearAuthentication": "Clear Authentication", "mistralApiKey": "Mistral API Key", "getMistralApiKey": "Get Mistral / Codestral API Key", "codestralBaseUrl": "Codestral Base URL (Optional)", diff --git a/webview-ui/src/i18n/locales/es/settings.json b/webview-ui/src/i18n/locales/es/settings.json index fd6e1e27150..f1c1722aa02 100644 --- a/webview-ui/src/i18n/locales/es/settings.json +++ b/webview-ui/src/i18n/locales/es/settings.json @@ -306,6 +306,14 @@ "apiKey": "Clave API", "openAiBaseUrl": "URL base", "getOpenAiApiKey": "Obtener clave API de OpenAI", + "copilotAuthentication": "Autenticación de GitHub Copilot", + "copilotDeviceCodeNotice": "GitHub Copilot utiliza el flujo de código de dispositivo OAuth para autenticación segura. Serás redirigido a GitHub para autorizar el acceso.", + "copilotModelDescription": "Te permite usar modelos en Copilot", + "authenticating": "Autenticando...", + "waitingForAuth": "Esperando autenticación...", + "authenticateWithGitHub": "Autenticar con GitHub", + "authenticated": "Autenticado", + "clearAuthentication": "Limpiar autenticación", "mistralApiKey": "Clave API de Mistral", "getMistralApiKey": "Obtener clave API de Mistral / Codestral", "codestralBaseUrl": "URL base de Codestral (Opcional)", diff --git a/webview-ui/src/i18n/locales/fr/settings.json b/webview-ui/src/i18n/locales/fr/settings.json index 451e51d0840..a3d5ada8f5f 100644 --- a/webview-ui/src/i18n/locales/fr/settings.json +++ b/webview-ui/src/i18n/locales/fr/settings.json @@ -306,6 +306,14 @@ "apiKey": "Clé API", "openAiBaseUrl": "URL de base", "getOpenAiApiKey": "Obtenir la clé API OpenAI", + "copilotAuthentication": "Authentification GitHub Copilot", + "copilotDeviceCodeNotice": "GitHub Copilot utilise le flux de code d'appareil OAuth pour une authentification sécurisée. Tu seras redirigé vers GitHub pour autoriser l'accès.", + "copilotModelDescription": "Te permet d'utiliser les modèles sur Copilot", + "authenticating": "Authentification en cours...", + "waitingForAuth": "En attente d'authentification...", + "authenticateWithGitHub": "S'authentifier avec GitHub", + "authenticated": "Authentifié", + "clearAuthentication": "Effacer l'authentification", "mistralApiKey": "Clé API Mistral", "getMistralApiKey": "Obtenir la clé API Mistral / Codestral", "codestralBaseUrl": "URL de base Codestral (Optionnel)", diff --git a/webview-ui/src/i18n/locales/hi/settings.json b/webview-ui/src/i18n/locales/hi/settings.json index 83a4b7b81b8..ccd21419630 100644 --- a/webview-ui/src/i18n/locales/hi/settings.json +++ b/webview-ui/src/i18n/locales/hi/settings.json @@ -306,6 +306,14 @@ "apiKey": "API कुंजी", "openAiBaseUrl": "बेस URL", "getOpenAiApiKey": "OpenAI API कुंजी प्राप्त करें", + "copilotAuthentication": "GitHub Copilot प्रमाणीकरण", + "copilotDeviceCodeNotice": "GitHub Copilot सुरक्षित प्रमाणीकरण के लिए OAuth डिवाइस कोड फ्लो का उपयोग करता है। आपको पहुंच को अधिकृत करने के लिए GitHub पर रीडायरेक्ट किया जाएगा।", + "copilotModelDescription": "आपको Copilot पर मॉडल का उपयोग करने की अनुमति देता है", + "authenticating": "प्रमाणीकरण हो रहा है...", + "waitingForAuth": "प्रमाणीकरण की प्रतीक्षा में...", + "authenticateWithGitHub": "GitHub के साथ प्रमाणीकरण करें", + "authenticated": "प्रमाणित", + "clearAuthentication": "प्रमाणीकरण साफ़ करें", "mistralApiKey": "Mistral API कुंजी", "getMistralApiKey": "Mistral / Codestral API कुंजी प्राप्त करें", "codestralBaseUrl": "Codestral बेस URL (वैकल्पिक)", diff --git a/webview-ui/src/i18n/locales/id/settings.json b/webview-ui/src/i18n/locales/id/settings.json index fd7edf40cf6..07052d8a383 100644 --- a/webview-ui/src/i18n/locales/id/settings.json +++ b/webview-ui/src/i18n/locales/id/settings.json @@ -310,6 +310,14 @@ "apiKey": "API Key", "openAiBaseUrl": "Base URL", "getOpenAiApiKey": "Dapatkan OpenAI API Key", + "copilotAuthentication": "Autentikasi GitHub Copilot", + "copilotDeviceCodeNotice": "GitHub Copilot menggunakan alur kode perangkat OAuth untuk autentikasi yang aman. Kamu akan diarahkan ke GitHub untuk mengotorisasi akses.", + "copilotModelDescription": "Memungkinkan kamu menggunakan model di Copilot", + "authenticating": "Mengautentikasi...", + "waitingForAuth": "Menunggu autentikasi...", + "authenticateWithGitHub": "Autentikasi dengan GitHub", + "authenticated": "Terautentikasi", + "clearAuthentication": "Hapus Autentikasi", "mistralApiKey": "Mistral API Key", "getMistralApiKey": "Dapatkan Mistral / Codestral API Key", "codestralBaseUrl": "Codestral Base URL (Opsional)", diff --git a/webview-ui/src/i18n/locales/it/settings.json b/webview-ui/src/i18n/locales/it/settings.json index c8a8800e4fd..bd7a737116d 100644 --- a/webview-ui/src/i18n/locales/it/settings.json +++ b/webview-ui/src/i18n/locales/it/settings.json @@ -306,6 +306,14 @@ "apiKey": "Chiave API", "openAiBaseUrl": "URL base", "getOpenAiApiKey": "Ottieni chiave API OpenAI", + "copilotAuthentication": "Autenticazione GitHub Copilot", + "copilotDeviceCodeNotice": "GitHub Copilot utilizza il flusso di codice del dispositivo OAuth per l'autenticazione sicura. Sarai reindirizzato a GitHub per autorizzare l'accesso.", + "copilotModelDescription": "Ti permette di utilizzare modelli su Copilot", + "authenticating": "Autenticazione in corso...", + "waitingForAuth": "In attesa di autenticazione...", + "authenticateWithGitHub": "Autentica con GitHub", + "authenticated": "Autenticato", + "clearAuthentication": "Cancella autenticazione", "mistralApiKey": "Chiave API Mistral", "getMistralApiKey": "Ottieni chiave API Mistral / Codestral", "codestralBaseUrl": "URL base Codestral (opzionale)", diff --git a/webview-ui/src/i18n/locales/ja/settings.json b/webview-ui/src/i18n/locales/ja/settings.json index d3e7b04baf1..601dd58ce1a 100644 --- a/webview-ui/src/i18n/locales/ja/settings.json +++ b/webview-ui/src/i18n/locales/ja/settings.json @@ -306,6 +306,14 @@ "apiKey": "APIキー", "openAiBaseUrl": "ベースURL", "getOpenAiApiKey": "OpenAI APIキーを取得", + "copilotAuthentication": "GitHub Copilot認証", + "copilotDeviceCodeNotice": "GitHub Copilotは、安全な認証のためにOAuthデバイスコードフローを使用します。アクセスを承認するためにGitHubにリダイレクトされます。", + "copilotModelDescription": "Copilotでモデルを使用できます", + "authenticating": "認証中...", + "waitingForAuth": "認証を待機中...", + "authenticateWithGitHub": "GitHubで認証", + "authenticated": "認証済み", + "clearAuthentication": "認証をクリア", "mistralApiKey": "Mistral APIキー", "getMistralApiKey": "Mistral / Codestral APIキーを取得", "codestralBaseUrl": "Codestral ベースURL(オプション)", diff --git a/webview-ui/src/i18n/locales/ko/settings.json b/webview-ui/src/i18n/locales/ko/settings.json index a5bcd1f3859..073c65d9bec 100644 --- a/webview-ui/src/i18n/locales/ko/settings.json +++ b/webview-ui/src/i18n/locales/ko/settings.json @@ -306,6 +306,14 @@ "openAiApiKey": "OpenAI API 키", "openAiBaseUrl": "기본 URL", "getOpenAiApiKey": "OpenAI API 키 받기", + "copilotAuthentication": "GitHub Copilot 인증", + "copilotDeviceCodeNotice": "GitHub Copilot은 안전한 인증을 위해 OAuth 장치 코드 플로우를 사용합니다. 액세스를 승인하기 위해 GitHub로 리디렉션됩니다.", + "copilotModelDescription": "Copilot에서 모델을 사용할 수 있게 해줍니다", + "authenticating": "인증 중...", + "waitingForAuth": "인증 대기 중...", + "authenticateWithGitHub": "GitHub로 인증", + "authenticated": "인증됨", + "clearAuthentication": "인증 지우기", "mistralApiKey": "Mistral API 키", "getMistralApiKey": "Mistral / Codestral API 키 받기", "codestralBaseUrl": "Codestral 기본 URL (선택사항)", diff --git a/webview-ui/src/i18n/locales/nl/settings.json b/webview-ui/src/i18n/locales/nl/settings.json index b54e021be0a..9fd27bdaa54 100644 --- a/webview-ui/src/i18n/locales/nl/settings.json +++ b/webview-ui/src/i18n/locales/nl/settings.json @@ -306,6 +306,14 @@ "openAiApiKey": "OpenAI API-sleutel", "openAiBaseUrl": "Basis-URL", "getOpenAiApiKey": "OpenAI API-sleutel ophalen", + "copilotAuthentication": "GitHub Copilot authenticatie", + "copilotDeviceCodeNotice": "GitHub Copilot gebruikt OAuth apparaatcode flow voor veilige authenticatie. Je wordt doorverwezen naar GitHub om toegang te autoriseren.", + "copilotModelDescription": "Stelt je in staat om modellen op Copilot te gebruiken", + "authenticating": "Aan het authenticeren...", + "waitingForAuth": "Wachten op authenticatie...", + "authenticateWithGitHub": "Authenticeren met GitHub", + "authenticated": "Geauthenticeerd", + "clearAuthentication": "Authenticatie wissen", "mistralApiKey": "Mistral API-sleutel", "getMistralApiKey": "Mistral / Codestral API-sleutel ophalen", "codestralBaseUrl": "Codestral basis-URL (optioneel)", diff --git a/webview-ui/src/i18n/locales/pl/settings.json b/webview-ui/src/i18n/locales/pl/settings.json index 194bd9029d0..31280e292b4 100644 --- a/webview-ui/src/i18n/locales/pl/settings.json +++ b/webview-ui/src/i18n/locales/pl/settings.json @@ -306,6 +306,14 @@ "openAiApiKey": "Klucz API OpenAI", "openAiBaseUrl": "URL bazowy", "getOpenAiApiKey": "Uzyskaj klucz API OpenAI", + "copilotAuthentication": "Uwierzytelnianie GitHub Copilot", + "copilotDeviceCodeNotice": "GitHub Copilot używa przepływu kodu urządzenia OAuth dla bezpiecznego uwierzytelniania. Zostaniesz przekierowany do GitHub, aby autoryzować dostęp.", + "copilotModelDescription": "Pozwala ci używać modeli w Copilot", + "authenticating": "Uwierzytelnianie...", + "waitingForAuth": "Oczekiwanie na uwierzytelnienie...", + "authenticateWithGitHub": "Uwierzytelnij z GitHub", + "authenticated": "Uwierzytelniony", + "clearAuthentication": "Wyczyść uwierzytelnianie", "mistralApiKey": "Klucz API Mistral", "getMistralApiKey": "Uzyskaj klucz API Mistral / Codestral", "codestralBaseUrl": "URL bazowy Codestral (opcjonalnie)", diff --git a/webview-ui/src/i18n/locales/pt-BR/settings.json b/webview-ui/src/i18n/locales/pt-BR/settings.json index 7879dd51543..16d85609ed1 100644 --- a/webview-ui/src/i18n/locales/pt-BR/settings.json +++ b/webview-ui/src/i18n/locales/pt-BR/settings.json @@ -306,6 +306,14 @@ "openAiApiKey": "Chave de API OpenAI", "openAiBaseUrl": "URL Base", "getOpenAiApiKey": "Obter chave de API OpenAI", + "copilotAuthentication": "Autenticação do GitHub Copilot", + "copilotDeviceCodeNotice": "O GitHub Copilot usa o fluxo de código de dispositivo OAuth para autenticação segura. Você será redirecionado para o GitHub para autorizar o acesso.", + "copilotModelDescription": "Permite que você use modelos no Copilot", + "authenticating": "Autenticando...", + "waitingForAuth": "Aguardando autenticação...", + "authenticateWithGitHub": "Autenticar com GitHub", + "authenticated": "Autenticado", + "clearAuthentication": "Limpar autenticação", "mistralApiKey": "Chave de API Mistral", "getMistralApiKey": "Obter chave de API Mistral / Codestral", "codestralBaseUrl": "URL Base Codestral (Opcional)", diff --git a/webview-ui/src/i18n/locales/ru/settings.json b/webview-ui/src/i18n/locales/ru/settings.json index 3744ecedffa..b1a59825b87 100644 --- a/webview-ui/src/i18n/locales/ru/settings.json +++ b/webview-ui/src/i18n/locales/ru/settings.json @@ -306,6 +306,14 @@ "openAiApiKey": "OpenAI API-ключ", "openAiBaseUrl": "Базовый URL", "getOpenAiApiKey": "Получить OpenAI API-ключ", + "copilotAuthentication": "Аутентификация GitHub Copilot", + "copilotDeviceCodeNotice": "GitHub Copilot использует поток кода устройства OAuth для безопасной аутентификации. Ты будешь перенаправлен на GitHub для авторизации доступа.", + "copilotModelDescription": "Позволяет тебе использовать модели в Copilot", + "authenticating": "Аутентификация...", + "waitingForAuth": "Ожидание аутентификации...", + "authenticateWithGitHub": "Аутентифицироваться с GitHub", + "authenticated": "Аутентифицирован", + "clearAuthentication": "Очистить аутентификацию", "mistralApiKey": "Mistral API-ключ", "getMistralApiKey": "Получить Mistral / Codestral API-ключ", "codestralBaseUrl": "Базовый URL Codestral (опционально)", diff --git a/webview-ui/src/i18n/locales/tr/settings.json b/webview-ui/src/i18n/locales/tr/settings.json index be1a7a01692..289efac4218 100644 --- a/webview-ui/src/i18n/locales/tr/settings.json +++ b/webview-ui/src/i18n/locales/tr/settings.json @@ -306,6 +306,14 @@ "apiKey": "API Anahtarı", "openAiBaseUrl": "Temel URL", "getOpenAiApiKey": "OpenAI API Anahtarı Al", + "copilotAuthentication": "GitHub Copilot Kimlik Doğrulama", + "copilotDeviceCodeNotice": "GitHub Copilot güvenli kimlik doğrulama için OAuth cihaz kodu akışını kullanır. Erişimi yetkilendirmek için GitHub'a yönlendirileceksin.", + "copilotModelDescription": "Copilot'ta modelleri kullanmanı sağlar", + "authenticating": "Kimlik doğrulanıyor...", + "waitingForAuth": "Kimlik doğrulama bekleniyor...", + "authenticateWithGitHub": "GitHub ile kimlik doğrula", + "authenticated": "Kimlik doğrulandı", + "clearAuthentication": "Kimlik doğrulamayı temizle", "mistralApiKey": "Mistral API Anahtarı", "getMistralApiKey": "Mistral / Codestral API Anahtarı Al", "codestralBaseUrl": "Codestral Temel URL (İsteğe bağlı)", diff --git a/webview-ui/src/i18n/locales/vi/settings.json b/webview-ui/src/i18n/locales/vi/settings.json index 7526cf31a74..5c194ed61b6 100644 --- a/webview-ui/src/i18n/locales/vi/settings.json +++ b/webview-ui/src/i18n/locales/vi/settings.json @@ -306,6 +306,14 @@ "apiKey": "Khóa API", "openAiBaseUrl": "URL cơ sở", "getOpenAiApiKey": "Lấy khóa API OpenAI", + "copilotAuthentication": "Xác thực GitHub Copilot", + "copilotDeviceCodeNotice": "GitHub Copilot sử dụng luồng mã thiết bị OAuth để xác thực an toàn. Bạn sẽ được chuyển hướng đến GitHub để ủy quyền truy cập.", + "copilotModelDescription": "Cho phép bạn sử dụng các mô hình trên Copilot", + "authenticating": "Đang xác thực...", + "waitingForAuth": "Đang chờ xác thực...", + "authenticateWithGitHub": "Xác thực với GitHub", + "authenticated": "Đã xác thực", + "clearAuthentication": "Xóa xác thực", "mistralApiKey": "Khóa API Mistral", "getMistralApiKey": "Lấy khóa API Mistral / Codestral", "codestralBaseUrl": "URL cơ sở Codestral (Tùy chọn)", diff --git a/webview-ui/src/i18n/locales/zh-CN/settings.json b/webview-ui/src/i18n/locales/zh-CN/settings.json index c1985e74906..a9d5a82c1a2 100644 --- a/webview-ui/src/i18n/locales/zh-CN/settings.json +++ b/webview-ui/src/i18n/locales/zh-CN/settings.json @@ -306,6 +306,14 @@ "apiKey": "API 密钥", "openAiBaseUrl": "OpenAI 基础 URL", "getOpenAiApiKey": "获取 OpenAI API 密钥", + "copilotAuthentication": "GitHub Copilot 身份验证", + "copilotDeviceCodeNotice": "GitHub Copilot 使用 OAuth 设备代码流进行安全身份验证。你将被重定向到 GitHub 以授权访问。", + "copilotModelDescription": "允许你在 Copilot 上使用模型", + "authenticating": "正在验证身份...", + "waitingForAuth": "等待身份验证...", + "authenticateWithGitHub": "使用 GitHub 验证身份", + "authenticated": "已验证身份", + "clearAuthentication": "清除身份验证", "mistralApiKey": "Mistral API 密钥", "getMistralApiKey": "获取 Mistral / Codestral API 密钥", "codestralBaseUrl": "Codestral 基础 URL(可选)", diff --git a/webview-ui/src/i18n/locales/zh-TW/settings.json b/webview-ui/src/i18n/locales/zh-TW/settings.json index 6a673de60ac..403e3379f03 100644 --- a/webview-ui/src/i18n/locales/zh-TW/settings.json +++ b/webview-ui/src/i18n/locales/zh-TW/settings.json @@ -306,6 +306,14 @@ "apiKey": "API 金鑰", "openAiBaseUrl": "基礎 URL", "getOpenAiApiKey": "取得 OpenAI API 金鑰", + "copilotAuthentication": "GitHub Copilot 身份驗證", + "copilotDeviceCodeNotice": "GitHub Copilot 使用 OAuth 裝置代碼流程進行安全身份驗證。你將被重新導向至 GitHub 以授權存取。", + "copilotModelDescription": "允許你在 Copilot 上使用模型", + "authenticating": "正在驗證身份...", + "waitingForAuth": "等待身份驗證...", + "authenticateWithGitHub": "使用 GitHub 驗證身份", + "authenticated": "已驗證身份", + "clearAuthentication": "清除身份驗證", "mistralApiKey": "Mistral API 金鑰", "getMistralApiKey": "取得 Mistral/Codestral API 金鑰", "codestralBaseUrl": "Codestral 基礎 URL(選用)", diff --git a/webview-ui/src/utils/__tests__/validate.test.ts b/webview-ui/src/utils/__tests__/validate.test.ts index f14fb892004..cb7709545a0 100644 --- a/webview-ui/src/utils/__tests__/validate.test.ts +++ b/webview-ui/src/utils/__tests__/validate.test.ts @@ -41,6 +41,7 @@ describe("Model Validation Functions", () => { ollama: {}, lmstudio: {}, "io-intelligence": {}, + copilot: {}, } const allowAllOrganization: OrganizationAllowList = { diff --git a/webview-ui/src/utils/validate.ts b/webview-ui/src/utils/validate.ts index e4c9ed483da..5fa1b25b117 100644 --- a/webview-ui/src/utils/validate.ts +++ b/webview-ui/src/utils/validate.ts @@ -194,6 +194,8 @@ function getModelIdForProvider(apiConfiguration: ProviderSettings, provider: str return apiConfiguration.huggingFaceModelId case "io-intelligence": return apiConfiguration.ioIntelligenceModelId + case "copilot": + return apiConfiguration.copilotModelId default: return apiConfiguration.apiModelId } @@ -267,6 +269,9 @@ export function validateModelId(apiConfiguration: ProviderSettings, routerModels case "io-intelligence": modelId = apiConfiguration.ioIntelligenceModelId break + case "copilot": + modelId = apiConfiguration.copilotModelId + break } if (!modelId) {