Skip to content

Commit 903b1e9

Browse files
adamfweidmanscidomino
authored andcommitted
fix: update currentSequenceModel when modelChanged (#17051)
1 parent bd15549 commit 903b1e9

File tree

7 files changed

+77
-36
lines changed

7 files changed

+77
-36
lines changed

packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,6 @@ describe('useQuotaAndFallback', () => {
166166
const intent = await promise!;
167167
expect(intent).toBe('retry_always');
168168

169-
// Verify activateFallbackMode was called
170-
expect(mockConfig.activateFallbackMode).toHaveBeenCalledWith(
171-
'gemini-flash',
172-
);
173-
174169
// The pending request should be cleared from the state
175170
expect(result.current.proQuotaRequest).toBeNull();
176171
expect(mockHistoryManager.addItem).toHaveBeenCalledTimes(1);
@@ -282,11 +277,6 @@ describe('useQuotaAndFallback', () => {
282277
const intent = await promise!;
283278
expect(intent).toBe('retry_always');
284279

285-
// Verify activateFallbackMode was called
286-
expect(mockConfig.activateFallbackMode).toHaveBeenCalledWith(
287-
'model-B',
288-
);
289-
290280
// The pending request should be cleared from the state
291281
expect(result.current.proQuotaRequest).toBeNull();
292282
expect(mockConfig.setQuotaErrorOccurred).toHaveBeenCalledWith(true);
@@ -342,11 +332,6 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`,
342332
const intent = await promise!;
343333
expect(intent).toBe('retry_always');
344334

345-
// Verify activateFallbackMode was called
346-
expect(mockConfig.activateFallbackMode).toHaveBeenCalledWith(
347-
'gemini-2.5-pro',
348-
);
349-
350335
expect(result.current.proQuotaRequest).toBeNull();
351336
});
352337
});
@@ -430,11 +415,6 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`,
430415
expect(intent).toBe('retry_always');
431416
expect(result.current.proQuotaRequest).toBeNull();
432417

433-
// Verify activateFallbackMode was called
434-
expect(mockConfig.activateFallbackMode).toHaveBeenCalledWith(
435-
'gemini-flash',
436-
);
437-
438418
// Verify quota error flags are reset
439419
expect(mockSetModelSwitchedFromQuotaError).toHaveBeenCalledWith(false);
440420
expect(mockConfig.setQuotaErrorOccurred).toHaveBeenCalledWith(false);

packages/cli/src/ui/hooks/useQuotaAndFallback.ts

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,6 @@ export function useQuotaAndFallback({
135135
config.setQuotaErrorOccurred(false);
136136

137137
if (choice === 'retry_always') {
138-
// Set the model to the fallback model for the current session.
139-
// This ensures the Footer updates and future turns use this model.
140-
// The change is not persisted, so the original model is restored on restart.
141-
config.activateFallbackMode(proQuotaRequest.fallbackModel);
142138
historyManager.addItem(
143139
{
144140
type: MessageType.INFO,

packages/core/src/config/config.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1955,9 +1955,8 @@ export class Config {
19551955
*/
19561956
async dispose(): Promise<void> {
19571957
coreEvents.off(CoreEvent.AgentsRefreshed, this.onAgentsRefreshed);
1958-
if (this.agentRegistry) {
1959-
this.agentRegistry.dispose();
1960-
}
1958+
this.agentRegistry?.dispose();
1959+
this.geminiClient?.dispose();
19611960
if (this.mcpClientManager) {
19621961
await this.mcpClientManager.stop();
19631962
}

packages/core/src/core/client.test.ts

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ import type {
4848
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
4949
import * as policyCatalog from '../availability/policyCatalog.js';
5050
import { partToString } from '../utils/partUtils.js';
51+
import { coreEvents } from '../utils/events.js';
5152

5253
vi.mock('../services/chatCompressionService.js');
5354

@@ -290,6 +291,7 @@ describe('Gemini Client (client.ts)', () => {
290291
});
291292

292293
afterEach(() => {
294+
client.dispose();
293295
vi.restoreAllMocks();
294296
});
295297

@@ -1579,6 +1581,55 @@ ${JSON.stringify(
15791581
expect.any(AbortSignal),
15801582
);
15811583
});
1584+
1585+
it('should re-route within the same prompt when the configured model changes', async () => {
1586+
mockTurnRunFn.mockClear();
1587+
mockTurnRunFn.mockImplementation(async function* () {
1588+
yield { type: 'content', value: 'Hello' };
1589+
});
1590+
1591+
mockRouterService.route.mockResolvedValueOnce({
1592+
model: 'original-model',
1593+
reason: 'test',
1594+
});
1595+
1596+
let stream = client.sendMessageStream(
1597+
[{ text: 'Hi' }],
1598+
new AbortController().signal,
1599+
'prompt-1',
1600+
);
1601+
await fromAsync(stream);
1602+
1603+
expect(mockRouterService.route).toHaveBeenCalledTimes(1);
1604+
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
1605+
1,
1606+
{ model: 'original-model' },
1607+
[{ text: 'Hi' }],
1608+
expect.any(AbortSignal),
1609+
);
1610+
1611+
mockRouterService.route.mockResolvedValue({
1612+
model: 'fallback-model',
1613+
reason: 'test',
1614+
});
1615+
vi.mocked(mockConfig.getModel).mockReturnValue('gemini-2.5-flash');
1616+
coreEvents.emitModelChanged('gemini-2.5-flash');
1617+
1618+
stream = client.sendMessageStream(
1619+
[{ text: 'Continue' }],
1620+
new AbortController().signal,
1621+
'prompt-1',
1622+
);
1623+
await fromAsync(stream);
1624+
1625+
expect(mockRouterService.route).toHaveBeenCalledTimes(2);
1626+
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
1627+
2,
1628+
{ model: 'fallback-model' },
1629+
[{ text: 'Continue' }],
1630+
expect.any(AbortSignal),
1631+
);
1632+
});
15821633
});
15831634

15841635
it('should use getGlobalMemory for system instruction when JIT is enabled', async () => {

packages/core/src/core/client.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ import {
5858
import { resolveModel } from '../config/models.js';
5959
import type { RetryAvailabilityContext } from '../utils/retry.js';
6060
import { partToString } from '../utils/partUtils.js';
61+
import { coreEvents, CoreEvent } from '../utils/events.js';
6162

6263
const MAX_TURNS = 100;
6364

@@ -94,8 +95,14 @@ export class GeminiClient {
9495
this.loopDetector = new LoopDetectionService(config);
9596
this.compressionService = new ChatCompressionService();
9697
this.lastPromptId = this.config.getSessionId();
98+
99+
coreEvents.on(CoreEvent.ModelChanged, this.handleModelChanged);
97100
}
98101

102+
private handleModelChanged = () => {
103+
this.currentSequenceModel = null;
104+
};
105+
99106
// Hook state to deduplicate BeforeAgent calls and track response for
100107
// AfterAgent
101108
private hookStateMap = new Map<
@@ -253,6 +260,10 @@ export class GeminiClient {
253260
this.updateTelemetryTokenCount();
254261
}
255262

263+
dispose() {
264+
coreEvents.off(CoreEvent.ModelChanged, this.handleModelChanged);
265+
}
266+
256267
async resumeChat(
257268
history: Content[],
258269
resumedSessionData?: ResumedSessionData,

packages/core/src/fallback/handler.test.ts

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ const createMockConfig = (overrides: Partial<Config> = {}): Config =>
6565
fallbackHandler: undefined,
6666
getFallbackModelHandler: vi.fn(),
6767
setActiveModel: vi.fn(),
68+
setModel: vi.fn(),
69+
activateFallbackMode: vi.fn(),
6870
getModelAvailabilityService: vi.fn(() =>
6971
createAvailabilityServiceMock({
7072
selectedModel: FALLBACK_MODEL,
@@ -198,7 +200,7 @@ describe('handleFallback', () => {
198200

199201
expect(result).toBe(true);
200202
expect(policyConfig.getFallbackModelHandler).not.toHaveBeenCalled();
201-
expect(policyConfig.setActiveModel).toHaveBeenCalledWith(
203+
expect(policyConfig.activateFallbackMode).toHaveBeenCalledWith(
202204
DEFAULT_GEMINI_FLASH_MODEL,
203205
);
204206
} finally {
@@ -273,7 +275,7 @@ describe('handleFallback', () => {
273275
expect(openBrowserSecurely).toHaveBeenCalledWith(
274276
'https://goo.gle/set-up-gemini-code-assist',
275277
);
276-
expect(policyConfig.setActiveModel).not.toHaveBeenCalled();
278+
expect(policyConfig.activateFallbackMode).not.toHaveBeenCalled();
277279
});
278280

279281
it('should catch errors from the handler, log an error, and return null', async () => {
@@ -378,7 +380,7 @@ describe('handleFallback', () => {
378380
);
379381
});
380382

381-
it('calls setActiveModel and logs telemetry when handler returns "retry_always"', async () => {
383+
it('calls activateFallbackMode when handler returns "retry_always"', async () => {
382384
policyHandler.mockResolvedValue('retry_always');
383385
vi.mocked(policyConfig.getModel).mockReturnValue(
384386
DEFAULT_GEMINI_MODEL_AUTO,
@@ -391,11 +393,13 @@ describe('handleFallback', () => {
391393
);
392394

393395
expect(result).toBe(true);
394-
expect(policyConfig.setActiveModel).toHaveBeenCalledWith(FALLBACK_MODEL);
396+
expect(policyConfig.activateFallbackMode).toHaveBeenCalledWith(
397+
FALLBACK_MODEL,
398+
);
395399
// TODO: add logging expect statement
396400
});
397401

398-
it('does NOT call setActiveModel when handler returns "stop"', async () => {
402+
it('does NOT call activateFallbackMode when handler returns "stop"', async () => {
399403
policyHandler.mockResolvedValue('stop');
400404

401405
const result = await handleFallback(
@@ -405,11 +409,11 @@ describe('handleFallback', () => {
405409
);
406410

407411
expect(result).toBe(false);
408-
expect(policyConfig.setActiveModel).not.toHaveBeenCalled();
412+
expect(policyConfig.activateFallbackMode).not.toHaveBeenCalled();
409413
// TODO: add logging expect statement
410414
});
411415

412-
it('does NOT call setActiveModel when handler returns "retry_once"', async () => {
416+
it('does NOT call activateFallbackMode when handler returns "retry_once"', async () => {
413417
policyHandler.mockResolvedValue('retry_once');
414418

415419
const result = await handleFallback(
@@ -419,7 +423,7 @@ describe('handleFallback', () => {
419423
);
420424

421425
expect(result).toBe(true);
422-
expect(policyConfig.setActiveModel).not.toHaveBeenCalled();
426+
expect(policyConfig.activateFallbackMode).not.toHaveBeenCalled();
423427
});
424428
});
425429
});

packages/core/src/fallback/handler.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ async function processIntent(
131131
case 'retry_always':
132132
// TODO(telemetry): Implement generic fallback event logging. Existing
133133
// logFlashFallback is specific to a single Model.
134-
config.setActiveModel(fallbackModel);
134+
config.activateFallbackMode(fallbackModel);
135135
return true;
136136

137137
case 'retry_once':

0 commit comments

Comments
 (0)