Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 234 additions & 0 deletions packages/cli/src/ui/hooks/useGeminiStream.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ const MockedGeminiClientClass = vi.hoisted(() =>
this.startChat = mockStartChat;
this.sendMessageStream = mockSendMessageStream;
this.addHistory = vi.fn();
this.getCurrentSequenceModel = vi.fn().mockReturnValue('gemini-2.0-flash');
this.getChat = vi.fn().mockReturnValue({
recordCompletedToolCalls: vi.fn(),
});
Expand Down Expand Up @@ -84,6 +85,9 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => {
UserPromptEvent: MockedUserPromptEvent,
parseAndFormatApiError: mockParseAndFormatApiError,
tokenLimit: vi.fn().mockReturnValue(100), // Mock tokenLimit
ValidationRequiredError: class {
userHandled = false;
},
};
});

Expand Down Expand Up @@ -222,6 +226,11 @@ describe('useGeminiStream', () => {
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
addHistory: vi.fn(),
getContentGenerator: vi.fn().mockReturnValue({
generateContent: vi.fn(),
generateContentStream: vi.fn(),
countTokens: vi.fn(),
}),
getSessionId() {
return 'test-session-id';
},
Expand Down Expand Up @@ -3245,4 +3254,229 @@ describe('useGeminiStream', () => {
expect(coreEvents.emitFeedback).not.toHaveBeenCalled();
});
});

it('should not resubmit tool responses that have already been submitted to Gemini', async () => {
// This tests the fix for the unrecoverable 400 error where already-submitted
// tools were being resubmitted due to stale closure in handleCompletedTools.

const alreadySubmittedTool = {
request: {
callId: 'already-submitted-call',
name: 'tool1',
args: {},
isClientInitiated: false,
prompt_id: 'prompt-id-race',
},
status: 'success',
responseSubmittedToGemini: true, // KEY: Already submitted
response: {
callId: 'already-submitted-call',
responseParts: [{ text: 'already sent' }],
errorType: undefined,
},
tool: { displayName: 'Tool1' },
invocation: {
getDescription: () => 'Mock description',
} as unknown as AnyToolInvocation,
} as TrackedCompletedToolCall;

const newTool = {
request: {
callId: 'new-call',
name: 'tool2',
args: {},
isClientInitiated: false,
prompt_id: 'prompt-id-race',
},
status: 'success',
responseSubmittedToGemini: false, // Not yet submitted
response: {
callId: 'new-call',
responseParts: [{ text: 'new response' }],
errorType: undefined,
},
tool: { displayName: 'Tool2' },
invocation: {
getDescription: () => 'Mock description',
} as unknown as AnyToolInvocation,
} as TrackedCompletedToolCall;

let capturedOnComplete:
| ((tools: TrackedToolCall[]) => Promise<void>)
| null = null;

mockUseToolScheduler.mockImplementation(
(onComplete: (tools: TrackedToolCall[]) => Promise<void>) => {
capturedOnComplete = onComplete;
return [
[alreadySubmittedTool, newTool], // Current tracked state with flags
mockScheduleToolCalls,
mockMarkToolsAsSubmitted,
vi.fn(),
];
},
);

renderHook(() =>
useGeminiStream(
new MockedGeminiClientClass(mockConfig),
[],
mockAddItem,
mockConfig,
mockLoadedSettings,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
() => 'vscode' as EditorType,
() => {},
() => Promise.resolve(),
false,
() => {},
() => {},
() => {},
80,
24,
),
);

// Simulate scheduler calling onComplete
const fromScheduler = [
{ ...alreadySubmittedTool, responseSubmittedToGemini: undefined },
{ ...newTool, responseSubmittedToGemini: undefined },
];

await act(async () => {
if (capturedOnComplete) {
await capturedOnComplete(fromScheduler as TrackedToolCall[]);
}
});

// Only the NEW tool should be marked as submitted
await waitFor(() => {
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['new-call']);
// Should NOT be called with already-submitted-call
expect(mockMarkToolsAsSubmitted).not.toHaveBeenCalledWith([
'already-submitted-call',
]);
});
});

