Skip to content

Commit 47855ed

Browse files
ctehannesrudolph
authored andcommitted
Mode and provider profile selector (#7545)
1 parent 0126507 commit 47855ed

File tree

10 files changed

+436
-357
lines changed

10 files changed

+436
-357
lines changed

packages/cloud/src/bridge/BaseChannel.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ export abstract class BaseChannel<TCommand = unknown, TEventName extends string
8383
/**
8484
* Handle incoming commands - must be implemented by subclasses.
8585
*/
86-
public abstract handleCommand(command: TCommand): void
86+
public abstract handleCommand(command: TCommand): Promise<void>
8787

8888
/**
8989
* Handle connection-specific logic.

packages/cloud/src/bridge/ExtensionChannel.ts

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,7 @@ export class ExtensionChannel extends BaseChannel<
5353
this.setupListeners()
5454
}
5555

56-
/**
57-
* Handle extension-specific commands from the web app
58-
*/
59-
public handleCommand(command: ExtensionBridgeCommand): void {
56+
public async handleCommand(command: ExtensionBridgeCommand): Promise<void> {
6057
if (command.instanceId !== this.instanceId) {
6158
console.log(`[ExtensionChannel] command -> instance id mismatch | ${this.instanceId}`, {
6259
messageInstanceId: command.instanceId,
@@ -69,13 +66,22 @@ export class ExtensionChannel extends BaseChannel<
6966
console.log(`[ExtensionChannel] command -> createTask() | ${command.instanceId}`, {
7067
text: command.payload.text?.substring(0, 100) + "...",
7168
hasImages: !!command.payload.images,
69+
mode: command.payload.mode,
70+
providerProfile: command.payload.providerProfile,
7271
})
7372

74-
this.provider.createTask(command.payload.text, command.payload.images)
73+
this.provider.createTask(
74+
command.payload.text,
75+
command.payload.images,
76+
undefined, // parentTask
77+
undefined, // options
78+
{ mode: command.payload.mode, currentApiConfigName: command.payload.providerProfile },
79+
)
80+
7581
break
7682
}
7783
case ExtensionBridgeCommandName.StopTask: {
78-
const instance = this.updateInstance()
84+
const instance = await this.updateInstance()
7985

8086
if (instance.task.taskStatus === TaskStatus.Running) {
8187
console.log(`[ExtensionChannel] command -> cancelTask() | ${command.instanceId}`)
@@ -86,14 +92,14 @@ export class ExtensionChannel extends BaseChannel<
8692
this.provider.clearTask()
8793
this.provider.postStateToWebview()
8894
}
95+
8996
break
9097
}
9198
case ExtensionBridgeCommandName.ResumeTask: {
9299
console.log(`[ExtensionChannel] command -> resumeTask() | ${command.instanceId}`, {
93100
taskId: command.payload.taskId,
94101
})
95102

96-
// Resume the task from history by taskId
97103
this.provider.resumeTask(command.payload.taskId)
98104
this.provider.postStateToWebview()
99105
break
@@ -122,20 +128,20 @@ export class ExtensionChannel extends BaseChannel<
122128
}
123129

124130
private async registerInstance(_socket: Socket): Promise<void> {
125-
const instance = this.updateInstance()
131+
const instance = await this.updateInstance()
126132
await this.publish(ExtensionSocketEvents.REGISTER, instance)
127133
}
128134

129135
private async unregisterInstance(_socket: Socket): Promise<void> {
130-
const instance = this.updateInstance()
136+
const instance = await this.updateInstance()
131137
await this.publish(ExtensionSocketEvents.UNREGISTER, instance)
132138
}
133139

134140
private startHeartbeat(socket: Socket): void {
135141
this.stopHeartbeat()
136142

137143
this.heartbeatInterval = setInterval(async () => {
138-
const instance = this.updateInstance()
144+
const instance = await this.updateInstance()
139145

140146
try {
141147
socket.emit(ExtensionSocketEvents.HEARTBEAT, instance)
@@ -172,11 +178,11 @@ export class ExtensionChannel extends BaseChannel<
172178
] as const
173179

174180
eventMapping.forEach(({ from, to }) => {
175-
// Create and store the listener function for cleanup/
176-
const listener = (..._args: unknown[]) => {
181+
// Create and store the listener function for cleanup.
182+
const listener = async (..._args: unknown[]) => {
177183
this.publish(ExtensionSocketEvents.EVENT, {
178184
type: to,
179-
instance: this.updateInstance(),
185+
instance: await this.updateInstance(),
180186
timestamp: Date.now(),
181187
})
182188
}
@@ -195,10 +201,16 @@ export class ExtensionChannel extends BaseChannel<
195201
this.eventListeners.clear()
196202
}
197203

198-
private updateInstance(): ExtensionInstance {
204+
private async updateInstance(): Promise<ExtensionInstance> {
199205
const task = this.provider?.getCurrentTask()
200206
const taskHistory = this.provider?.getRecentTasks() ?? []
201207

208+
const mode = await this.provider?.getMode()
209+
const modes = (await this.provider?.getModes()) ?? []
210+
211+
const providerProfile = await this.provider?.getProviderProfile()
212+
const providerProfiles = (await this.provider?.getProviderProfiles()) ?? []
213+
202214
this.extensionInstance = {
203215
...this.extensionInstance,
204216
appProperties: this.extensionInstance.appProperties ?? this.provider.appProperties,
@@ -213,6 +225,10 @@ export class ExtensionChannel extends BaseChannel<
213225
: { taskId: "", taskStatus: TaskStatus.None },
214226
taskAsk: task?.taskAsk,
215227
taskHistory,
228+
mode,
229+
providerProfile,
230+
modes,
231+
providerProfiles,
216232
}
217233

218234
return this.extensionInstance

packages/cloud/src/bridge/TaskChannel.ts

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ export class TaskChannel extends BaseChannel<
7373
super(instanceId)
7474
}
7575

76-
public handleCommand(command: TaskBridgeCommand): void {
76+
public async handleCommand(command: TaskBridgeCommand): Promise<void> {
7777
const task = this.subscribedTasks.get(command.taskId)
7878

7979
if (!task) {
@@ -87,14 +87,22 @@ export class TaskChannel extends BaseChannel<
8787
`[TaskChannel] ${TaskBridgeCommandName.Message} ${command.taskId} -> submitUserMessage()`,
8888
command,
8989
)
90-
task.submitUserMessage(command.payload.text, command.payload.images)
90+
91+
await task.submitUserMessage(
92+
command.payload.text,
93+
command.payload.images,
94+
command.payload.mode,
95+
command.payload.providerProfile,
96+
)
97+
9198
break
9299

93100
case TaskBridgeCommandName.ApproveAsk:
94101
console.log(
95102
`[TaskChannel] ${TaskBridgeCommandName.ApproveAsk} ${command.taskId} -> approveAsk()`,
96103
command,
97104
)
105+
98106
task.approveAsk(command.payload)
99107
break
100108

packages/cloud/src/bridge/__tests__/ExtensionChannel.test.ts

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ describe("ExtensionChannel", () => {
5353
postStateToWebview: vi.fn(),
5454
postMessageToWebview: vi.fn(),
5555
getTelemetryProperties: vi.fn(),
56+
getMode: vi.fn().mockResolvedValue("code"),
57+
getModes: vi.fn().mockResolvedValue([
58+
{ slug: "code", name: "Code", description: "Code mode" },
59+
{ slug: "architect", name: "Architect", description: "Architect mode" },
60+
]),
61+
getProviderProfile: vi.fn().mockResolvedValue("default"),
62+
getProviderProfiles: vi.fn().mockResolvedValue([{ name: "default", description: "Default profile" }]),
5663
on: vi.fn((event: keyof TaskProviderEvents, listener: (...args: unknown[]) => unknown) => {
5764
if (!eventListeners.has(event)) {
5865
eventListeners.set(event, new Set())
@@ -184,6 +191,9 @@ describe("ExtensionChannel", () => {
184191
// Connect the socket to enable publishing
185192
await extensionChannel.onConnect(mockSocket)
186193

194+
// Clear the mock calls from the connection (which emits a register event)
195+
;(mockSocket.emit as any).mockClear()
196+
187197
// Get a listener that was registered for TaskStarted
188198
const taskStartedListeners = eventListeners.get(RooCodeEventName.TaskStarted)
189199
expect(taskStartedListeners).toBeDefined()
@@ -192,7 +202,7 @@ describe("ExtensionChannel", () => {
192202
// Trigger the listener
193203
const listener = Array.from(taskStartedListeners!)[0]
194204
if (listener) {
195-
listener("test-task-id")
205+
await listener("test-task-id")
196206
}
197207

198208
// Verify the event was published to the socket

packages/cloud/src/bridge/__tests__/TaskChannel.test.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,12 @@ describe("TaskChannel", () => {
333333

334334
taskChannel.handleCommand(command)
335335

336-
expect(mockTask.submitUserMessage).toHaveBeenCalledWith(command.payload.text, command.payload.images)
336+
expect(mockTask.submitUserMessage).toHaveBeenCalledWith(
337+
command.payload.text,
338+
command.payload.images,
339+
undefined,
340+
undefined,
341+
)
337342
})
338343

339344
it("should handle ApproveAsk command", () => {

packages/types/npm/package.metadata.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@roo-code/types",
3-
"version": "1.65.0",
3+
"version": "1.66.0",
44
"description": "TypeScript type definitions for Roo Code.",
55
"publishConfig": {
66
"access": "public",

packages/types/src/cloud.ts

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,10 @@ export const extensionInstanceSchema = z.object({
378378
task: extensionTaskSchema,
379379
taskAsk: clineMessageSchema.optional(),
380380
taskHistory: z.array(z.string()),
381+
mode: z.string().optional(),
382+
modes: z.array(z.object({ slug: z.string(), name: z.string() })).optional(),
383+
providerProfile: z.string().optional(),
384+
providerProfiles: z.array(z.object({ name: z.string(), provider: z.string().optional() })).optional(),
381385
})
382386

383387
export type ExtensionInstance = z.infer<typeof extensionInstanceSchema>
@@ -398,6 +402,9 @@ export enum ExtensionBridgeEventName {
398402
TaskResumable = RooCodeEventName.TaskResumable,
399403
TaskIdle = RooCodeEventName.TaskIdle,
400404

405+
ModeChanged = RooCodeEventName.ModeChanged,
406+
ProviderProfileChanged = RooCodeEventName.ProviderProfileChanged,
407+
401408
InstanceRegistered = "instance_registered",
402409
InstanceUnregistered = "instance_unregistered",
403410
HeartbeatUpdated = "heartbeat_updated",
@@ -469,6 +476,18 @@ export const extensionBridgeEventSchema = z.discriminatedUnion("type", [
469476
instance: extensionInstanceSchema,
470477
timestamp: z.number(),
471478
}),
479+
z.object({
480+
type: z.literal(ExtensionBridgeEventName.ModeChanged),
481+
instance: extensionInstanceSchema,
482+
mode: z.string(),
483+
timestamp: z.number(),
484+
}),
485+
z.object({
486+
type: z.literal(ExtensionBridgeEventName.ProviderProfileChanged),
487+
instance: extensionInstanceSchema,
488+
providerProfile: z.object({ name: z.string(), provider: z.string().optional() }),
489+
timestamp: z.number(),
490+
}),
472491
])
473492

474493
export type ExtensionBridgeEvent = z.infer<typeof extensionBridgeEventSchema>
@@ -490,6 +509,8 @@ export const extensionBridgeCommandSchema = z.discriminatedUnion("type", [
490509
payload: z.object({
491510
text: z.string(),
492511
images: z.array(z.string()).optional(),
512+
mode: z.string().optional(),
513+
providerProfile: z.string().optional(),
493514
}),
494515
timestamp: z.number(),
495516
}),
@@ -502,9 +523,7 @@ export const extensionBridgeCommandSchema = z.discriminatedUnion("type", [
502523
z.object({
503524
type: z.literal(ExtensionBridgeCommandName.ResumeTask),
504525
instanceId: z.string(),
505-
payload: z.object({
506-
taskId: z.string(),
507-
}),
526+
payload: z.object({ taskId: z.string() }),
508527
timestamp: z.number(),
509528
}),
510529
])
@@ -558,6 +577,8 @@ export const taskBridgeCommandSchema = z.discriminatedUnion("type", [
558577
payload: z.object({
559578
text: z.string(),
560579
images: z.array(z.string()).optional(),
580+
mode: z.string().optional(),
581+
providerProfile: z.string().optional(),
561582
}),
562583
timestamp: z.number(),
563584
}),

packages/types/src/task.ts

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { z } from "zod"
22

33
import { RooCodeEventName } from "./events.js"
44
import type { RooCodeSettings } from "./global-settings.js"
5-
import type { ClineMessage, QueuedMessage, TokenUsage } from "./message.js"
5+
import type { ClineMessage, TokenUsage } from "./message.js"
66
import type { ToolUsage, ToolName } from "./tool.js"
77
import type { StaticAppProperties, GitProperties, TelemetryProperties } from "./telemetry.js"
88
import type { TodoItem } from "./todo.js"
@@ -59,6 +59,8 @@ export interface TaskProviderLike {
5959

6060
export type TaskProviderEvents = {
6161
[RooCodeEventName.TaskCreated]: [task: TaskLike]
62+
63+
// Proxied from the Task EventEmitter.
6264
[RooCodeEventName.TaskStarted]: [taskId: string]
6365
[RooCodeEventName.TaskCompleted]: [taskId: string, tokenUsage: TokenUsage, toolUsage: ToolUsage]
6466
[RooCodeEventName.TaskAborted]: [taskId: string]
@@ -68,21 +70,14 @@ export type TaskProviderEvents = {
6870
[RooCodeEventName.TaskInteractive]: [taskId: string]
6971
[RooCodeEventName.TaskResumable]: [taskId: string]
7072
[RooCodeEventName.TaskIdle]: [taskId: string]
71-
[RooCodeEventName.TaskPaused]: [taskId: string]
72-
[RooCodeEventName.TaskUnpaused]: [taskId: string]
7373
[RooCodeEventName.TaskSpawned]: [taskId: string]
74-
75-
[RooCodeEventName.TaskUserMessage]: [taskId: string]
76-
77-
[RooCodeEventName.TaskTokenUsageUpdated]: [taskId: string, tokenUsage: TokenUsage]
78-
7974
[RooCodeEventName.ModeChanged]: [mode: string]
8075
[RooCodeEventName.ProviderProfileChanged]: [config: { name: string; provider?: string }]
8176
}
8277

8378
/**
84-
* TaskLike
85-
*/
79+
* TaskLike
80+
*/
8681

8782
export interface CreateTaskOptions {
8883
enableDiff?: boolean
@@ -110,14 +105,11 @@ export type TaskMetadata = z.infer<typeof taskMetadataSchema>
110105

111106
export interface TaskLike {
112107
readonly taskId: string
113-
readonly rootTaskId?: string
114-
readonly parentTaskId?: string
115-
readonly childTaskId?: string
116-
readonly metadata: TaskMetadata
117108
readonly taskStatus: TaskStatus
118109
readonly taskAsk: ClineMessage | undefined
119-
readonly queuedMessages: QueuedMessage[]
120-
readonly tokenUsage: TokenUsage | undefined
110+
readonly metadata: TaskMetadata
111+
112+
readonly rootTask?: TaskLike
121113

122114
on<K extends keyof TaskEvents>(event: K, listener: (...args: TaskEvents[K]) => void | Promise<void>): this
123115
off<K extends keyof TaskEvents>(event: K, listener: (...args: TaskEvents[K]) => void | Promise<void>): this
@@ -141,15 +133,14 @@ export type TaskEvents = {
141133
[RooCodeEventName.TaskIdle]: [taskId: string]
142134

143135
// Subtask Lifecycle
144-
[RooCodeEventName.TaskPaused]: [taskId: string]
145-
[RooCodeEventName.TaskUnpaused]: [taskId: string]
136+
[RooCodeEventName.TaskPaused]: []
137+
[RooCodeEventName.TaskUnpaused]: []
146138
[RooCodeEventName.TaskSpawned]: [taskId: string]
147139

148140
// Task Execution
149141
[RooCodeEventName.Message]: [{ action: "created" | "updated"; message: ClineMessage }]
150142
[RooCodeEventName.TaskModeSwitched]: [taskId: string, mode: string]
151143
[RooCodeEventName.TaskAskResponded]: []
152-
[RooCodeEventName.TaskUserMessage]: [taskId: string]
153144

154145
// Task Analytics
155146
[RooCodeEventName.TaskToolFailed]: [taskId: string, tool: ToolName, error: string]

0 commit comments

Comments
 (0)