diff --git a/packages/types/src/task.ts b/packages/types/src/task.ts index 3f741fc6dd..f3c4a579f4 100644 --- a/packages/types/src/task.ts +++ b/packages/types/src/task.ts @@ -3,6 +3,8 @@ import { z } from "zod" import { RooCodeEventName } from "./events.js" import { type ClineMessage, type TokenUsage } from "./message.js" import { type ToolUsage, type ToolName } from "./tool.js" +import { type Experiments } from "./experiment.js" +import type { TodoItem } from "./todo.js" import type { StaticAppProperties, GitProperties, TelemetryProperties } from "./telemetry.js" /** @@ -13,6 +15,15 @@ export interface TaskProviderState { mode?: string } +export interface CreateTaskOptions { + modeSlug?: string + enableDiff?: boolean + enableCheckpoints?: boolean + fuzzyMatchThreshold?: number + consecutiveMistakeLimit?: number + experiments?: Experiments + initialTodos?: TodoItem[] +} export interface TaskProviderLike { readonly cwd: string readonly appProperties: StaticAppProperties @@ -22,7 +33,7 @@ export interface TaskProviderLike { getCurrentTaskStack(): string[] getRecentTasks(): string[] - createTask(text?: string, images?: string[], parentTask?: TaskLike): Promise + createTask(text?: string, images?: string[], parentTask?: TaskLike, options?: CreateTaskOptions): Promise cancelTask(): Promise clearTask(): Promise resumeTask(taskId: string): void @@ -89,9 +100,10 @@ export interface TaskLike { on(event: K, listener: (...args: TaskEvents[K]) => void | Promise): this off(event: K, listener: (...args: TaskEvents[K]) => void | Promise): this + setMessageResponse(text: string, images?: string[]): void approveAsk(options?: { text?: string; images?: string[] }): void denyAsk(options?: { text?: string; images?: string[] }): void - submitUserMessage(text: string, images?: string[]): void + submitUserMessage(text: string, images?: string[], modeSlug?: string): void abortTask(): void } diff --git a/src/api/providers/featherless.ts b/src/api/providers/featherless.ts index 56d7177de7..2a985e2a87 100644 --- a/src/api/providers/featherless.ts +++ b/src/api/providers/featherless.ts @@ -1,4 +1,9 @@ -import { DEEP_SEEK_DEFAULT_TEMPERATURE, type FeatherlessModelId, featherlessDefaultModelId, featherlessModels } from "@roo-code/types" +import { + DEEP_SEEK_DEFAULT_TEMPERATURE, + type FeatherlessModelId, + featherlessDefaultModelId, + featherlessModels, +} from "@roo-code/types" import { Anthropic } from "@anthropic-ai/sdk" import OpenAI from "openai" diff --git a/src/core/task/Task.ts b/src/core/task/Task.ts index 104cb87206..2b08eede77 100644 --- a/src/core/task/Task.ts +++ b/src/core/task/Task.ts @@ -49,7 +49,7 @@ import { t } from "../../i18n" import { ClineApiReqCancelReason, ClineApiReqInfo } from "../../shared/ExtensionMessage" import { getApiMetrics } from "../../shared/getApiMetrics" import { ClineAskResponse } from "../../shared/WebviewMessage" -import { defaultModeSlug } from "../../shared/modes" +import { defaultModeSlug, getModeBySlug } from "../../shared/modes" import { DiffStrategy } from "../../shared/tools" import { EXPERIMENT_IDS, experiments } from "../../shared/experiments" import { getModelMaxOutputTokens } from "../../shared/api" @@ -838,22 +838,52 @@ export class Task extends EventEmitter implements TaskLike { this.handleWebviewAskResponse("noButtonClicked", text, images) } - public submitUserMessage(text: string, images?: string[]): void { + public submitUserMessage(text: string, images?: string[], modeSlug?: string): void { try { text = (text ?? "").trim() images = images ?? [] - if (text.length === 0 && images.length === 0) { - return - } - const provider = this.providerRef.deref() - if (provider) { - provider.postMessageToWebview({ type: "invoke", invoke: "sendMessage", text, images }) - } else { - console.error("[Task#submitUserMessage] Provider reference lost") - } + // Run asynchronously to allow awaiting mode switch before sending the message + void (async () => { + // If a mode slug is provided, handle the mode switch first (same behavior as createTask) + try { + const modeSlugValue = (modeSlug ?? "").trim() + if (modeSlugValue.length > 0 && provider) { + const customModes = await provider.customModesManager.getCustomModes() + const targetMode = getModeBySlug(modeSlugValue, customModes) + if (targetMode) { + await provider.handleModeSwitch(targetMode.slug) + provider.log(`[Task#submitUserMessage] Applied mode from modeSlug: '${targetMode.slug}'`) + } else { + provider.log(`[Task#submitUserMessage] Ignoring invalid modeSlug: '${modeSlugValue}'.`) + } + } + } catch (err) { + provider?.log( + `[Task#submitUserMessage] Failed to apply modeSlug: ${ + err instanceof Error ? err.message : String(err) + }`, + ) + } + + // If there's no content to send, exit after potential mode switch + if (text.length === 0 && images.length === 0) { + return + } + + if (provider) { + await provider.postMessageToWebview({ + type: "invoke", + invoke: "sendMessage", + text, + images, + }) + } else { + console.error("[Task#submitUserMessage] Provider reference lost") + } + })() } catch (error) { console.error("[Task#submitUserMessage] Failed to submit user message:", error) } diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index b6c00ee379..f12f01ff1e 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -29,6 +29,7 @@ import { type TerminalActionId, type TerminalActionPromptType, type HistoryItem, + type CreateTaskOptions, type ClineAsk, RooCodeEventName, requestyDefaultModelId, @@ -81,7 +82,7 @@ import { forceFullModelDetailsLoad, hasLoadedFullDetails } from "../../api/provi import { ContextProxy } from "../config/ContextProxy" import { ProviderSettingsManager } from "../config/ProviderSettingsManager" import { CustomModesManager } from "../config/CustomModesManager" -import { Task, TaskOptions } from "../task/Task" +import { Task } from "../task/Task" import { getSystemPromptFilePath } from "../prompts/sections/custom-system-prompt" import { webviewMessageHandler } from "./webviewMessageHandler" @@ -750,22 +751,7 @@ export class ClineProvider // from the stack and the caller is resumed in this way we can have a chain // of tasks, each one being a sub task of the previous one until the main // task is finished. - public async createTask( - text?: string, - images?: string[], - parentTask?: Task, - options: Partial< - Pick< - TaskOptions, - | "enableDiff" - | "enableCheckpoints" - | "fuzzyMatchThreshold" - | "consecutiveMistakeLimit" - | "experiments" - | "initialTodos" - > - > = {}, - ) { + public async createTask(text?: string, images?: string[], parentTask?: Task, options: CreateTaskOptions = {}) { const { apiConfiguration, organizationAllowList, @@ -781,6 +767,30 @@ export class ClineProvider throw new OrganizationAllowListViolationError(t("common:errors.violated_organization_allowlist")) } + // Initializes task with the correct mode and associated provider profile + try { + const modeSlugFromBridge: string | undefined = (options as any)?.modeSlug + if (typeof modeSlugFromBridge === "string" && modeSlugFromBridge.trim().length > 0) { + const customModes = await this.customModesManager.getCustomModes() + const targetMode = getModeBySlug(modeSlugFromBridge, customModes) + if (targetMode) { + // Switch provider/global mode first so Task reads it during initialization + await this.handleModeSwitch(targetMode.slug) + this.log(`[createTask] Applied mode from bridge: '${targetMode.slug}'`) + } else { + this.log( + `[createTask] Ignoring invalid modeSlug from bridge: '${modeSlugFromBridge}'. Falling back to current mode.`, + ) + } + } + } catch (err) { + this.log( + `[createTask] Failed to apply modeSlug from bridge: ${ + err instanceof Error ? err.message : String(err) + }`, + ) + } + const task = new Task({ provider: this, apiConfiguration, diff --git a/src/core/webview/__tests__/ClineProvider.spec.ts b/src/core/webview/__tests__/ClineProvider.spec.ts index 80c2f537a2..76be66aa00 100644 --- a/src/core/webview/__tests__/ClineProvider.spec.ts +++ b/src/core/webview/__tests__/ClineProvider.spec.ts @@ -2208,6 +2208,84 @@ describe("ClineProvider", () => { }) }) }) +describe("Bridge modeSlug handling", () => { + let provider: ClineProvider + let mockContext: vscode.ExtensionContext + let mockOutputChannel: vscode.OutputChannel + let mockWebviewView: vscode.WebviewView + + beforeEach(() => { + vi.clearAllMocks() + + mockContext = { + extensionPath: "/test/path", + extensionUri: {} as vscode.Uri, + globalState: { + get: vi.fn(), + update: vi.fn(), + keys: vi.fn().mockReturnValue([]), + }, + secrets: { + get: vi.fn(), + store: vi.fn(), + delete: vi.fn(), + }, + subscriptions: [], + extension: { + packageJSON: { version: "1.0.0" }, + }, + globalStorageUri: { + fsPath: "/test/storage/path", + }, + } as unknown as vscode.ExtensionContext + + mockOutputChannel = { + appendLine: vi.fn(), + clear: vi.fn(), + dispose: vi.fn(), + } as unknown as vscode.OutputChannel + + mockWebviewView = { + webview: { + postMessage: vi.fn(), + html: "", + options: {}, + onDidReceiveMessage: vi.fn(), + asWebviewUri: vi.fn(), + cspSource: "vscode-webview://test-csp-source", + }, + visible: true, + onDidDispose: vi.fn(), + onDidChangeVisibility: vi.fn(), + } as unknown as vscode.WebviewView + + provider = new ClineProvider(mockContext, mockOutputChannel, "sidebar", new ContextProxy(mockContext)) + }) + + it("applies modeSlug from bridge options when starting task", async () => { + await provider.resolveWebviewView(mockWebviewView) + + // Spy on handleModeSwitch to ensure it's invoked with the bridge-provided mode + const handleModeSwitchSpy = vi.spyOn(provider, "handleModeSwitch").mockResolvedValue(undefined as any) + + // Ensure getModeBySlug returns a valid mode for the provided slug + const { getModeBySlug } = await import("../../../shared/modes") + vi.mocked(getModeBySlug).mockReturnValueOnce({ + slug: "architect", + name: "Architect Mode", + roleDefinition: "You are an architect", + groups: ["read", "edit"] as any, + } as any) + + // Pass modeSlug through the options object (as provided by the bridge package) + await provider.createTask("Started from bridge", undefined, undefined, { + experiments: {}, + modeSlug: "architect", + } as any) + + expect(handleModeSwitchSpy).toHaveBeenCalledWith("architect") + }) +}) describe("Project MCP Settings", () => { let provider: ClineProvider diff --git a/src/extension/api.ts b/src/extension/api.ts index 2fc51e7afb..315ad9304a 100644 --- a/src/extension/api.ts +++ b/src/extension/api.ts @@ -108,11 +108,13 @@ export class API extends EventEmitter implements RooCodeAPI { text, images, newTab, + modeSlug, }: { configuration: RooCodeSettings text?: string images?: string[] newTab?: boolean + modeSlug?: string }) { let provider: ClineProvider @@ -161,6 +163,7 @@ export class API extends EventEmitter implements RooCodeAPI { const cline = await provider.createTask(text, images, undefined, { consecutiveMistakeLimit: Number.MAX_SAFE_INTEGER, + modeSlug, }) if (!cline) {