Skip to content

Commit e237de2

Browse files
authored
feat: model selection (#243)
Adds support for selecting a default model, and switching between models. For now we include support for Anthropic models, will follow up with a PR for OpenAI models (the gateway needs some formatting fixes for it) It also sets the default model to Opus - looks like we were using Sonnet before. <img width="697" height="338" alt="Screenshot 2025-12-11 at 11 18 33" src="https://github.com/user-attachments/assets/53830f0d-ef0f-47cd-816a-f04ed228c0c5" />
1 parent 0e28769 commit e237de2

File tree

11 files changed

+307
-23
lines changed

11 files changed

+307
-23
lines changed

apps/array/src/main/preload.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ contextBridge.exposeInMainWorld("electronAPI", {
193193
newToken: string,
194194
): Promise<void> =>
195195
ipcRenderer.invoke("agent-token-refresh", taskRunId, newToken),
196+
agentSetModel: async (sessionId: string, modelId: string): Promise<void> =>
197+
ipcRenderer.invoke("agent-set-model", sessionId, modelId),
196198
onAgentEvent: (
197199
channel: string,
198200
listener: (payload: unknown) => void,

apps/array/src/main/services/session-manager.ts

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ export interface SessionConfig {
130130
credentials: PostHogCredentials;
131131
logUrl?: string; // For reconnection from S3
132132
sdkSessionId?: string; // SDK session ID for resuming Claude Code context
133+
model?: string;
133134
}
134135

135136
export interface ManagedSession {
@@ -219,8 +220,15 @@ export class SessionManager {
219220
config: SessionConfig,
220221
isReconnect: boolean,
221222
): Promise<ManagedSession | null> {
222-
const { taskId, taskRunId, repoPath, credentials, logUrl, sdkSessionId } =
223-
config;
223+
const {
224+
taskId,
225+
taskRunId,
226+
repoPath,
227+
credentials,
228+
logUrl,
229+
sdkSessionId,
230+
model,
231+
} = config;
224232

225233
const existing = this.sessions.get(taskRunId);
226234
if (existing) {
@@ -276,7 +284,7 @@ export class SessionManager {
276284
await connection.newSession({
277285
cwd: repoPath,
278286
mcpServers,
279-
_meta: { sessionId: taskRunId },
287+
_meta: { sessionId: taskRunId, model },
280288
});
281289
}
282290

@@ -355,6 +363,24 @@ export class SessionManager {
355363
return this.sessions.get(taskRunId);
356364
}
357365

366+
async setSessionModel(taskRunId: string, modelId: string): Promise<void> {
367+
const session = this.sessions.get(taskRunId);
368+
if (!session) {
369+
throw new Error(`Session not found: ${taskRunId}`);
370+
}
371+
372+
try {
373+
await session.connection.extMethod("session/setModel", {
374+
sessionId: taskRunId,
375+
modelId,
376+
});
377+
log.info("Session model updated", { taskRunId, modelId });
378+
} catch (err) {
379+
log.error("Failed to set session model", { taskRunId, modelId, err });
380+
throw err;
381+
}
382+
}
383+
358384
listSessions(taskId?: string): ManagedSession[] {
359385
const all = Array.from(this.sessions.values());
360386
return taskId ? all.filter((s) => s.taskId === taskId) : all;
@@ -507,6 +533,7 @@ interface AgentSessionParams {
507533
projectId: number;
508534
logUrl?: string;
509535
sdkSessionId?: string;
536+
model?: string;
510537
}
511538

512539
type SessionResponse = { sessionId: string; channel: string };
@@ -536,6 +563,7 @@ function toSessionConfig(params: AgentSessionParams): SessionConfig {
536563
},
537564
logUrl: params.logUrl,
538565
sdkSessionId: params.sdkSessionId,
566+
model: params.model,
539567
};
540568
}
541569

@@ -648,4 +676,15 @@ export function registerAgentIpc(
648676
sessionManager.updateSessionToken(taskRunId, newToken);
649677
},
650678
);
679+
680+
ipcMain.handle(
681+
"agent-set-model",
682+
async (
683+
_event: IpcMainInvokeEvent,
684+
sessionId: string,
685+
modelId: string,
686+
): Promise<void> => {
687+
await sessionManager.setSessionModel(sessionId, modelId);
688+
},
689+
);
651690
}

apps/array/src/renderer/features/sessions/components/MessageEditor.tsx

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import {
2929
} from "react";
3030
import { flushSync } from "react-dom";
3131
import { useMessageDraftStore } from "../stores/messageDraftStore";
32+
import { ModelSelector } from "./ModelSelector";
3233

3334
const log = logger.scope("message-editor");
3435

@@ -174,6 +175,7 @@ export interface MessageEditorHandle {
174175

175176
interface MessageEditorProps {
176177
sessionId: string;
178+
taskId?: string;
177179
placeholder?: string;
178180
repoPath?: string | null;
179181
disabled?: boolean;
@@ -191,6 +193,7 @@ export const MessageEditor = forwardRef<
191193
(
192194
{
193195
sessionId,
196+
taskId,
194197
placeholder = "Type a message... @ to mention files",
195198
repoPath,
196199
disabled = false,
@@ -449,26 +452,29 @@ export const MessageEditor = forwardRef<
449452
<EditorContent editor={editor} />
450453
</Box>
451454
<Flex justify="between" align="center">
452-
<input
453-
ref={fileInputRef}
454-
type="file"
455-
multiple
456-
onChange={handleFileSelect}
457-
style={{ display: "none" }}
458-
/>
459-
<Tooltip content="Attach file">
460-
<IconButton
461-
size="1"
462-
variant="ghost"
463-
color="gray"
464-
onClick={() => fileInputRef.current?.click()}
465-
disabled={disabled}
466-
title="Attach file"
467-
style={{ marginLeft: "0px" }}
468-
>
469-
<Paperclip size={14} weight="bold" />
470-
</IconButton>
471-
</Tooltip>
455+
<Flex gap="2" align="center">
456+
<input
457+
ref={fileInputRef}
458+
type="file"
459+
multiple
460+
onChange={handleFileSelect}
461+
style={{ display: "none" }}
462+
/>
463+
<Tooltip content="Attach file">
464+
<IconButton
465+
size="1"
466+
variant="ghost"
467+
color="gray"
468+
onClick={() => fileInputRef.current?.click()}
469+
disabled={disabled}
470+
title="Attach file"
471+
style={{ marginLeft: "0px" }}
472+
>
473+
<Paperclip size={14} weight="bold" />
474+
</IconButton>
475+
</Tooltip>
476+
<ModelSelector taskId={taskId} disabled={disabled} />
477+
</Flex>
472478
<Flex gap="4" align="center">
473479
{isLoading && onCancel ? (
474480
<Tooltip content="Stop">
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import { useSettingsStore } from "@features/settings/stores/settingsStore";
2+
import { Select, Text } from "@radix-ui/themes";
3+
import {
4+
AVAILABLE_MODELS,
5+
getModelsByProvider,
6+
type ModelProvider,
7+
} from "@shared/types/models";
8+
import { Fragment } from "react";
9+
import { useSessionStore } from "../stores/sessionStore";
10+
11+
interface ModelSelectorProps {
12+
taskId?: string;
13+
disabled?: boolean;
14+
onModelChange?: (modelId: string) => void;
15+
}
16+
17+
export function ModelSelector({
18+
taskId,
19+
disabled,
20+
onModelChange,
21+
}: ModelSelectorProps) {
22+
const defaultModel = useSettingsStore((state) => state.defaultModel);
23+
const setDefaultModel = useSettingsStore((state) => state.setDefaultModel);
24+
const setSessionModel = useSessionStore((state) => state.setSessionModel);
25+
const session = useSessionStore((state) =>
26+
taskId ? state.getSessionForTask(taskId) : undefined,
27+
);
28+
29+
// Use session model if available, otherwise fall back to default
30+
const activeModel = session?.model ?? defaultModel;
31+
32+
const handleChange = (value: string) => {
33+
// Always update the default
34+
setDefaultModel(value);
35+
onModelChange?.(value);
36+
37+
// If there's an active session, update the model mid-session
38+
if (taskId && session?.status === "connected" && !session.isCloud) {
39+
setSessionModel(taskId, value);
40+
}
41+
};
42+
43+
const modelsByProvider = getModelsByProvider();
44+
const providers = (Object.keys(modelsByProvider) as ModelProvider[]).filter(
45+
(provider) => modelsByProvider[provider].models.length > 0,
46+
);
47+
48+
const currentModel = AVAILABLE_MODELS.find((m) => m.id === activeModel);
49+
const displayName = currentModel?.name ?? activeModel;
50+
51+
return (
52+
<Select.Root
53+
value={activeModel}
54+
onValueChange={handleChange}
55+
disabled={disabled}
56+
size="1"
57+
>
58+
<Select.Trigger
59+
variant="ghost"
60+
style={{
61+
fontSize: "var(--font-size-1)",
62+
color: "var(--gray-11)",
63+
padding: "4px 8px",
64+
marginLeft: "4px",
65+
height: "auto",
66+
minHeight: "unset",
67+
}}
68+
>
69+
<Text size="1" style={{ fontFamily: "var(--font-mono)" }}>
70+
{displayName}
71+
</Text>
72+
</Select.Trigger>
73+
<Select.Content position="popper" sideOffset={4}>
74+
{providers.map((provider, index) => (
75+
<Fragment key={provider}>
76+
{index > 0 && <Select.Separator />}
77+
<Select.Group>
78+
<Select.Label>{modelsByProvider[provider].name}</Select.Label>
79+
{modelsByProvider[provider].models.map((model) => (
80+
<Select.Item key={model.id} value={model.id}>
81+
{model.name}
82+
</Select.Item>
83+
))}
84+
</Select.Group>
85+
</Fragment>
86+
))}
87+
</Select.Content>
88+
</Select.Root>
89+
);
90+
}

apps/array/src/renderer/features/sessions/components/SessionView.tsx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ function RawLogEntry({
6464
interface SessionViewProps {
6565
events: SessionEvent[];
6666
sessionId: string | null;
67+
taskId?: string;
6768
isRunning: boolean;
6869
isPromptPending?: boolean;
6970
onSendPrompt: (text: string) => void;
@@ -342,6 +343,7 @@ function groupMessagesIntoTurns(
342343
export function SessionView({
343344
events,
344345
sessionId,
346+
taskId,
345347
isRunning,
346348
isPromptPending,
347349
onSendPrompt,
@@ -615,6 +617,7 @@ export function SessionView({
615617
<Box className="border-gray-6 border-t p-3">
616618
<MessageEditor
617619
sessionId={sessionId ?? "default"}
620+
taskId={taskId}
618621
placeholder="Type a message... @ to mention files"
619622
repoPath={repoPath}
620623
disabled={!isRunning}

apps/array/src/renderer/features/sessions/stores/sessionStore.ts

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import type {
33
SessionNotification,
44
} from "@agentclientprotocol/sdk";
55
import { useAuthStore } from "@features/auth/stores/authStore";
6+
import { useSettingsStore } from "@features/settings/stores/settingsStore";
67
import { logger } from "@renderer/lib/logger";
78
import type { Task } from "@shared/types";
89
import { create } from "zustand";
@@ -95,6 +96,7 @@ export interface AgentSession {
9596
isCloud: boolean;
9697
logUrl?: string;
9798
processedLineCount?: number;
99+
model?: string;
98100
}
99101

100102
interface ConnectParams {
@@ -126,6 +128,9 @@ interface SessionStore {
126128
// Cancel ongoing prompt without terminating session
127129
cancelPrompt: (taskId: string) => Promise<boolean>;
128130

131+
// Change model for active session
132+
setSessionModel: (taskId: string, modelId: string) => Promise<void>;
133+
129134
// Internal: subscribe to IPC events
130135
_subscribeToChannel: (
131136
taskRunId: string,
@@ -317,13 +322,15 @@ export const useSessionStore = create<SessionStore>((set, get) => ({
317322
return;
318323
}
319324

325+
const defaultModel = useSettingsStore.getState().defaultModel;
320326
const result = await window.electronAPI.agentStart({
321327
taskId,
322328
taskRunId: taskRun.id,
323329
repoPath,
324330
apiKey,
325331
apiHost,
326332
projectId,
333+
model: defaultModel,
327334
});
328335

329336
set((state) => ({
@@ -338,6 +345,7 @@ export const useSessionStore = create<SessionStore>((set, get) => ({
338345
status: "connected",
339346
isPromptPending: false,
340347
isCloud: false,
348+
model: defaultModel,
341349
},
342350
},
343351
}));
@@ -490,6 +498,35 @@ export const useSessionStore = create<SessionStore>((set, get) => ({
490498
}
491499
},
492500

501+
setSessionModel: async (taskId, modelId) => {
502+
const session = get().getSessionForTask(taskId);
503+
if (!session) {
504+
log.warn("No session found for model change", { taskId });
505+
return;
506+
}
507+
508+
if (session.isCloud) {
509+
log.warn("Model change not supported for cloud sessions");
510+
return;
511+
}
512+
513+
try {
514+
await window.electronAPI.agentSetModel(session.taskRunId, modelId);
515+
set((state) => ({
516+
sessions: {
517+
...state.sessions,
518+
[session.taskRunId]: {
519+
...state.sessions[session.taskRunId],
520+
model: modelId,
521+
},
522+
},
523+
}));
524+
log.info("Session model changed", { taskId, modelId });
525+
} catch (error) {
526+
log.error("Failed to change session model", { taskId, modelId, error });
527+
}
528+
},
529+
493530
_subscribeToChannel: (taskRunId, _taskId, channel) => {
494531
if (subscriptions.has(taskRunId)) {
495532
return;

0 commit comments

Comments
 (0)