diff --git a/packages/a2a-server/src/http/app.test.ts b/packages/a2a-server/src/http/app.test.ts index f427bdfe635..f5aff574cd5 100644 --- a/packages/a2a-server/src/http/app.test.ts +++ b/packages/a2a-server/src/http/app.test.ts @@ -42,6 +42,17 @@ import type { Command, CommandContext } from '../commands/types.js'; const mockToolConfirmationFn = async () => ({}) as unknown as ToolCallConfirmationDetails; +interface CoderAgentMetadata { + kind?: string; +} + +interface ToolCallData { + status?: string; + request?: { + callId: string; + }; +} + const streamToSSEEvents = ( stream: string, ): SendStreamingMessageSuccessResponse[] => @@ -145,24 +156,27 @@ describe('E2E Tests', () => { assertTaskCreationAndWorkingStatus(events); // Status update: text-content - const textContentEvent = events[2].result as TaskStatusUpdateEvent; - expect(textContentEvent.kind).toBe('status-update'); + const textContentEvent = events.find( + (e) => + (e.result?.metadata?.['coderAgent'] as CoderAgentMetadata)?.kind === + 'text-content', + )?.result as TaskStatusUpdateEvent; + expect(textContentEvent).toBeDefined(); expect(textContentEvent.status.state).toBe('working'); - expect(textContentEvent.metadata?.['coderAgent']).toMatchObject({ - kind: 'text-content', - }); expect(textContentEvent.status.message?.parts).toMatchObject([ { kind: 'text', text: 'Hello how are you?' }, ]); // Status update: input-required (final) - const finalEvent = events[3].result as TaskStatusUpdateEvent; - expect(finalEvent.kind).toBe('status-update'); - expect(finalEvent.status?.state).toBe('input-required'); + const finalEvent = events.find( + (e) => + (e.result as { status?: { state: string } })?.status?.state === + 'input-required', + )?.result as TaskStatusUpdateEvent; + expect(finalEvent).toBeDefined(); expect(finalEvent.final).toBe(true); assertUniqueFinalEventIsLast(events); - expect(events.length).toBe(4); }); it('should create a new task, schedule a tool call, and wait for approval', async () => { @@ -205,36 +219,34 @@ describe('E2E Tests', () => { const events = streamToSSEEvents(res.text); assertTaskCreationAndWorkingStatus(events); - // Status update: working - const workingEvent2 = events[2].result as TaskStatusUpdateEvent; - expect(workingEvent2.kind).toBe('status-update'); - expect(workingEvent2.status.state).toBe('working'); - expect(workingEvent2.metadata?.['coderAgent']).toMatchObject({ - kind: 'state-change', - }); - - // Status update: tool-call-update - const toolCallUpdateEvent = events[3].result as TaskStatusUpdateEvent; - expect(toolCallUpdateEvent.kind).toBe('status-update'); + // Status update: tool-call-update (validating) + const toolCallUpdateEvent = events.find( + (e) => + (e.result?.metadata?.['coderAgent'] as CoderAgentMetadata)?.kind === + 'tool-call-update' && + ( + (e.result as TaskStatusUpdateEvent).status.message?.parts[0] as { + data?: ToolCallData; + } + )?.data?.status === 'validating', + )?.result as TaskStatusUpdateEvent; + expect(toolCallUpdateEvent).toBeDefined(); expect(toolCallUpdateEvent.status.state).toBe('working'); - expect(toolCallUpdateEvent.metadata?.['coderAgent']).toMatchObject({ - kind: 'tool-call-update', - }); expect(toolCallUpdateEvent.status.message?.parts).toMatchObject([ { data: { - status: 'validating', request: { callId: 'test-call-id' }, }, }, ]); // State update: awaiting_approval update - const toolCallConfirmationEvent = events[4].result as TaskStatusUpdateEvent; - expect(toolCallConfirmationEvent.kind).toBe('status-update'); - expect(toolCallConfirmationEvent.metadata?.['coderAgent']).toMatchObject({ - kind: 'tool-call-confirmation', - }); + const toolCallConfirmationEvent = events.find( + (e) => + (e.result?.metadata?.['coderAgent'] as CoderAgentMetadata)?.kind === + 'tool-call-confirmation', + )?.result as TaskStatusUpdateEvent; + expect(toolCallConfirmationEvent).toBeDefined(); expect(toolCallConfirmationEvent.status.message?.parts).toMatchObject([ { data: { @@ -246,7 +258,6 @@ describe('E2E Tests', () => { expect(toolCallConfirmationEvent.status?.state).toBe('working'); assertUniqueFinalEventIsLast(events); - expect(events.length).toBe(6); }); it('should handle multiple tool calls in a single turn', async () => { @@ -312,47 +323,55 @@ describe('E2E Tests', () => { const events = streamToSSEEvents(res.text); assertTaskCreationAndWorkingStatus(events); - // Second working update - const workingEvent = events[2].result as TaskStatusUpdateEvent; - expect(workingEvent.kind).toBe('status-update'); - expect(workingEvent.status.state).toBe('working'); - - // State Update: Validate the first tool call - const toolCallValidateEvent1 = events[3].result as TaskStatusUpdateEvent; - expect(toolCallValidateEvent1.metadata?.['coderAgent']).toMatchObject({ - kind: 'tool-call-update', - }); - expect(toolCallValidateEvent1.status.message?.parts).toMatchObject([ - { - data: { - status: 'validating', - request: { callId: 'test-call-id-1' }, - }, - }, - ]); - - // --- Assert the event stream --- - // 1. Initial "submitted" status. - expect((events[0].result as TaskStatusUpdateEvent).status.state).toBe( - 'submitted', + const findToolCallUpdateEvent = (callId: string, status: string) => + events.find( + (e) => + (e.result?.metadata?.['coderAgent'] as CoderAgentMetadata)?.kind === + 'tool-call-update' && + ( + (e.result as TaskStatusUpdateEvent).status.message?.parts[0] as { + data?: ToolCallData; + } + )?.data?.request?.callId === callId && + ( + (e.result as TaskStatusUpdateEvent).status.message?.parts[0] as { + data?: ToolCallData; + } + )?.data?.status === status, + )?.result as TaskStatusUpdateEvent; + + const findToolCallConfirmationEvent = (callId: string, status: string) => + events.find( + (e) => + (e.result?.metadata?.['coderAgent'] as CoderAgentMetadata)?.kind === + 'tool-call-confirmation' && + ( + (e.result as TaskStatusUpdateEvent).status.message?.parts[0] as { + data?: ToolCallData; + } + )?.data?.request?.callId === callId && + ( + (e.result as TaskStatusUpdateEvent).status.message?.parts[0] as { + data?: ToolCallData; + } + )?.data?.status === status, + )?.result as TaskStatusUpdateEvent; + + // A "state-change" event from the agent. + const stateChangeEvent = events.find( + (e) => + (e.result?.metadata?.['coderAgent'] as CoderAgentMetadata)?.kind === + 'state-change', + )?.result as TaskStatusUpdateEvent; + expect(stateChangeEvent).toBeDefined(); + + // Tool 1 is validating. + const toolCallValidateEvent1 = findToolCallUpdateEvent( + 'test-call-id-1', + 'validating', ); - - // 2. "working" status after receiving the user prompt. - expect((events[1].result as TaskStatusUpdateEvent).status.state).toBe( - 'working', - ); - - // 3. A "state-change" event from the agent. - expect(events[2].result.metadata?.['coderAgent']).toMatchObject({ - kind: 'state-change', - }); - - // 4. Tool 1 is validating. - const toolCallUpdate1 = events[3].result as TaskStatusUpdateEvent; - expect(toolCallUpdate1.metadata?.['coderAgent']).toMatchObject({ - kind: 'tool-call-update', - }); - expect(toolCallUpdate1.status.message?.parts).toMatchObject([ + expect(toolCallValidateEvent1).toBeDefined(); + expect(toolCallValidateEvent1.status.message?.parts).toMatchObject([ { data: { request: { callId: 'test-call-id-1' }, @@ -361,12 +380,13 @@ describe('E2E Tests', () => { }, ]); - // 5. Tool 2 is validating. - const toolCallUpdate2 = events[4].result as TaskStatusUpdateEvent; - expect(toolCallUpdate2.metadata?.['coderAgent']).toMatchObject({ - kind: 'tool-call-update', - }); - expect(toolCallUpdate2.status.message?.parts).toMatchObject([ + // Tool 2 is validating. + const toolCallValidateEvent2 = findToolCallUpdateEvent( + 'test-call-id-2', + 'validating', + ); + expect(toolCallValidateEvent2).toBeDefined(); + expect(toolCallValidateEvent2.status.message?.parts).toMatchObject([ { data: { request: { callId: 'test-call-id-2' }, @@ -375,11 +395,12 @@ describe('E2E Tests', () => { }, ]); - // 6. Tool 1 is awaiting approval. - const toolCallAwaitEvent = events[5].result as TaskStatusUpdateEvent; - expect(toolCallAwaitEvent.metadata?.['coderAgent']).toMatchObject({ - kind: 'tool-call-confirmation', - }); + // Tool 1 is awaiting approval. + const toolCallAwaitEvent = findToolCallConfirmationEvent( + 'test-call-id-1', + 'awaiting_approval', + ); + expect(toolCallAwaitEvent).toBeDefined(); expect(toolCallAwaitEvent.status.message?.parts).toMatchObject([ { data: { @@ -389,14 +410,15 @@ describe('E2E Tests', () => { }, ]); - // 7. The final event is "input-required". - const finalEvent = events[6].result as TaskStatusUpdateEvent; - expect(finalEvent.final).toBe(true); + // The final event is "input-required". + const finalEvent = events.find( + (e) => (e.result as { final?: boolean })?.final, + )?.result as TaskStatusUpdateEvent; + expect(finalEvent).toBeDefined(); expect(finalEvent.status.state).toBe('input-required'); // The scheduler now waits for approval, so no more events are sent. assertUniqueFinalEventIsLast(events); - expect(events.length).toBe(7); }); it('should handle multiple tool calls sequentially in YOLO mode', async () => { @@ -474,10 +496,10 @@ describe('E2E Tests', () => { // --- Assert the sequential execution flow --- const eventStream = events.slice(2).map((e) => { const update = e.result as TaskStatusUpdateEvent; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const agentData = update.metadata?.['coderAgent'] as any; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const toolData = update.status.message?.parts[0] as any; + const agentData = update.metadata?.['coderAgent'] as { kind: string }; + const toolData = update.status.message?.parts[0] as { + data?: { status: string; request: { callId: string } }; + }; if (!toolData) { return { kind: agentData.kind }; } @@ -593,83 +615,44 @@ describe('E2E Tests', () => { const events = streamToSSEEvents(res.text); assertTaskCreationAndWorkingStatus(events); - // Status update: working - const workingEvent2 = events[2].result as TaskStatusUpdateEvent; - expect(workingEvent2.kind).toBe('status-update'); - expect(workingEvent2.status.state).toBe('working'); - - // Status update: tool-call-update (validating) - const validatingEvent = events[3].result as TaskStatusUpdateEvent; - expect(validatingEvent.metadata?.['coderAgent']).toMatchObject({ - kind: 'tool-call-update', - }); - expect(validatingEvent.status.message?.parts).toMatchObject([ - { - data: { - status: 'validating', - request: { callId: 'test-call-id-no-approval' }, - }, - }, - ]); - - // Status update: tool-call-update (scheduled) - const scheduledEvent = events[4].result as TaskStatusUpdateEvent; - expect(scheduledEvent.metadata?.['coderAgent']).toMatchObject({ - kind: 'tool-call-update', - }); - expect(scheduledEvent.status.message?.parts).toMatchObject([ - { - data: { - status: 'scheduled', - request: { callId: 'test-call-id-no-approval' }, - }, - }, - ]); - - // Status update: tool-call-update (executing) - const executingEvent = events[5].result as TaskStatusUpdateEvent; - expect(executingEvent.metadata?.['coderAgent']).toMatchObject({ - kind: 'tool-call-update', - }); - expect(executingEvent.status.message?.parts).toMatchObject([ - { - data: { - status: 'executing', - request: { callId: 'test-call-id-no-approval' }, - }, - }, - ]); - - // Status update: tool-call-update (success) - const successEvent = events[6].result as TaskStatusUpdateEvent; - expect(successEvent.metadata?.['coderAgent']).toMatchObject({ - kind: 'tool-call-update', - }); + const findToolEvent = (status: string) => + events.find( + (e) => + (e.result?.metadata?.['coderAgent'] as CoderAgentMetadata)?.kind === + 'tool-call-update' && + ( + (e.result as TaskStatusUpdateEvent).status.message?.parts[0] as { + data?: ToolCallData; + } + )?.data?.status === status, + )?.result as TaskStatusUpdateEvent; + + // Status update: tool-call-update transitions + expect(findToolEvent('validating')).toBeDefined(); + expect(findToolEvent('scheduled')).toBeDefined(); + expect(findToolEvent('executing')).toBeDefined(); + const successEvent = findToolEvent('success'); + expect(successEvent).toBeDefined(); expect(successEvent.status.message?.parts).toMatchObject([ { data: { - status: 'success', request: { callId: 'test-call-id-no-approval' }, }, }, ]); - // Status update: working (before sending tool result to LLM) - const workingEvent3 = events[7].result as TaskStatusUpdateEvent; - expect(workingEvent3.kind).toBe('status-update'); - expect(workingEvent3.status.state).toBe('working'); - // Status update: text-content (final LLM response) - const textContentEvent = events[8].result as TaskStatusUpdateEvent; - expect(textContentEvent.metadata?.['coderAgent']).toMatchObject({ - kind: 'text-content', - }); + const textContentEvent = events.find( + (e) => + (e.result?.metadata?.['coderAgent'] as CoderAgentMetadata)?.kind === + 'text-content', + )?.result as TaskStatusUpdateEvent; + expect(textContentEvent).toBeDefined(); expect(textContentEvent.status.message?.parts).toMatchObject([ { text: 'Tool executed successfully.' }, ]); assertUniqueFinalEventIsLast(events); - expect(events.length).toBe(10); }); it('should bypass tool approval in YOLO mode', async () => { @@ -724,83 +707,44 @@ describe('E2E Tests', () => { const events = streamToSSEEvents(res.text); assertTaskCreationAndWorkingStatus(events); - // Status update: working - const workingEvent2 = events[2].result as TaskStatusUpdateEvent; - expect(workingEvent2.kind).toBe('status-update'); - expect(workingEvent2.status.state).toBe('working'); - - // Status update: tool-call-update (validating) - const validatingEvent = events[3].result as TaskStatusUpdateEvent; - expect(validatingEvent.metadata?.['coderAgent']).toMatchObject({ - kind: 'tool-call-update', - }); - expect(validatingEvent.status.message?.parts).toMatchObject([ - { - data: { - status: 'validating', - request: { callId: 'test-call-id-yolo' }, - }, - }, - ]); - - // Status update: tool-call-update (scheduled) - const awaitingEvent = events[4].result as TaskStatusUpdateEvent; - expect(awaitingEvent.metadata?.['coderAgent']).toMatchObject({ - kind: 'tool-call-update', - }); - expect(awaitingEvent.status.message?.parts).toMatchObject([ - { - data: { - status: 'scheduled', - request: { callId: 'test-call-id-yolo' }, - }, - }, - ]); - - // Status update: tool-call-update (executing) - const executingEvent = events[5].result as TaskStatusUpdateEvent; - expect(executingEvent.metadata?.['coderAgent']).toMatchObject({ - kind: 'tool-call-update', - }); - expect(executingEvent.status.message?.parts).toMatchObject([ - { - data: { - status: 'executing', - request: { callId: 'test-call-id-yolo' }, - }, - }, - ]); - - // Status update: tool-call-update (success) - const successEvent = events[6].result as TaskStatusUpdateEvent; - expect(successEvent.metadata?.['coderAgent']).toMatchObject({ - kind: 'tool-call-update', - }); + const findToolEvent = (status: string) => + events.find( + (e) => + (e.result?.metadata?.['coderAgent'] as CoderAgentMetadata)?.kind === + 'tool-call-update' && + ( + (e.result as TaskStatusUpdateEvent).status.message?.parts[0] as { + data?: ToolCallData; + } + )?.data?.status === status, + )?.result as TaskStatusUpdateEvent; + + // Status update: tool-call-update transitions + expect(findToolEvent('validating')).toBeDefined(); + expect(findToolEvent('scheduled')).toBeDefined(); + expect(findToolEvent('executing')).toBeDefined(); + const successEvent = findToolEvent('success'); + expect(successEvent).toBeDefined(); expect(successEvent.status.message?.parts).toMatchObject([ { data: { - status: 'success', request: { callId: 'test-call-id-yolo' }, }, }, ]); - // Status update: working (before sending tool result to LLM) - const workingEvent3 = events[7].result as TaskStatusUpdateEvent; - expect(workingEvent3.kind).toBe('status-update'); - expect(workingEvent3.status.state).toBe('working'); - // Status update: text-content (final LLM response) - const textContentEvent = events[8].result as TaskStatusUpdateEvent; - expect(textContentEvent.metadata?.['coderAgent']).toMatchObject({ - kind: 'text-content', - }); + const textContentEvent = events.find( + (e) => + (e.result?.metadata?.['coderAgent'] as CoderAgentMetadata)?.kind === + 'text-content', + )?.result as TaskStatusUpdateEvent; + expect(textContentEvent).toBeDefined(); expect(textContentEvent.status.message?.parts).toMatchObject([ { text: 'Tool executed successfully.' }, ]); assertUniqueFinalEventIsLast(events); - expect(events.length).toBe(10); }); it('should include traceId in status updates when available', async () => { @@ -821,13 +765,22 @@ describe('E2E Tests', () => { const events = streamToSSEEvents(res.text); - // The first two events are task-creation and working status - const textContentEvent = events[2].result as TaskStatusUpdateEvent; - expect(textContentEvent.kind).toBe('status-update'); + // The first two events are task-creation and working status. + // We look for the text-content and thought events via their metadata/message parts. + const textContentEvent = events.find( + (e) => + (e.result?.metadata?.['coderAgent'] as { kind?: string })?.kind === + 'text-content', + )?.result as TaskStatusUpdateEvent; + expect(textContentEvent).toBeDefined(); expect(textContentEvent.metadata?.['traceId']).toBe(traceId); - const thoughtEvent = events[3].result as TaskStatusUpdateEvent; - expect(thoughtEvent.kind).toBe('status-update'); + const thoughtEvent = events.find( + (e) => + (e.result?.metadata?.['coderAgent'] as CoderAgentMetadata)?.kind === + 'thought', + )?.result as TaskStatusUpdateEvent; + expect(thoughtEvent).toBeDefined(); expect(thoughtEvent.metadata?.['traceId']).toBe(traceId); }); diff --git a/packages/cli/src/ui/hooks/useToolScheduler.test.ts b/packages/cli/src/ui/hooks/useToolScheduler.test.ts index 1ffaa61cc7b..130f4269105 100644 --- a/packages/cli/src/ui/hooks/useToolScheduler.test.ts +++ b/packages/cli/src/ui/hooks/useToolScheduler.test.ts @@ -220,9 +220,11 @@ describe('useReactToolScheduler', () => { | undefined; const advanceAndSettle = async () => { - await act(async () => { - await vi.advanceTimersByTimeAsync(0); - }); + for (let i = 0; i < 7; i++) { + await act(async () => { + await vi.advanceTimersByTimeAsync(0); + }); + } }; const scheduleAndWaitForExecution = async ( @@ -237,8 +239,6 @@ describe('useReactToolScheduler', () => { }); await advanceAndSettle(); - await advanceAndSettle(); - await advanceAndSettle(); }; beforeEach(() => { @@ -350,6 +350,12 @@ describe('useReactToolScheduler', () => { schedule(newRequest, new AbortController().signal); }); + // Wait for the async schedule operation to update state + for (let i = 0; i < 50; i++) { + if (result.current[0].length === 1) break; + await advanceAndSettle(); + } + // After scheduling, the old call should be gone, // and the new one should be in the display in its initial state. expect(result.current[0].length).toBe(1); @@ -481,6 +487,12 @@ describe('useReactToolScheduler', () => { await scheduleAndWaitForExecution(result.current[1], request); + // Poll for completion + for (let i = 0; i < 50; i++) { + if (completedToolCalls.length === 1) break; + await advanceAndSettle(); + } + expect(completedToolCalls).toHaveLength(1); expect(completedToolCalls[0].status).toBe('error'); expect(completedToolCalls[0].request).toBe(request); diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts index 22ef939a624..04d4d2eb9fe 100644 --- a/packages/core/src/core/coreToolScheduler.test.ts +++ b/packages/core/src/core/coreToolScheduler.test.ts @@ -750,6 +750,17 @@ describe('CoreToolScheduler with payload', () => { ); } + // After internal update, the tool should be awaiting approval again with the NEW content. + const updatedAwaitingCall = (await waitForStatus( + onToolCallsUpdate, + 'awaiting_approval', + )) as WaitingToolCall; + + // Now confirm for real to execute. + await updatedAwaitingCall.confirmationDetails.onConfirm( + ToolConfirmationOutcome.ProceedOnce, + ); + // Wait for the tool execution to complete await vi.waitFor(() => { expect(onAllToolCallsComplete).toHaveBeenCalled(); @@ -979,7 +990,11 @@ describe('CoreToolScheduler YOLO mode', () => { .map((call) => (call[0][0] as ToolCall)?.status) .filter(Boolean); expect(statusUpdates).not.toContain('awaiting_approval'); - expect(statusUpdates).toEqual([ + // Expect the sequence of states, ignoring duplicates + const uniqueStatusUpdates = statusUpdates.filter( + (status, index, self) => index === 0 || status !== self[index - 1], + ); + expect(uniqueStatusUpdates).toEqual([ 'validating', 'scheduled', 'executing', @@ -1196,7 +1211,11 @@ describe('CoreToolScheduler request queueing', () => { .map((call) => (call[0][0] as ToolCall)?.status) .filter(Boolean); expect(statusUpdates).not.toContain('awaiting_approval'); - expect(statusUpdates).toEqual([ + // Expect the sequence of states, ignoring duplicates + const uniqueStatusUpdates = statusUpdates.filter( + (status, index, self) => index === 0 || status !== self[index - 1], + ); + expect(uniqueStatusUpdates).toEqual([ 'validating', 'scheduled', 'executing', @@ -1799,8 +1818,9 @@ describe('CoreToolScheduler Sequential Execution', () => { abortController.signal, ); - const toolCall = (scheduler as unknown as { toolCalls: ToolCall[] }) - .toolCalls[0] as WaitingToolCall; + const toolCall = ( + scheduler as unknown as { state: { getSnapshot: () => ToolCall[] } } + ).state.getSnapshot()[0] as WaitingToolCall; expect(toolCall.status).toBe('awaiting_approval'); const confirmationSignal = new AbortController().signal; diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts index 11200742483..2af92ed13fe 100644 --- a/packages/core/src/core/coreToolScheduler.ts +++ b/packages/core/src/core/coreToolScheduler.ts @@ -5,7 +5,6 @@ */ import { - type ToolResultDisplay, type AnyDeclarativeTool, type AnyToolInvocation, type ToolCallConfirmationDetails, @@ -19,12 +18,7 @@ import { logToolCall } from '../telemetry/loggers.js'; import { ToolErrorType } from '../tools/tool-error.js'; import { ToolCallEvent } from '../telemetry/types.js'; import { runInDevTraceSpan } from '../telemetry/trace.js'; -import type { ModifyContext } from '../tools/modifiable-tool.js'; -import { - isModifiableDeclarativeTool, - modifyWithEditor, -} from '../tools/modifiable-tool.js'; -import * as Diff from 'diff'; +import { ToolModificationHandler } from '../scheduler/tool-modifier.js'; import { getToolSuggestion } from '../utils/tool-utils.js'; import type { ToolConfirmationRequest } from '../confirmation-bus/types.js'; import { MessageBusType } from '../confirmation-bus/types.js'; @@ -49,6 +43,7 @@ import { type ToolCallResponseInfo, } from '../scheduler/types.js'; import { ToolExecutor } from '../scheduler/tool-executor.js'; +import { SchedulerStateManager } from '../scheduler/state-manager.js'; export type { ToolCall, @@ -106,7 +101,6 @@ export class CoreToolScheduler { (request: ToolConfirmationRequest) => void >(); - private toolCalls: ToolCall[] = []; private outputUpdateHandler?: OutputUpdateHandler; private onAllToolCallsComplete?: AllToolCallsCompleteHandler; private onToolCallsUpdate?: ToolCallsUpdateHandler; @@ -121,9 +115,9 @@ export class CoreToolScheduler { resolve: () => void; reject: (reason?: Error) => void; }> = []; - private toolCallQueue: ToolCall[] = []; - private completedToolCallsForBatch: CompletedToolCall[] = []; private toolExecutor: ToolExecutor; + private toolModifier: ToolModificationHandler; + private state: SchedulerStateManager; constructor(options: CoreToolSchedulerOptions) { this.config = options.config; @@ -132,6 +126,8 @@ export class CoreToolScheduler { this.onToolCallsUpdate = options.onToolCallsUpdate; this.getPreferredEditor = options.getPreferredEditor; this.toolExecutor = new ToolExecutor(this.config); + this.toolModifier = new ToolModificationHandler(); + this.state = new SchedulerStateManager(this.onToolCallsUpdate); // Subscribe to message bus for ASK_USER policy decisions // Use a static WeakMap to ensure we only subscribe ONCE per MessageBus instance @@ -164,222 +160,13 @@ export class CoreToolScheduler { } } - private setStatusInternal( - targetCallId: string, - status: 'success', - signal: AbortSignal, - response: ToolCallResponseInfo, - ): void; - private setStatusInternal( - targetCallId: string, - status: 'awaiting_approval', - signal: AbortSignal, - confirmationDetails: ToolCallConfirmationDetails, - ): void; - private setStatusInternal( - targetCallId: string, - status: 'error', - signal: AbortSignal, - response: ToolCallResponseInfo, - ): void; - private setStatusInternal( - targetCallId: string, - status: 'cancelled', - signal: AbortSignal, - reason: string, - ): void; - private setStatusInternal( - targetCallId: string, - status: 'executing' | 'scheduled' | 'validating', - signal: AbortSignal, - ): void; - private setStatusInternal( - targetCallId: string, - newStatus: Status, - signal: AbortSignal, - auxiliaryData?: unknown, - ): void { - this.toolCalls = this.toolCalls.map((currentCall) => { - if ( - currentCall.request.callId !== targetCallId || - currentCall.status === 'success' || - currentCall.status === 'error' || - currentCall.status === 'cancelled' - ) { - return currentCall; - } - - // currentCall is a non-terminal state here and should have startTime and tool. - const existingStartTime = currentCall.startTime; - const toolInstance = currentCall.tool; - const invocation = currentCall.invocation; - - const outcome = currentCall.outcome; - - switch (newStatus) { - case 'success': { - const durationMs = existingStartTime - ? Date.now() - existingStartTime - : undefined; - return { - request: currentCall.request, - tool: toolInstance, - invocation, - status: 'success', - response: auxiliaryData as ToolCallResponseInfo, - durationMs, - outcome, - } as SuccessfulToolCall; - } - case 'error': { - const durationMs = existingStartTime - ? Date.now() - existingStartTime - : undefined; - return { - request: currentCall.request, - status: 'error', - tool: toolInstance, - response: auxiliaryData as ToolCallResponseInfo, - durationMs, - outcome, - } as ErroredToolCall; - } - case 'awaiting_approval': - return { - request: currentCall.request, - tool: toolInstance, - status: 'awaiting_approval', - confirmationDetails: auxiliaryData as ToolCallConfirmationDetails, - startTime: existingStartTime, - outcome, - invocation, - } as WaitingToolCall; - case 'scheduled': - return { - request: currentCall.request, - tool: toolInstance, - status: 'scheduled', - startTime: existingStartTime, - outcome, - invocation, - } as ScheduledToolCall; - case 'cancelled': { - const durationMs = existingStartTime - ? Date.now() - existingStartTime - : undefined; - - // Preserve diff for cancelled edit operations - let resultDisplay: ToolResultDisplay | undefined = undefined; - if (currentCall.status === 'awaiting_approval') { - const waitingCall = currentCall; - if (waitingCall.confirmationDetails.type === 'edit') { - resultDisplay = { - fileDiff: waitingCall.confirmationDetails.fileDiff, - fileName: waitingCall.confirmationDetails.fileName, - originalContent: - waitingCall.confirmationDetails.originalContent, - newContent: waitingCall.confirmationDetails.newContent, - filePath: waitingCall.confirmationDetails.filePath, - }; - } - } - - const errorMessage = `[Operation Cancelled] Reason: ${auxiliaryData}`; - return { - request: currentCall.request, - tool: toolInstance, - invocation, - status: 'cancelled', - response: { - callId: currentCall.request.callId, - responseParts: [ - { - functionResponse: { - id: currentCall.request.callId, - name: currentCall.request.name, - response: { - error: errorMessage, - }, - }, - }, - ], - resultDisplay, - error: undefined, - errorType: undefined, - contentLength: errorMessage.length, - }, - durationMs, - outcome, - } as CancelledToolCall; - } - case 'validating': - return { - request: currentCall.request, - tool: toolInstance, - status: 'validating', - startTime: existingStartTime, - outcome, - invocation, - } as ValidatingToolCall; - case 'executing': - return { - request: currentCall.request, - tool: toolInstance, - status: 'executing', - startTime: existingStartTime, - outcome, - invocation, - } as ExecutingToolCall; - default: { - const exhaustiveCheck: never = newStatus; - return exhaustiveCheck; - } - } - }); - this.notifyToolCallsUpdate(); - } - - private setArgsInternal(targetCallId: string, args: unknown): void { - this.toolCalls = this.toolCalls.map((call) => { - // We should never be asked to set args on an ErroredToolCall, but - // we guard for the case anyways. - if (call.request.callId !== targetCallId || call.status === 'error') { - return call; - } - - const invocationOrError = this.buildInvocation( - call.tool, - args as Record, - ); - if (invocationOrError instanceof Error) { - const response = createErrorResponse( - call.request, - invocationOrError, - ToolErrorType.INVALID_TOOL_PARAMS, - ); - return { - request: { ...call.request, args: args as Record }, - status: 'error', - tool: call.tool, - response, - } as ErroredToolCall; - } - - return { - ...call, - request: { ...call.request, args: args as Record }, - invocation: invocationOrError, - }; - }); - } - private isRunning(): boolean { + const firstActive = this.state.getFirstActiveCall(); return ( this.isFinalizingToolCalls || - this.toolCalls.some( - (call) => - call.status === 'executing' || call.status === 'awaiting_approval', - ) + (firstActive !== undefined && + (firstActive.status === 'executing' || + firstActive.status === 'awaiting_approval')) ); } @@ -445,8 +232,8 @@ export class CoreToolScheduler { } this.isCancelling = true; // Cancel the currently active tool call, if there is one. - if (this.toolCalls.length > 0) { - const activeCall = this.toolCalls[0]; + const activeCall = this.state.getFirstActiveCall(); + if (activeCall) { // Only cancel if it's in a cancellable state. if ( activeCall.status === 'awaiting_approval' || @@ -454,10 +241,9 @@ export class CoreToolScheduler { activeCall.status === 'scheduled' || activeCall.status === 'validating' ) { - this.setStatusInternal( + this.state.updateStatus( activeCall.request.callId, 'cancelled', - signal, 'User cancelled the operation.', ); } @@ -483,7 +269,7 @@ export class CoreToolScheduler { ); } const requestsToProcess = Array.isArray(request) ? request : [request]; - this.completedToolCallsForBatch = []; + this.state.clearBatch(); const newToolCalls: ToolCall[] = requestsToProcess.map( (reqInfo): ToolCall => { @@ -504,8 +290,9 @@ export class CoreToolScheduler { new Error(errorMessage), ToolErrorType.TOOL_NOT_REGISTERED, ), + startTime: Date.now(), durationMs: 0, - }; + } as ErroredToolCall; } const invocationOrError = this.buildInvocation( @@ -522,8 +309,9 @@ export class CoreToolScheduler { invocationOrError, ToolErrorType.INVALID_TOOL_PARAMS, ), + startTime: Date.now(), durationMs: 0, - }; + } as ErroredToolCall; } return { @@ -536,7 +324,7 @@ export class CoreToolScheduler { }, ); - this.toolCallQueue.push(...newToolCalls); + this.state.enqueue(newToolCalls); await this._processNextInQueue(signal); } finally { this.isScheduling = false; @@ -545,7 +333,7 @@ export class CoreToolScheduler { private async _processNextInQueue(signal: AbortSignal): Promise { // If there's already a tool being processed, or the queue is empty, stop. - if (this.toolCalls.length > 0 || this.toolCallQueue.length === 0) { + if (this.state.hasActiveCalls() || this.state.getQueueLength() === 0) { return; } @@ -557,16 +345,17 @@ export class CoreToolScheduler { return; } - const toolCall = this.toolCallQueue.shift()!; - - // This is now the single active tool call. - this.toolCalls = [toolCall]; - this.notifyToolCallsUpdate(); + const toolCall = this.state.dequeue()!; // Handle tools that were already errored during creation. if (toolCall.status === 'error') { // An error during validation means this "active" tool is already complete. // We need to check for batch completion to either finish or process the next in queue. + this.state.updateStatus( + toolCall.request.callId, + 'error', + toolCall.response, + ); await this.checkAndNotifyCompletion(signal); return; } @@ -577,10 +366,9 @@ export class CoreToolScheduler { try { if (signal.aborted) { - this.setStatusInternal( + this.state.updateStatus( reqInfo.callId, 'cancelled', - signal, 'Tool call cancelled by user.', ); // The completion check will handle the cascade. @@ -600,10 +388,9 @@ export class CoreToolScheduler { if (decision === PolicyDecision.DENY) { const errorMessage = `Tool execution denied by policy.`; - this.setStatusInternal( + this.state.updateStatus( reqInfo.callId, 'error', - signal, createErrorResponse( reqInfo, new Error(errorMessage), @@ -615,11 +402,11 @@ export class CoreToolScheduler { } if (decision === PolicyDecision.ALLOW) { - this.setToolCallOutcome( + this.state.setOutcome( reqInfo.callId, ToolConfirmationOutcome.ProceedAlways, ); - this.setStatusInternal(reqInfo.callId, 'scheduled', signal); + this.state.updateStatus(reqInfo.callId, 'scheduled'); } else { // PolicyDecision.ASK_USER @@ -628,11 +415,11 @@ export class CoreToolScheduler { await invocation.shouldConfirmExecute(signal); if (!confirmationDetails) { - this.setToolCallOutcome( + this.state.setOutcome( reqInfo.callId, ToolConfirmationOutcome.ProceedAlways, ); - this.setStatusInternal(reqInfo.callId, 'scheduled', signal); + this.state.updateStatus(reqInfo.callId, 'scheduled'); } else { if (!this.config.isInteractive()) { throw new Error( @@ -691,28 +478,26 @@ export class CoreToolScheduler { payload, ), }; - this.setStatusInternal( + this.state.updateStatus( reqInfo.callId, 'awaiting_approval', - signal, wrappedConfirmationDetails, ); } } } catch (error) { if (signal.aborted) { - this.setStatusInternal( + this.state.updateStatus( reqInfo.callId, 'cancelled', - signal, 'Tool call cancelled by user.', ); await this.checkAndNotifyCompletion(signal); + return; } else { - this.setStatusInternal( + this.state.updateStatus( reqInfo.callId, 'error', - signal, createErrorResponse( reqInfo, error instanceof Error ? error : new Error(String(error)), @@ -733,15 +518,17 @@ export class CoreToolScheduler { signal: AbortSignal, payload?: ToolConfirmationPayload, ): Promise { - const toolCall = this.toolCalls.find( - (c) => c.request.callId === callId && c.status === 'awaiting_approval', - ); + const toolCall = this.state + .getSnapshot() + .find( + (c) => c.request.callId === callId && c.status === 'awaiting_approval', + ); if (toolCall && toolCall.status === 'awaiting_approval') { await originalOnConfirm(outcome); } - this.setToolCallOutcome(callId, outcome); + this.state.setOutcome(callId, outcome); if (outcome === ToolConfirmationOutcome.Cancel || signal.aborted) { // Instead of just cancelling one tool, trigger the full cancel cascade. @@ -749,130 +536,107 @@ export class CoreToolScheduler { return; // `cancelAll` calls `checkAndNotifyCompletion`, so we can exit here. } else if (outcome === ToolConfirmationOutcome.ModifyWithEditor) { const waitingToolCall = toolCall as WaitingToolCall; - if (isModifiableDeclarativeTool(waitingToolCall.tool)) { - const modifyContext = waitingToolCall.tool.getModifyContext(signal); - const editorType = this.getPreferredEditor(); - if (!editorType) { - return; - } - this.setStatusInternal(callId, 'awaiting_approval', signal, { - ...waitingToolCall.confirmationDetails, - isModifying: true, - } as ToolCallConfirmationDetails); + const editorType = this.getPreferredEditor(); + if (!editorType) { + return; + } - const contentOverrides = - waitingToolCall.confirmationDetails.type === 'edit' - ? { - currentContent: - waitingToolCall.confirmationDetails.originalContent, - proposedContent: waitingToolCall.confirmationDetails.newContent, - } - : undefined; + this.state.updateStatus(callId, 'awaiting_approval', { + ...waitingToolCall.confirmationDetails, + isModifying: true, + } as ToolCallConfirmationDetails); - const { updatedParams, updatedDiff } = await modifyWithEditor< - typeof waitingToolCall.request.args - >( - waitingToolCall.request.args, - modifyContext as ModifyContext, - editorType, - signal, - contentOverrides, + const result = await this.toolModifier.handleModifyWithEditor( + waitingToolCall, + editorType, + signal, + ); + + // Restore status (isModifying: false) and update diff if result exists + if (result) { + const invocationOrError = this.buildInvocation( + waitingToolCall.tool, + result.updatedParams, ); - this.setArgsInternal(callId, updatedParams); - this.setStatusInternal(callId, 'awaiting_approval', signal, { + if (!(invocationOrError instanceof Error)) { + this.state.updateArgs( + callId, + result.updatedParams, + invocationOrError, + ); + } + this.state.updateStatus(callId, 'awaiting_approval', { + ...waitingToolCall.confirmationDetails, + fileDiff: result.updatedDiff, + isModifying: false, + } as ToolCallConfirmationDetails); + } else { + this.state.updateStatus(callId, 'awaiting_approval', { ...waitingToolCall.confirmationDetails, - fileDiff: updatedDiff, isModifying: false, } as ToolCallConfirmationDetails); } } else { - // If the client provided new content, apply it before scheduling. + // If the client provided new content, apply it and wait for + // re-confirmation. if (payload?.newContent && toolCall) { - await this._applyInlineModify( + const result = await this.toolModifier.applyInlineModify( toolCall as WaitingToolCall, payload, signal, ); + if (result) { + const invocationOrError = this.buildInvocation( + (toolCall as WaitingToolCall).tool, + result.updatedParams, + ); + if (!(invocationOrError instanceof Error)) { + this.state.updateArgs( + callId, + result.updatedParams, + invocationOrError, + ); + } + this.state.updateStatus(callId, 'awaiting_approval', { + ...(toolCall as WaitingToolCall).confirmationDetails, + fileDiff: result.updatedDiff, + } as ToolCallConfirmationDetails); + // After an inline modification, wait for another user confirmation. + return; + } } - this.setStatusInternal(callId, 'scheduled', signal); + this.state.updateStatus(callId, 'scheduled'); } await this.attemptExecutionOfScheduledCalls(signal); } - /** - * Applies user-provided content changes to a tool call that is awaiting confirmation. - * This method updates the tool's arguments and refreshes the confirmation prompt with a new diff - * before the tool is scheduled for execution. - * @private - */ - private async _applyInlineModify( - toolCall: WaitingToolCall, - payload: ToolConfirmationPayload, - signal: AbortSignal, - ): Promise { - if ( - toolCall.confirmationDetails.type !== 'edit' || - !isModifiableDeclarativeTool(toolCall.tool) - ) { - return; - } - - const modifyContext = toolCall.tool.getModifyContext(signal); - const currentContent = await modifyContext.getCurrentContent( - toolCall.request.args, - ); - - const updatedParams = modifyContext.createUpdatedParams( - currentContent, - payload.newContent, - toolCall.request.args, - ); - const updatedDiff = Diff.createPatch( - modifyContext.getFilePath(toolCall.request.args), - currentContent, - payload.newContent, - 'Current', - 'Proposed', - ); - - this.setArgsInternal(toolCall.request.callId, updatedParams); - this.setStatusInternal( - toolCall.request.callId, - 'awaiting_approval', - signal, - { - ...toolCall.confirmationDetails, - fileDiff: updatedDiff, - }, - ); - } - private async attemptExecutionOfScheduledCalls( signal: AbortSignal, ): Promise { - const allCallsFinalOrScheduled = this.toolCalls.every( + const allCallsFinalOrScheduled = this.state.getSnapshot().every( (call) => call.status === 'scheduled' || call.status === 'cancelled' || call.status === 'success' || - call.status === 'error', + call.status === 'error' || + call.status === 'validating', // validating ones are in queue ); if (allCallsFinalOrScheduled) { - const callsToExecute = this.toolCalls.filter( - (call) => call.status === 'scheduled', - ); + const callsToExecute = this.state + .getSnapshot() + .filter((call) => call.status === 'scheduled'); for (const toolCall of callsToExecute) { if (toolCall.status !== 'scheduled') continue; - this.setStatusInternal(toolCall.request.callId, 'executing', signal); - const executingCall = this.toolCalls.find( - (c) => c.request.callId === toolCall.request.callId, - ); + this.state.updateStatus(toolCall.request.callId, 'executing'); + const executingCall = this.state + .getSnapshot() + .find((c) => c.request.callId === toolCall.request.callId); - if (!executingCall) { + if (!executingCall || executingCall.status !== 'executing') { // Should not happen, but safe guard continue; } @@ -884,29 +648,28 @@ export class CoreToolScheduler { if (this.outputUpdateHandler) { this.outputUpdateHandler(callId, output); } - this.toolCalls = this.toolCalls.map((tc) => - tc.request.callId === callId && tc.status === 'executing' - ? { ...tc, liveOutput: output } - : tc, - ); - this.notifyToolCallsUpdate(); + // Update live output in state manager + this.state.updateStatus(callId, 'executing', { + ...executingCall, + liveOutput: output, + }); }, onUpdateToolCall: (updatedCall) => { - this.toolCalls = this.toolCalls.map((tc) => - tc.request.callId === updatedCall.request.callId - ? updatedCall - : tc, + // This is a bit tricky since updateStatus handles transitions. + // For general updates, we might need a more direct way or just use updateStatus with current status. + this.state.updateStatus( + updatedCall.request.callId, + updatedCall.status, + updatedCall, ); - this.notifyToolCallsUpdate(); }, }); - this.toolCalls = this.toolCalls.map((tc) => - tc.request.callId === completedCall.request.callId - ? completedCall - : tc, + this.state.updateStatus( + completedCall.request.callId, + completedCall.status, + completedCall.response, ); - this.notifyToolCallsUpdate(); await this.checkAndNotifyCompletion(signal); } @@ -915,13 +678,13 @@ export class CoreToolScheduler { private async checkAndNotifyCompletion(signal: AbortSignal): Promise { // This method is now only concerned with the single active tool call. - if (this.toolCalls.length === 0) { + if (!this.state.hasActiveCalls()) { // It's possible to be called when a batch is cancelled before any tool has started. - if (signal.aborted && this.toolCallQueue.length > 0) { + if (signal.aborted && this.state.getQueueLength() > 0) { this._cancelAllQueuedCalls(); } } else { - const activeCall = this.toolCalls[0]; + const activeCall = this.state.getFirstActiveCall()!; const isTerminal = activeCall.status === 'success' || activeCall.status === 'error' || @@ -933,37 +696,38 @@ export class CoreToolScheduler { return; } - // The active tool is finished. Move it to the completed batch. - const completedCall = activeCall as CompletedToolCall; - this.completedToolCallsForBatch.push(completedCall); - logToolCall(this.config, new ToolCallEvent(completedCall)); - - // Clear the active tool slot. This is crucial for the sequential processing. - this.toolCalls = []; + // The state manager handles moving terminal calls to the completed batch + // and removing them from the active map. + logToolCall( + this.config, + new ToolCallEvent(activeCall as CompletedToolCall), + ); + this.state.finalizeCall(activeCall.request.callId); } // Now, check if the entire batch is complete. // The batch is complete if the queue is empty or the operation was cancelled. - if (this.toolCallQueue.length === 0 || signal.aborted) { + if (this.state.getQueueLength() === 0 || signal.aborted) { if (signal.aborted) { this._cancelAllQueuedCalls(); } + const completedBatch = this.state.getCompletedBatch(); + // If there's nothing to report and we weren't cancelled, we can stop. // But if we were cancelled, we must proceed to potentially start the next queued request. - if (this.completedToolCallsForBatch.length === 0 && !signal.aborted) { + if (completedBatch.length === 0 && !signal.aborted) { return; } if (this.onAllToolCallsComplete) { this.isFinalizingToolCalls = true; // Use the batch array, not the (now empty) active array. - await this.onAllToolCallsComplete(this.completedToolCallsForBatch); - this.completedToolCallsForBatch = []; // Clear after reporting. + await this.onAllToolCallsComplete(completedBatch); + this.state.clearBatch(); // Clear after reporting. this.isFinalizingToolCalls = false; } this.isCancelling = false; - this.notifyToolCallsUpdate(); // After completion of the entire batch, process the next item in the main request queue. if (this.requestQueue.length > 0) { @@ -979,65 +743,6 @@ export class CoreToolScheduler { } private _cancelAllQueuedCalls(): void { - while (this.toolCallQueue.length > 0) { - const queuedCall = this.toolCallQueue.shift()!; - // Don't cancel tools that already errored during validation. - if (queuedCall.status === 'error') { - this.completedToolCallsForBatch.push(queuedCall); - continue; - } - const durationMs = - 'startTime' in queuedCall && queuedCall.startTime - ? Date.now() - queuedCall.startTime - : undefined; - const errorMessage = - '[Operation Cancelled] User cancelled the operation.'; - this.completedToolCallsForBatch.push({ - request: queuedCall.request, - tool: queuedCall.tool, - invocation: queuedCall.invocation, - status: 'cancelled', - response: { - callId: queuedCall.request.callId, - responseParts: [ - { - functionResponse: { - id: queuedCall.request.callId, - name: queuedCall.request.name, - response: { - error: errorMessage, - }, - }, - }, - ], - resultDisplay: undefined, - error: undefined, - errorType: undefined, - contentLength: errorMessage.length, - }, - durationMs, - outcome: ToolConfirmationOutcome.Cancel, - }); - } - } - - private notifyToolCallsUpdate(): void { - if (this.onToolCallsUpdate) { - this.onToolCallsUpdate([ - ...this.completedToolCallsForBatch, - ...this.toolCalls, - ...this.toolCallQueue, - ]); - } - } - - private setToolCallOutcome(callId: string, outcome: ToolConfirmationOutcome) { - this.toolCalls = this.toolCalls.map((call) => { - if (call.request.callId !== callId) return call; - return { - ...call, - outcome, - }; - }); + this.state.cancelAllQueued('User cancelled the operation.'); } } diff --git a/packages/core/src/scheduler/state-manager.test.ts b/packages/core/src/scheduler/state-manager.test.ts new file mode 100644 index 00000000000..40ac30b60a0 --- /dev/null +++ b/packages/core/src/scheduler/state-manager.test.ts @@ -0,0 +1,403 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { SchedulerStateManager } from './state-manager.js'; +import { + type ValidatingToolCall, + type WaitingToolCall, + type SuccessfulToolCall, + type ErroredToolCall, + type CancelledToolCall, + type ExecutingToolCall, + type ToolCallRequestInfo, + type ToolCallResponseInfo, + type ToolCallsUpdateHandler, +} from './types.js'; +import { + ToolConfirmationOutcome, + type AnyDeclarativeTool, + type AnyToolInvocation, +} from '../tools/tools.js'; + +describe('SchedulerStateManager', () => { + const mockRequest: ToolCallRequestInfo = { + callId: 'call-1', + name: 'test-tool', + args: { foo: 'bar' }, + isClientInitiated: false, + prompt_id: 'prompt-1', + }; + + const mockTool = { + name: 'test-tool', + displayName: 'Test Tool', + } as AnyDeclarativeTool; + + const mockInvocation = { + shouldConfirmExecute: vi.fn(), + } as unknown as AnyToolInvocation; + + const createValidatingCall = (id = 'call-1'): ValidatingToolCall => ({ + status: 'validating', + request: { ...mockRequest, callId: id }, + tool: mockTool, + invocation: mockInvocation, + startTime: Date.now(), + }); + + let stateManager: SchedulerStateManager; + let onUpdate: ToolCallsUpdateHandler; + + beforeEach(() => { + onUpdate = vi.fn() as unknown as ToolCallsUpdateHandler; + stateManager = new SchedulerStateManager(onUpdate); + }); + + describe('Initialization', () => { + it('should start with empty state', () => { + expect(stateManager.hasActiveCalls()).toBe(false); + expect(stateManager.getActiveCallCount()).toBe(0); + expect(stateManager.getQueueLength()).toBe(0); + expect(stateManager.getSnapshot()).toEqual([]); + }); + }); + + describe('Queue Management', () => { + it('should enqueue calls and notify', () => { + const call = createValidatingCall(); + stateManager.enqueue([call]); + + expect(stateManager.getQueueLength()).toBe(1); + expect(onUpdate).toHaveBeenCalledWith([call]); + }); + + it('should dequeue calls and notify', () => { + const call = createValidatingCall(); + stateManager.enqueue([call]); + + const dequeued = stateManager.dequeue(); + + expect(dequeued).toEqual(call); + expect(stateManager.getQueueLength()).toBe(0); + expect(stateManager.getActiveCallCount()).toBe(1); + expect(onUpdate).toHaveBeenCalled(); + }); + + it('should return undefined when dequeueing from empty queue', () => { + const dequeued = stateManager.dequeue(); + expect(dequeued).toBeUndefined(); + }); + }); + + describe('Status Transitions', () => { + it('should transition validating to scheduled', () => { + const call = createValidatingCall(); + stateManager.enqueue([call]); + stateManager.dequeue(); + + stateManager.updateStatus(call.request.callId, 'scheduled'); + + const snapshot = stateManager.getSnapshot(); + expect(snapshot[0].status).toBe('scheduled'); + expect(snapshot[0].request.callId).toBe(call.request.callId); + }); + + it('should transition scheduled to executing', () => { + const call = createValidatingCall(); + stateManager.enqueue([call]); + stateManager.dequeue(); + stateManager.updateStatus(call.request.callId, 'scheduled'); + + stateManager.updateStatus(call.request.callId, 'executing'); + + expect(stateManager.getFirstActiveCall()?.status).toBe('executing'); + }); + + it('should transition to success and move to completed batch', () => { + const call = createValidatingCall(); + stateManager.enqueue([call]); + stateManager.dequeue(); + + const response: ToolCallResponseInfo = { + callId: call.request.callId, + responseParts: [], + resultDisplay: 'Success', + error: undefined, + errorType: undefined, + }; + + stateManager.updateStatus(call.request.callId, 'success', response); + stateManager.finalizeCall(call.request.callId); + + expect(stateManager.hasActiveCalls()).toBe(false); + expect(stateManager.getCompletedBatch()).toHaveLength(1); + const completed = + stateManager.getCompletedBatch()[0] as SuccessfulToolCall; + expect(completed.status).toBe('success'); + expect(completed.response).toEqual(response); + expect(completed.durationMs).toBeDefined(); + }); + + it('should transition to error and move to completed batch', () => { + const call = createValidatingCall(); + stateManager.enqueue([call]); + stateManager.dequeue(); + + const response: ToolCallResponseInfo = { + callId: call.request.callId, + responseParts: [], + resultDisplay: 'Error', + error: new Error('Failed'), + errorType: undefined, + }; + + stateManager.updateStatus(call.request.callId, 'error', response); + stateManager.finalizeCall(call.request.callId); + + expect(stateManager.hasActiveCalls()).toBe(false); + expect(stateManager.getCompletedBatch()).toHaveLength(1); + const completed = stateManager.getCompletedBatch()[0] as ErroredToolCall; + expect(completed.status).toBe('error'); + expect(completed.response).toEqual(response); + }); + + it('should transition to awaiting_approval with details', () => { + const call = createValidatingCall(); + stateManager.enqueue([call]); + stateManager.dequeue(); + + const details = { type: 'info', title: 'Confirm', prompt: 'Proceed?' }; + + stateManager.updateStatus( + call.request.callId, + 'awaiting_approval', + details, + ); + + const active = stateManager.getFirstActiveCall() as WaitingToolCall; + expect(active.status).toBe('awaiting_approval'); + expect(active.confirmationDetails).toEqual(details); + }); + + it('should preserve diff when cancelling an edit tool call', () => { + const call = createValidatingCall(); + stateManager.enqueue([call]); + stateManager.dequeue(); + + const details = { + type: 'edit', + title: 'Edit', + fileName: 'test.txt', + fileDiff: 'diff', + originalContent: 'old', + newContent: 'new', + }; + + stateManager.updateStatus( + call.request.callId, + 'awaiting_approval', + details, + ); + stateManager.updateStatus( + call.request.callId, + 'cancelled', + 'User said no', + ); + stateManager.finalizeCall(call.request.callId); + + const completed = + stateManager.getCompletedBatch()[0] as CancelledToolCall; + expect(completed.status).toBe('cancelled'); + expect(completed.response.resultDisplay).toEqual({ + fileDiff: 'diff', + fileName: 'test.txt', + originalContent: 'old', + newContent: 'new', + }); + }); + + it('should ignore status updates for non-existent callIds', () => { + stateManager.updateStatus('unknown', 'scheduled'); + expect(onUpdate).not.toHaveBeenCalled(); + }); + + it('should ignore status updates for terminal calls', () => { + const call = createValidatingCall(); + stateManager.enqueue([call]); + stateManager.dequeue(); + stateManager.updateStatus(call.request.callId, 'success', {}); + stateManager.finalizeCall(call.request.callId); + + vi.mocked(onUpdate).mockClear(); + stateManager.updateStatus(call.request.callId, 'scheduled'); + expect(onUpdate).not.toHaveBeenCalled(); + }); + + it('should only finalize terminal calls', () => { + const call = createValidatingCall(); + stateManager.enqueue([call]); + stateManager.dequeue(); + + stateManager.updateStatus(call.request.callId, 'executing'); + stateManager.finalizeCall(call.request.callId); + + expect(stateManager.hasActiveCalls()).toBe(true); + expect(stateManager.getCompletedBatch()).toHaveLength(0); + + stateManager.updateStatus(call.request.callId, 'success', {}); + stateManager.finalizeCall(call.request.callId); + + expect(stateManager.hasActiveCalls()).toBe(false); + expect(stateManager.getCompletedBatch()).toHaveLength(1); + }); + + it('should merge liveOutput and pid during executing updates', () => { + const call = createValidatingCall(); + stateManager.enqueue([call]); + stateManager.dequeue(); + + // Start executing + stateManager.updateStatus(call.request.callId, 'executing'); + let active = stateManager.getFirstActiveCall() as ExecutingToolCall; + expect(active.status).toBe('executing'); + expect(active.liveOutput).toBeUndefined(); + + // Update with live output + stateManager.updateStatus(call.request.callId, 'executing', { + liveOutput: 'chunk 1', + }); + active = stateManager.getFirstActiveCall() as ExecutingToolCall; + expect(active.liveOutput).toBe('chunk 1'); + + // Update with pid (should preserve liveOutput) + stateManager.updateStatus(call.request.callId, 'executing', { + pid: 1234, + }); + active = stateManager.getFirstActiveCall() as ExecutingToolCall; + expect(active.liveOutput).toBe('chunk 1'); + expect(active.pid).toBe(1234); + + // Update live output again (should preserve pid) + stateManager.updateStatus(call.request.callId, 'executing', { + liveOutput: 'chunk 2', + }); + active = stateManager.getFirstActiveCall() as ExecutingToolCall; + expect(active.liveOutput).toBe('chunk 2'); + expect(active.pid).toBe(1234); + }); + }); + + describe('Argument Updates', () => { + it('should update args and invocation', () => { + const call = createValidatingCall(); + stateManager.enqueue([call]); + stateManager.dequeue(); + + const newArgs = { foo: 'updated' }; + const newInvocation = { ...mockInvocation } as AnyToolInvocation; + + stateManager.updateArgs(call.request.callId, newArgs, newInvocation); + + const active = stateManager.getFirstActiveCall(); + if (active && 'invocation' in active) { + expect(active.invocation).toEqual(newInvocation); + } else { + throw new Error('Active call should have invocation'); + } + }); + + it('should ignore arg updates for errored calls', () => { + const call = createValidatingCall(); + stateManager.enqueue([call]); + stateManager.dequeue(); + stateManager.updateStatus(call.request.callId, 'error', {}); + stateManager.finalizeCall(call.request.callId); + + stateManager.updateArgs( + call.request.callId, + { foo: 'new' }, + mockInvocation, + ); + + const completed = stateManager.getCompletedBatch()[0]; + expect(completed.request.args).toEqual(mockRequest.args); + }); + }); + + describe('Outcome Tracking', () => { + it('should set outcome and notify', () => { + const call = createValidatingCall(); + stateManager.enqueue([call]); + stateManager.dequeue(); + + stateManager.setOutcome( + call.request.callId, + ToolConfirmationOutcome.ProceedAlways, + ); + + const active = stateManager.getFirstActiveCall(); + expect(active?.outcome).toBe(ToolConfirmationOutcome.ProceedAlways); + expect(onUpdate).toHaveBeenCalled(); + }); + }); + + describe('Batch Operations', () => { + it('should cancel all queued calls', () => { + stateManager.enqueue([ + createValidatingCall('1'), + createValidatingCall('2'), + ]); + + stateManager.cancelAllQueued('Batch cancel'); + + expect(stateManager.getQueueLength()).toBe(0); + expect(stateManager.getCompletedBatch()).toHaveLength(2); + expect( + stateManager.getCompletedBatch().every((c) => c.status === 'cancelled'), + ).toBe(true); + }); + + it('should clear batch and notify', () => { + const call = createValidatingCall(); + stateManager.enqueue([call]); + stateManager.dequeue(); + stateManager.updateStatus(call.request.callId, 'success', {}); + stateManager.finalizeCall(call.request.callId); + + stateManager.clearBatch(); + + expect(stateManager.getCompletedBatch()).toHaveLength(0); + expect(onUpdate).toHaveBeenCalledWith([]); + }); + }); + + describe('Snapshot and Ordering', () => { + it('should return snapshot in order: completed, active, queue', () => { + // 1. Completed + const call1 = createValidatingCall('1'); + stateManager.enqueue([call1]); + stateManager.dequeue(); + stateManager.updateStatus('1', 'success', {}); + stateManager.finalizeCall('1'); + + // 2. Active + const call2 = createValidatingCall('2'); + stateManager.enqueue([call2]); + stateManager.dequeue(); + + // 3. Queue + const call3 = createValidatingCall('3'); + stateManager.enqueue([call3]); + + const snapshot = stateManager.getSnapshot(); + expect(snapshot).toHaveLength(3); + expect(snapshot[0].request.callId).toBe('1'); + expect(snapshot[1].request.callId).toBe('2'); + expect(snapshot[2].request.callId).toBe('3'); + }); + }); +}); diff --git a/packages/core/src/scheduler/state-manager.ts b/packages/core/src/scheduler/state-manager.ts new file mode 100644 index 00000000000..61891ac585c --- /dev/null +++ b/packages/core/src/scheduler/state-manager.ts @@ -0,0 +1,365 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + type ToolCall, + type Status, + type WaitingToolCall, + type ToolCallsUpdateHandler, + type CompletedToolCall, + type SuccessfulToolCall, + type ErroredToolCall, + type CancelledToolCall, + type ScheduledToolCall, + type ValidatingToolCall, + type ExecutingToolCall, + type ToolCallResponseInfo, +} from './types.js'; +import { + type ToolCallConfirmationDetails, + ToolConfirmationOutcome, + type ToolResultDisplay, + type AnyToolInvocation, +} from '../tools/tools.js'; + +/** + * Manages the state of tool calls for the CoreToolScheduler. + * This class encapsulates the data structures and state transitions. + */ +export class SchedulerStateManager { + private activeCalls = new Map(); + private queue: ToolCall[] = []; + private completedBatch: CompletedToolCall[] = []; + + constructor(private onUpdate?: ToolCallsUpdateHandler) {} + + enqueue(calls: ToolCall[]): void { + this.queue.push(...calls); + this.emitUpdate(); + } + + dequeue(): ToolCall | undefined { + const next = this.queue.shift(); + if (next) { + this.activeCalls.set(next.request.callId, next); + } + return next; + } + + hasActiveCalls(): boolean { + return this.activeCalls.size > 0; + } + + getActiveCallCount(): number { + return this.activeCalls.size; + } + + getQueueLength(): number { + return this.queue.length; + } + + /** + * Returns the first active call, if any. + * Useful for the current single-active-tool implementation. + */ + getFirstActiveCall(): ToolCall | undefined { + return this.activeCalls.values().next().value; + } + + updateStatus(callId: string, status: Status, auxiliaryData?: unknown): void { + const call = this.activeCalls.get(callId); + if (!call) return; + + const updatedCall = this.transitionCall(call, status, auxiliaryData); + + this.activeCalls.set(callId, updatedCall); + + this.emitUpdate(); + } + + /** + * Moves an active tool call to the completed batch. + * This should be called by the scheduler after it has finished processing + * the terminal state (e.g., logging, etc). + */ + finalizeCall(callId: string): void { + const call = this.activeCalls.get(callId); + if (!call) { + return; + } + + if (this.isTerminal(call.status)) { + this.completedBatch.push(call as CompletedToolCall); + this.activeCalls.delete(callId); + } + } + + updateArgs( + callId: string, + newArgs: Record, + newInvocation: AnyToolInvocation, + ): void { + const call = this.activeCalls.get(callId); + if (!call || call.status === 'error') return; + + this.activeCalls.set(callId, { + ...call, + request: { ...call.request, args: newArgs }, + invocation: newInvocation, + } as ToolCall); + this.emitUpdate(); + } + + setOutcome(callId: string, outcome: ToolConfirmationOutcome): void { + const call = this.activeCalls.get(callId); + if (!call) return; + + this.activeCalls.set(callId, { + ...call, + outcome, + } as ToolCall); + this.emitUpdate(); + } + + cancelAllQueued(reason: string): void { + while (this.queue.length > 0) { + const queuedCall = this.queue.shift()!; + + // Don't cancel tools that already errored during validation. + if (queuedCall.status === 'error') { + this.completedBatch.push(queuedCall); + continue; + } + + const durationMs = + 'startTime' in queuedCall && queuedCall.startTime + ? Date.now() - queuedCall.startTime + : undefined; + + const errorMessage = `[Operation Cancelled] ${reason}`; + + this.completedBatch.push({ + request: queuedCall.request, + tool: queuedCall.tool, + invocation: queuedCall.invocation, + status: 'cancelled', + response: { + callId: queuedCall.request.callId, + responseParts: [ + { + functionResponse: { + id: queuedCall.request.callId, + name: queuedCall.request.name, + response: { + error: errorMessage, + }, + }, + }, + ], + resultDisplay: undefined, + error: undefined, + errorType: undefined, + contentLength: errorMessage.length, + }, + durationMs, + outcome: ToolConfirmationOutcome.Cancel, + } as CancelledToolCall); + } + this.emitUpdate(); + } + + getSnapshot(): ToolCall[] { + return [ + ...this.completedBatch, + ...Array.from(this.activeCalls.values()), + ...this.queue, + ]; + } + + clearBatch(): void { + if (this.completedBatch.length === 0) return; + this.completedBatch = []; + this.emitUpdate(); + } + + getCompletedBatch(): CompletedToolCall[] { + return this.completedBatch; + } + + private emitUpdate() { + if (this.onUpdate) { + this.onUpdate(this.getSnapshot()); + } + } + + private isTerminal(status: Status): boolean { + return status === 'success' || status === 'error' || status === 'cancelled'; + } + + private transitionCall( + call: ToolCall, + newStatus: Status, + auxiliaryData?: unknown, + ): ToolCall { + switch (newStatus) { + case 'success': + return this.toSuccess(call, auxiliaryData as ToolCallResponseInfo); + case 'error': + return this.toError(call, auxiliaryData as ToolCallResponseInfo); + case 'awaiting_approval': + return this.toAwaitingApproval( + call, + auxiliaryData as ToolCallConfirmationDetails, + ); + case 'scheduled': + return this.toScheduled(call); + case 'cancelled': + return this.toCancelled(call, auxiliaryData as string); + case 'validating': + return this.toValidating(call); + case 'executing': + return this.toExecuting(call, auxiliaryData); + default: { + const exhaustiveCheck: never = newStatus; + return exhaustiveCheck; + } + } + } + + private toSuccess( + call: ToolCall, + response: ToolCallResponseInfo, + ): SuccessfulToolCall { + const startTime = 'startTime' in call ? call.startTime : undefined; + return { + request: call.request, + tool: 'tool' in call ? call.tool : undefined, + invocation: 'invocation' in call ? call.invocation : undefined, + status: 'success', + response, + durationMs: startTime ? Date.now() - startTime : undefined, + outcome: call.outcome, + } as SuccessfulToolCall; + } + + private toError( + call: ToolCall, + response: ToolCallResponseInfo, + ): ErroredToolCall { + const startTime = 'startTime' in call ? call.startTime : undefined; + return { + request: call.request, + status: 'error', + tool: 'tool' in call ? call.tool : undefined, + response, + durationMs: startTime ? Date.now() - startTime : undefined, + outcome: call.outcome, + } as ErroredToolCall; + } + + private toAwaitingApproval( + call: ToolCall, + confirmationDetails: ToolCallConfirmationDetails, + ): WaitingToolCall { + return { + request: call.request, + tool: 'tool' in call ? call.tool : undefined, + status: 'awaiting_approval', + confirmationDetails, + startTime: 'startTime' in call ? call.startTime : undefined, + outcome: call.outcome, + invocation: 'invocation' in call ? call.invocation : undefined, + } as WaitingToolCall; + } + + private toScheduled(call: ToolCall): ScheduledToolCall { + return { + request: call.request, + tool: 'tool' in call ? call.tool : undefined, + status: 'scheduled', + startTime: 'startTime' in call ? call.startTime : undefined, + outcome: call.outcome, + invocation: 'invocation' in call ? call.invocation : undefined, + } as ScheduledToolCall; + } + + private toCancelled(call: ToolCall, reason: string): CancelledToolCall { + const startTime = 'startTime' in call ? call.startTime : undefined; + + // Preserve diff for cancelled edit operations + let resultDisplay: ToolResultDisplay | undefined = undefined; + if (call.status === 'awaiting_approval') { + const waitingCall = call; + if (waitingCall.confirmationDetails.type === 'edit') { + resultDisplay = { + fileDiff: waitingCall.confirmationDetails.fileDiff, + fileName: waitingCall.confirmationDetails.fileName, + filePath: waitingCall.confirmationDetails.filePath, + originalContent: waitingCall.confirmationDetails.originalContent, + newContent: waitingCall.confirmationDetails.newContent, + }; + } + } + + const errorMessage = `[Operation Cancelled] Reason: ${reason}`; + return { + request: call.request, + tool: 'tool' in call ? call.tool : undefined, + invocation: 'invocation' in call ? call.invocation : undefined, + status: 'cancelled', + response: { + callId: call.request.callId, + responseParts: [ + { + functionResponse: { + id: call.request.callId, + name: call.request.name, + response: { + error: errorMessage, + }, + }, + }, + ], + resultDisplay, + error: undefined, + errorType: undefined, + contentLength: errorMessage.length, + }, + durationMs: startTime ? Date.now() - startTime : undefined, + outcome: call.outcome, + } as CancelledToolCall; + } + + private toValidating(call: ToolCall): ValidatingToolCall { + return { + request: call.request, + tool: 'tool' in call ? call.tool : undefined, + status: 'validating', + startTime: 'startTime' in call ? call.startTime : undefined, + outcome: call.outcome, + invocation: 'invocation' in call ? call.invocation : undefined, + } as ValidatingToolCall; + } + + private toExecuting(call: ToolCall, data?: unknown): ExecutingToolCall { + const execData = data as Partial | undefined; + const liveOutput = + execData?.liveOutput ?? + ('liveOutput' in call ? call.liveOutput : undefined); + const pid = execData?.pid ?? ('pid' in call ? call.pid : undefined); + + return { + request: call.request, + tool: 'tool' in call ? call.tool : undefined, + status: 'executing', + startTime: 'startTime' in call ? call.startTime : undefined, + outcome: call.outcome, + invocation: 'invocation' in call ? call.invocation : undefined, + liveOutput, + pid, + } as ExecutingToolCall; + } +} diff --git a/packages/core/src/scheduler/tool-modifier.test.ts b/packages/core/src/scheduler/tool-modifier.test.ts new file mode 100644 index 00000000000..8107e4c9011 --- /dev/null +++ b/packages/core/src/scheduler/tool-modifier.test.ts @@ -0,0 +1,252 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { ToolModificationHandler } from './tool-modifier.js'; +import type { WaitingToolCall, ToolCallRequestInfo } from './types.js'; +import * as modifiableToolModule from '../tools/modifiable-tool.js'; +import * as Diff from 'diff'; +import { MockModifiableTool, MockTool } from '../test-utils/mock-tool.js'; +import type { + ToolResult, + ToolInvocation, + ToolConfirmationPayload, +} from '../tools/tools.js'; +import type { ModifyContext } from '../tools/modifiable-tool.js'; +import type { Mock } from 'vitest'; + +// Mock the modules that export functions we need to control +vi.mock('diff', () => ({ + createPatch: vi.fn(), + diffLines: vi.fn(), +})); + +vi.mock('../tools/modifiable-tool.js', () => ({ + isModifiableDeclarativeTool: vi.fn(), + modifyWithEditor: vi.fn(), +})); + +type MockModifyContext = { + [K in keyof ModifyContext>]: Mock; +}; + +function createMockWaitingToolCall( + overrides: Partial = {}, +): WaitingToolCall { + return { + status: 'awaiting_approval', + request: { + callId: 'test-call-id', + name: 'test-tool', + args: {}, + isClientInitiated: false, + prompt_id: 'test-prompt-id', + } as ToolCallRequestInfo, + tool: new MockTool({ name: 'test-tool' }), + invocation: {} as ToolInvocation, ToolResult>, // We generally don't check invocation details in these tests + confirmationDetails: { + type: 'edit', + title: 'Test Confirmation', + fileName: 'test.txt', + filePath: '/path/to/test.txt', + fileDiff: 'diff', + originalContent: 'original', + newContent: 'new', + onConfirm: async () => {}, + }, + ...overrides, + }; +} + +describe('ToolModificationHandler', () => { + let handler: ToolModificationHandler; + let mockModifiableTool: MockModifiableTool; + let mockPlainTool: MockTool; + let mockModifyContext: MockModifyContext; + + beforeEach(() => { + vi.clearAllMocks(); + handler = new ToolModificationHandler(); + mockModifiableTool = new MockModifiableTool(); + mockPlainTool = new MockTool({ name: 'plainTool' }); + + mockModifyContext = { + getCurrentContent: vi.fn(), + getFilePath: vi.fn(), + createUpdatedParams: vi.fn(), + getProposedContent: vi.fn(), + }; + + vi.spyOn(mockModifiableTool, 'getModifyContext').mockReturnValue( + mockModifyContext as unknown as ModifyContext>, + ); + }); + + describe('handleModifyWithEditor', () => { + it('should return undefined if tool is not modifiable', async () => { + vi.mocked( + modifiableToolModule.isModifiableDeclarativeTool, + ).mockReturnValue(false); + + const mockWaitingToolCall = createMockWaitingToolCall({ + tool: mockPlainTool, + request: { + callId: 'call-1', + name: 'plainTool', + args: { path: 'foo.txt' }, + isClientInitiated: false, + prompt_id: 'p1', + }, + }); + + const result = await handler.handleModifyWithEditor( + mockWaitingToolCall, + 'vscode', + new AbortController().signal, + ); + + expect(result).toBeUndefined(); + }); + + it('should call modifyWithEditor and return updated params', async () => { + vi.mocked( + modifiableToolModule.isModifiableDeclarativeTool, + ).mockReturnValue(true); + + vi.mocked(modifiableToolModule.modifyWithEditor).mockResolvedValue({ + updatedParams: { path: 'foo.txt', content: 'new' }, + updatedDiff: 'diff', + }); + + const mockWaitingToolCall = createMockWaitingToolCall({ + tool: mockModifiableTool, + request: { + callId: 'call-1', + name: 'mockModifiableTool', + args: { path: 'foo.txt' }, + isClientInitiated: false, + prompt_id: 'p1', + }, + confirmationDetails: { + type: 'edit', + title: 'Confirm', + fileName: 'foo.txt', + filePath: 'foo.txt', + fileDiff: 'diff', + originalContent: 'old', + newContent: 'new', + onConfirm: async () => {}, + }, + }); + + const result = await handler.handleModifyWithEditor( + mockWaitingToolCall, + 'vscode', + new AbortController().signal, + ); + + expect(modifiableToolModule.modifyWithEditor).toHaveBeenCalledWith( + mockWaitingToolCall.request.args, + mockModifyContext, + 'vscode', + expect.any(AbortSignal), + { currentContent: 'old', proposedContent: 'new' }, + ); + + expect(result).toEqual({ + updatedParams: { path: 'foo.txt', content: 'new' }, + updatedDiff: 'diff', + }); + }); + }); + + describe('applyInlineModify', () => { + it('should return undefined if tool is not modifiable', async () => { + vi.mocked( + modifiableToolModule.isModifiableDeclarativeTool, + ).mockReturnValue(false); + + const mockWaitingToolCall = createMockWaitingToolCall({ + tool: mockPlainTool, + }); + + const result = await handler.applyInlineModify( + mockWaitingToolCall, + { newContent: 'foo' }, + new AbortController().signal, + ); + + expect(result).toBeUndefined(); + }); + + it('should return undefined if payload has no new content', async () => { + vi.mocked( + modifiableToolModule.isModifiableDeclarativeTool, + ).mockReturnValue(true); + + const mockWaitingToolCall = createMockWaitingToolCall({ + tool: mockModifiableTool, + }); + + const result = await handler.applyInlineModify( + mockWaitingToolCall, + { newContent: undefined } as unknown as ToolConfirmationPayload, + new AbortController().signal, + ); + + expect(result).toBeUndefined(); + }); + + it('should calculate diff and return updated params', async () => { + vi.mocked( + modifiableToolModule.isModifiableDeclarativeTool, + ).mockReturnValue(true); + (Diff.createPatch as unknown as Mock).mockReturnValue('mock-diff'); + + mockModifyContext.getCurrentContent.mockResolvedValue('old content'); + mockModifyContext.getFilePath.mockReturnValue('test.txt'); + mockModifyContext.createUpdatedParams.mockReturnValue({ + content: 'new content', + }); + + const mockWaitingToolCall = createMockWaitingToolCall({ + tool: mockModifiableTool, + request: { + callId: 'call-1', + name: 'mockModifiableTool', + args: { content: 'original' }, + isClientInitiated: false, + prompt_id: 'p1', + }, + }); + + const result = await handler.applyInlineModify( + mockWaitingToolCall, + { newContent: 'new content' }, + new AbortController().signal, + ); + + expect(mockModifyContext.getCurrentContent).toHaveBeenCalled(); + expect(mockModifyContext.createUpdatedParams).toHaveBeenCalledWith( + 'old content', + 'new content', + { content: 'original' }, + ); + expect(Diff.createPatch).toHaveBeenCalledWith( + 'test.txt', + 'old content', + 'new content', + 'Current', + 'Proposed', + ); + + expect(result).toEqual({ + updatedParams: { content: 'new content' }, + updatedDiff: 'mock-diff', + }); + }); + }); +}); diff --git a/packages/core/src/scheduler/tool-modifier.ts b/packages/core/src/scheduler/tool-modifier.ts new file mode 100644 index 00000000000..c7d9c93c679 --- /dev/null +++ b/packages/core/src/scheduler/tool-modifier.ts @@ -0,0 +1,105 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as Diff from 'diff'; +import type { EditorType } from '../utils/editor.js'; +import { + isModifiableDeclarativeTool, + modifyWithEditor, + type ModifyContext, +} from '../tools/modifiable-tool.js'; +import type { ToolConfirmationPayload } from '../tools/tools.js'; +import type { WaitingToolCall } from './types.js'; + +export interface ModificationResult { + updatedParams: Record; + updatedDiff?: string; +} + +export class ToolModificationHandler { + /** + * Handles the "Modify with Editor" flow where an external editor is launched + * to modify the tool's parameters. + */ + async handleModifyWithEditor( + toolCall: WaitingToolCall, + editorType: EditorType, + signal: AbortSignal, + ): Promise { + if (!isModifiableDeclarativeTool(toolCall.tool)) { + return undefined; + } + + const confirmationDetails = toolCall.confirmationDetails; + const modifyContext = toolCall.tool.getModifyContext(signal); + + const contentOverrides = + confirmationDetails.type === 'edit' + ? { + currentContent: confirmationDetails.originalContent, + proposedContent: confirmationDetails.newContent, + } + : undefined; + + const { updatedParams, updatedDiff } = await modifyWithEditor< + typeof toolCall.request.args + >( + toolCall.request.args, + modifyContext as ModifyContext, + editorType, + signal, + contentOverrides, + ); + + return { + updatedParams, + updatedDiff, + }; + } + + /** + * Applies user-provided inline content updates (e.g. from the chat UI). + */ + async applyInlineModify( + toolCall: WaitingToolCall, + payload: ToolConfirmationPayload, + signal: AbortSignal, + ): Promise { + if ( + toolCall.confirmationDetails.type !== 'edit' || + !payload.newContent || + !isModifiableDeclarativeTool(toolCall.tool) + ) { + return undefined; + } + + const modifyContext = toolCall.tool.getModifyContext( + signal, + ) as ModifyContext; + const currentContent = await modifyContext.getCurrentContent( + toolCall.request.args, + ); + + const updatedParams = modifyContext.createUpdatedParams( + currentContent, + payload.newContent, + toolCall.request.args, + ); + + const updatedDiff = Diff.createPatch( + modifyContext.getFilePath(toolCall.request.args), + currentContent, + payload.newContent, + 'Current', + 'Proposed', + ); + + return { + updatedParams, + updatedDiff, + }; + } +} diff --git a/packages/core/src/utils/editor.test.ts b/packages/core/src/utils/editor.test.ts index 82b886f3662..78035c4cc96 100644 --- a/packages/core/src/utils/editor.test.ts +++ b/packages/core/src/utils/editor.test.ts @@ -311,11 +311,18 @@ describe('editor utils', () => { }); } - it('should return the correct command for emacs', () => { - const command = getDiffCommand('old.txt', 'new.txt', 'emacs'); + it('should return the correct command for emacs with escaped paths', () => { + const command = getDiffCommand( + 'old file "quote".txt', + 'new file \\back\\slash.txt', + 'emacs', + ); expect(command).toEqual({ command: 'emacs', - args: ['--eval', '(ediff "old.txt" "new.txt")'], + args: [ + '--eval', + '(ediff "old file \\"quote\\".txt" "new file \\\\back\\\\slash.txt")', + ], }); }); diff --git a/packages/core/src/utils/editor.ts b/packages/core/src/utils/editor.ts index b71a0b23ebc..742d1157fbf 100644 --- a/packages/core/src/utils/editor.ts +++ b/packages/core/src/utils/editor.ts @@ -60,6 +60,14 @@ function isValidEditorType(editor: string): editor is EditorType { return EDITORS_SET.has(editor); } +/** + * Escapes a string for use in an Emacs Lisp string literal. + * Wraps in double quotes and escapes backslashes and double quotes. + */ +function escapeELispString(str: string): string { + return `"${str.replace(/\\/g, '\\\\').replace(/"/g, '\\"')}"`; +} + interface DiffCommand { command: string; args: string[]; @@ -182,7 +190,10 @@ export function getDiffCommand( case 'emacs': return { command: 'emacs', - args: ['--eval', `(ediff "${oldPath}" "${newPath}")`], + args: [ + '--eval', + `(ediff ${escapeELispString(oldPath)} ${escapeELispString(newPath)})`, + ], }; case 'hx': return {