it('should await submitQuery before marking tools as submitted (race condition fix)', async () => {
// This tests the fix for the race condition where markToolsAsSubmitted was
// called BEFORE submitQuery completed, allowing user prompts to race ahead
// of tool responses.

const toolCallResponseParts: Part[] = [{ text: 'tool response' }];

const completedTool = {
request: {
callId: 'race-test-call',
name: 'test_tool',
args: {},
isClientInitiated: false,
prompt_id: 'prompt-id-race-test',
},
status: 'success',
responseSubmittedToGemini: false,
response: {
callId: 'race-test-call',
responseParts: toolCallResponseParts,
errorType: undefined,
},
tool: { displayName: 'TestTool' },
invocation: {
getDescription: () => 'Mock description',
} as unknown as AnyToolInvocation,
} as TrackedCompletedToolCall;

// Track the order of operations
const callOrder: string[] = [];
let resolveStreamPromise: () => void;
const streamPromise = new Promise<void>((resolve) => {
resolveStreamPromise = resolve;
});

// Mock sendMessageStream to be a slow async operation
mockSendMessageStream.mockImplementation(() => {
callOrder.push('sendMessageStream:start');
return (async function* () {
await streamPromise; // Wait until we explicitly resolve
callOrder.push('sendMessageStream:end');
yield {
type: ServerGeminiEventType.Finished,
value: { finishReason: 'STOP' },
};
})();
});

// Track when markToolsAsSubmitted is called
mockMarkToolsAsSubmitted.mockImplementation((callIds: string[]) => {
callOrder.push(`markToolsAsSubmitted:${callIds.join(',')}`);
});

let capturedOnComplete:
| ((tools: TrackedToolCall[]) => Promise<void>)
| null = null;

mockUseToolScheduler.mockImplementation(
(onComplete: (tools: TrackedToolCall[]) => Promise<void>) => {
capturedOnComplete = onComplete;
return [
[completedTool],
mockScheduleToolCalls,
mockMarkToolsAsSubmitted,
vi.fn(),
];
},
);

renderHook(() =>
useGeminiStream(
new MockedGeminiClientClass(mockConfig),
[],
mockAddItem,
mockConfig,
mockLoadedSettings,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
() => 'vscode' as EditorType,
() => {},
() => Promise.resolve(),
false,
() => {},
() => {},
() => {},
80,
24,
),
);

// Start tool completion
const completionPromise = act(async () => {
if (capturedOnComplete) {
await capturedOnComplete([completedTool] as TrackedToolCall[]);
}
});

// Give time for sendMessageStream to start
await new Promise((r) => setTimeout(r, 10));

expect(callOrder).toContain('sendMessageStream:start');
expect(callOrder).not.toContain('markToolsAsSubmitted:race-test-call');

// Now resolve the stream
resolveStreamPromise!();
await completionPromise;

// After stream completes, markToolsAsSubmitted should be called
await waitFor(() => {
expect(callOrder).toContain('markToolsAsSubmitted:race-test-call');
});

// Verify the correct order
const streamEndIndex = callOrder.indexOf('sendMessageStream:end');
const markIndex = callOrder.indexOf('markToolsAsSubmitted:race-test-call');
expect(streamEndIndex).toBeLessThan(markIndex);
});
});
40 changes: 27 additions & 13 deletions packages/cli/src/ui/hooks/useGeminiStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,7 @@ export const useGeminiStream = (
'info',
'Waiting for MCP servers to initialize... Slash commands are still available.',
);
return;
return false;
}

const queryId = `${Date.now()}-${Math.random()}`;
Expand All @@ -987,7 +987,7 @@ export const useGeminiStream = (
streamingState === StreamingState.WaitingForConfirmation) &&
!options?.isContinuation
)
return;
return false;

const userMessageTimestamp = Date.now();

Expand All @@ -1013,7 +1013,7 @@ export const useGeminiStream = (
);

if (!shouldProceed || queryToSend === null) {
return;
return false;
}

if (!options?.isContinuation) {
Expand Down Expand Up @@ -1054,7 +1054,7 @@ export const useGeminiStream = (
);

if (processingStatus === StreamProcessingStatus.UserCancelled) {
return;
return false;
}

if (pendingHistoryItemRef.current) {
Expand All @@ -1065,7 +1065,7 @@ export const useGeminiStream = (
loopDetectedRef.current = false;
// Show the confirmation dialog to choose whether to disable loop detection
setLoopDetectionConfirmationRequest({
onComplete: (result: {
onComplete: async (result: {
userSelection: 'disable' | 'keep';
}) => {
setLoopDetectionConfirmationRequest(null);
Expand All @@ -1081,8 +1081,7 @@ export const useGeminiStream = (
});

if (lastQueryRef.current && lastPromptIdRef.current) {
// eslint-disable-next-line @typescript-eslint/no-floating-promises
submitQuery(
await submitQuery(
lastQueryRef.current,
{ isContinuation: true },
lastPromptIdRef.current,
Expand All @@ -1097,6 +1096,7 @@ export const useGeminiStream = (
},
});
}
return true;
} catch (error: unknown) {
spanMetadata.error = error;
if (error instanceof UnauthorizedError) {
Expand All @@ -1122,6 +1122,7 @@ export const useGeminiStream = (
userMessageTimestamp,
);
}
return false;
} finally {
if (activeQueryIdRef.current === queryId) {
setIsResponding(false);
Expand Down Expand Up @@ -1195,6 +1196,17 @@ export const useGeminiStream = (
(
tc: TrackedToolCall,
): tc is TrackedCompletedToolCall | TrackedCancelledToolCall => {
// Check if we've already submitted this tool call.
// We need to look up the tracked version because the incoming 'tc'
// comes directly from the scheduler core and lacks the
// 'responseSubmittedToGemini' flag.
const trackedToolCall = toolCalls.find(
(t) => t.request.callId === tc.request.callId,
);
if (trackedToolCall?.responseSubmittedToGemini) {
return false;
}

const isTerminalState =
tc.status === 'success' ||
tc.status === 'error' ||
Expand Down Expand Up @@ -1286,8 +1298,7 @@ export const useGeminiStream = (
const combinedParts = geminiTools.flatMap(
(toolCall) => toolCall.response.responseParts,
);
// eslint-disable-next-line @typescript-eslint/no-floating-promises
geminiClient.addHistory({
await geminiClient.addHistory({
role: 'user',
parts: combinedParts,
});
Expand All @@ -1311,21 +1322,23 @@ export const useGeminiStream = (
(toolCall) => toolCall.request.prompt_id,
);

markToolsAsSubmitted(callIdsToMarkAsSubmitted);

// Don't continue if model was switched due to quota error
if (modelSwitchedFromQuotaError) {
markToolsAsSubmitted(callIdsToMarkAsSubmitted);
return;
}

// eslint-disable-next-line @typescript-eslint/no-floating-promises
submitQuery(
const submitted = await submitQuery(
responsesToSend,
{
isContinuation: true,
},
prompt_ids[0],
);

if (submitted) {
markToolsAsSubmitted(callIdsToMarkAsSubmitted);
}
},
[
submitQuery,
Expand All @@ -1334,6 +1347,7 @@ export const useGeminiStream = (
performMemoryRefresh,
modelSwitchedFromQuotaError,
addItem,
toolCalls,
],
);

Expand Down