Skip to content

Commit c1df7fb

Browse files
adamfweidmanThomas-Shephard
authored andcommitted
fix: update currentSequenceModel when modelChanged (google-gemini#17051)
1 parent c349544 commit c1df7fb

File tree

7 files changed

+77
-36
lines changed

7 files changed

+77
-36
lines changed

packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,6 @@ describe('useQuotaAndFallback', () => {
166166
const intent = await promise!;
167167
expect(intent).toBe('retry_always');
168168

169-
// Verify activateFallbackMode was called
170-
expect(mockConfig.activateFallbackMode).toHaveBeenCalledWith(
171-
'gemini-flash',
172-
);
173-
174169
// The pending request should be cleared from the state
175170
expect(result.current.proQuotaRequest).toBeNull();
176171
expect(mockHistoryManager.addItem).toHaveBeenCalledTimes(1);
@@ -282,11 +277,6 @@ describe('useQuotaAndFallback', () => {
282277
const intent = await promise!;
283278
expect(intent).toBe('retry_always');
284279

285-
// Verify activateFallbackMode was called
286-
expect(mockConfig.activateFallbackMode).toHaveBeenCalledWith(
287-
'model-B',
288-
);
289-
290280
// The pending request should be cleared from the state
291281
expect(result.current.proQuotaRequest).toBeNull();
292282
expect(mockConfig.setQuotaErrorOccurred).toHaveBeenCalledWith(true);
@@ -342,11 +332,6 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`,
342332
const intent = await promise!;
343333
expect(intent).toBe('retry_always');
344334

345-
// Verify activateFallbackMode was called
346-
expect(mockConfig.activateFallbackMode).toHaveBeenCalledWith(
347-
'gemini-2.5-pro',
348-
);
349-
350335
expect(result.current.proQuotaRequest).toBeNull();
351336
});
352337
});
@@ -430,11 +415,6 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`,
430415
expect(intent).toBe('retry_always');
431416
expect(result.current.proQuotaRequest).toBeNull();
432417

433-
// Verify activateFallbackMode was called
434-
expect(mockConfig.activateFallbackMode).toHaveBeenCalledWith(
435-
'gemini-flash',
436-
);
437-
438418
// Verify quota error flags are reset
439419
expect(mockSetModelSwitchedFromQuotaError).toHaveBeenCalledWith(false);
440420
expect(mockConfig.setQuotaErrorOccurred).toHaveBeenCalledWith(false);

packages/cli/src/ui/hooks/useQuotaAndFallback.ts

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,6 @@ export function useQuotaAndFallback({
135135
config.setQuotaErrorOccurred(false);
136136

137137
if (choice === 'retry_always') {
138-
// Set the model to the fallback model for the current session.
139-
// This ensures the Footer updates and future turns use this model.
140-
// The change is not persisted, so the original model is restored on restart.
141-
config.activateFallbackMode(proQuotaRequest.fallbackModel);
142138
historyManager.addItem(
143139
{
144140
type: MessageType.INFO,

packages/core/src/config/config.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2037,9 +2037,8 @@ export class Config {
20372037
*/
20382038
async dispose(): Promise<void> {
20392039
coreEvents.off(CoreEvent.AgentsRefreshed, this.onAgentsRefreshed);
2040-
if (this.agentRegistry) {
2041-
this.agentRegistry.dispose();
2042-
}
2040+
this.agentRegistry?.dispose();
2041+
this.geminiClient?.dispose();
20432042
if (this.mcpClientManager) {
20442043
await this.mcpClientManager.stop();
20452044
}

packages/core/src/core/client.test.ts

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ import type {
4848
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
4949
import * as policyCatalog from '../availability/policyCatalog.js';
5050
import { partToString } from '../utils/partUtils.js';
51+
import { coreEvents } from '../utils/events.js';
5152

5253
// Mock fs module to prevent actual file system operations during tests
5354
const mockFileSystem = new Map<string, string>();
@@ -281,6 +282,7 @@ describe('Gemini Client (client.ts)', () => {
281282
});
282283

283284
afterEach(() => {
285+
client.dispose();
284286
vi.restoreAllMocks();
285287
});
286288

@@ -1757,6 +1759,55 @@ ${JSON.stringify(
17571759
expect.any(AbortSignal),
17581760
);
17591761
});
1762+
1763+
it('should re-route within the same prompt when the configured model changes', async () => {
1764+
mockTurnRunFn.mockClear();
1765+
mockTurnRunFn.mockImplementation(async function* () {
1766+
yield { type: 'content', value: 'Hello' };
1767+
});
1768+
1769+
mockRouterService.route.mockResolvedValueOnce({
1770+
model: 'original-model',
1771+
reason: 'test',
1772+
});
1773+
1774+
let stream = client.sendMessageStream(
1775+
[{ text: 'Hi' }],
1776+
new AbortController().signal,
1777+
'prompt-1',
1778+
);
1779+
await fromAsync(stream);
1780+
1781+
expect(mockRouterService.route).toHaveBeenCalledTimes(1);
1782+
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
1783+
1,
1784+
{ model: 'original-model' },
1785+
[{ text: 'Hi' }],
1786+
expect.any(AbortSignal),
1787+
);
1788+
1789+
mockRouterService.route.mockResolvedValue({
1790+
model: 'fallback-model',
1791+
reason: 'test',
1792+
});
1793+
vi.mocked(mockConfig.getModel).mockReturnValue('gemini-2.5-flash');
1794+
coreEvents.emitModelChanged('gemini-2.5-flash');
1795+
1796+
stream = client.sendMessageStream(
1797+
[{ text: 'Continue' }],
1798+
new AbortController().signal,
1799+
'prompt-1',
1800+
);
1801+
await fromAsync(stream);
1802+
1803+
expect(mockRouterService.route).toHaveBeenCalledTimes(2);
1804+
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
1805+
2,
1806+
{ model: 'fallback-model' },
1807+
[{ text: 'Continue' }],
1808+
expect.any(AbortSignal),
1809+
);
1810+
});
17601811
});
17611812

17621813
it('should use getGlobalMemory for system instruction when JIT is enabled', async () => {

packages/core/src/core/client.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ import {
5858
import { resolveModel } from '../config/models.js';
5959
import type { RetryAvailabilityContext } from '../utils/retry.js';
6060
import { partToString } from '../utils/partUtils.js';
61+
import { coreEvents, CoreEvent } from '../utils/events.js';
6162

6263
const MAX_TURNS = 100;
6364

@@ -94,8 +95,14 @@ export class GeminiClient {
9495
this.loopDetector = new LoopDetectionService(config);
9596
this.compressionService = new ChatCompressionService();
9697
this.lastPromptId = this.config.getSessionId();
98+
99+
coreEvents.on(CoreEvent.ModelChanged, this.handleModelChanged);
97100
}
98101

102+
private handleModelChanged = () => {
103+
this.currentSequenceModel = null;
104+
};
105+
99106
// Hook state to deduplicate BeforeAgent calls and track response for
100107
// AfterAgent
101108
private hookStateMap = new Map<
@@ -253,6 +260,10 @@ export class GeminiClient {
253260
this.updateTelemetryTokenCount();
254261
}
255262

263+
dispose() {
264+
coreEvents.off(CoreEvent.ModelChanged, this.handleModelChanged);
265+
}
266+
256267
async resumeChat(
257268
history: Content[],
258269
resumedSessionData?: ResumedSessionData,

packages/core/src/fallback/handler.test.ts

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ const createMockConfig = (overrides: Partial<Config> = {}): Config =>
6565
fallbackHandler: undefined,
6666
getFallbackModelHandler: vi.fn(),
6767
setActiveModel: vi.fn(),
68+
setModel: vi.fn(),
69+
activateFallbackMode: vi.fn(),
6870
getModelAvailabilityService: vi.fn(() =>
6971
createAvailabilityServiceMock({
7072
selectedModel: FALLBACK_MODEL,
@@ -198,7 +200,7 @@ describe('handleFallback', () => {
198200

199201
expect(result).toBe(true);
200202
expect(policyConfig.getFallbackModelHandler).not.toHaveBeenCalled();
201-
expect(policyConfig.setActiveModel).toHaveBeenCalledWith(
203+
expect(policyConfig.activateFallbackMode).toHaveBeenCalledWith(
202204
DEFAULT_GEMINI_FLASH_MODEL,
203205
);
204206
} finally {
@@ -273,7 +275,7 @@ describe('handleFallback', () => {
273275
expect(openBrowserSecurely).toHaveBeenCalledWith(
274276
'https://goo.gle/set-up-gemini-code-assist',
275277
);
276-
expect(policyConfig.setActiveModel).not.toHaveBeenCalled();
278+
expect(policyConfig.activateFallbackMode).not.toHaveBeenCalled();
277279
});
278280

279281
it('should catch errors from the handler, log an error, and return null', async () => {
@@ -378,7 +380,7 @@ describe('handleFallback', () => {
378380
);
379381
});
380382

381-
it('calls setActiveModel and logs telemetry when handler returns "retry_always"', async () => {
383+
it('calls activateFallbackMode when handler returns "retry_always"', async () => {
382384
policyHandler.mockResolvedValue('retry_always');
383385
vi.mocked(policyConfig.getModel).mockReturnValue(
384386
DEFAULT_GEMINI_MODEL_AUTO,
@@ -391,11 +393,13 @@ describe('handleFallback', () => {
391393
);
392394

393395
expect(result).toBe(true);
394-
expect(policyConfig.setActiveModel).toHaveBeenCalledWith(FALLBACK_MODEL);
396+
expect(policyConfig.activateFallbackMode).toHaveBeenCalledWith(
397+
FALLBACK_MODEL,
398+
);
395399
// TODO: add logging expect statement
396400
});
397401

398-
it('does NOT call setActiveModel when handler returns "stop"', async () => {
402+
it('does NOT call activateFallbackMode when handler returns "stop"', async () => {
399403
policyHandler.mockResolvedValue('stop');
400404

401405
const result = await handleFallback(
@@ -405,11 +409,11 @@ describe('handleFallback', () => {
405409
);
406410

407411
expect(result).toBe(false);
408-
expect(policyConfig.setActiveModel).not.toHaveBeenCalled();
412+
expect(policyConfig.activateFallbackMode).not.toHaveBeenCalled();
409413
// TODO: add logging expect statement
410414
});
411415

412-
it('does NOT call setActiveModel when handler returns "retry_once"', async () => {
416+
it('does NOT call activateFallbackMode when handler returns "retry_once"', async () => {
413417
policyHandler.mockResolvedValue('retry_once');
414418

415419
const result = await handleFallback(
@@ -419,7 +423,7 @@ describe('handleFallback', () => {
419423
);
420424

421425
expect(result).toBe(true);
422-
expect(policyConfig.setActiveModel).not.toHaveBeenCalled();
426+
expect(policyConfig.activateFallbackMode).not.toHaveBeenCalled();
423427
});
424428
});
425429
});

packages/core/src/fallback/handler.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ async function processIntent(
131131
case 'retry_always':
132132
// TODO(telemetry): Implement generic fallback event logging. Existing
133133
// logFlashFallback is specific to a single Model.
134-
config.setActiveModel(fallbackModel);
134+
config.activateFallbackMode(fallbackModel);
135135
return true;
136136

137137
case 'retry_once':

0 commit comments

Comments
 (0)