Skip to content

Commit 3699b24

Browse files
committed
model selection
1 parent 0c1ee6b commit 3699b24

File tree

11 files changed

+265
-23
lines changed

11 files changed

+265
-23
lines changed

apps/array/src/main/preload.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ contextBridge.exposeInMainWorld("electronAPI", {
193193
newToken: string,
194194
): Promise<void> =>
195195
ipcRenderer.invoke("agent-token-refresh", taskRunId, newToken),
196+
agentSetModel: async (sessionId: string, modelId: string): Promise<void> =>
197+
ipcRenderer.invoke("agent-set-model", sessionId, modelId),
196198
onAgentEvent: (
197199
channel: string,
198200
listener: (payload: unknown) => void,

apps/array/src/main/services/session-manager.ts

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ export interface SessionConfig {
130130
credentials: PostHogCredentials;
131131
logUrl?: string; // For reconnection from S3
132132
sdkSessionId?: string; // SDK session ID for resuming Claude Code context
133+
model?: string;
133134
}
134135

135136
export interface ManagedSession {
@@ -219,8 +220,15 @@ export class SessionManager {
219220
config: SessionConfig,
220221
isReconnect: boolean,
221222
): Promise<ManagedSession | null> {
222-
const { taskId, taskRunId, repoPath, credentials, logUrl, sdkSessionId } =
223-
config;
223+
const {
224+
taskId,
225+
taskRunId,
226+
repoPath,
227+
credentials,
228+
logUrl,
229+
sdkSessionId,
230+
model,
231+
} = config;
224232

225233
const existing = this.sessions.get(taskRunId);
226234
if (existing) {
@@ -276,7 +284,7 @@ export class SessionManager {
276284
await connection.newSession({
277285
cwd: repoPath,
278286
mcpServers,
279-
_meta: { sessionId: taskRunId },
287+
_meta: { sessionId: taskRunId, model },
280288
});
281289
}
282290

@@ -355,6 +363,24 @@ export class SessionManager {
355363
return this.sessions.get(taskRunId);
356364
}
357365

366+
async setSessionModel(taskRunId: string, modelId: string): Promise<void> {
367+
const session = this.sessions.get(taskRunId);
368+
if (!session) {
369+
throw new Error(`Session not found: ${taskRunId}`);
370+
}
371+
372+
try {
373+
await session.connection.extMethod("session/setModel", {
374+
sessionId: taskRunId,
375+
modelId,
376+
});
377+
log.info("Session model updated", { taskRunId, modelId });
378+
} catch (err) {
379+
log.error("Failed to set session model", { taskRunId, modelId, err });
380+
throw err;
381+
}
382+
}
383+
358384
listSessions(taskId?: string): ManagedSession[] {
359385
const all = Array.from(this.sessions.values());
360386
return taskId ? all.filter((s) => s.taskId === taskId) : all;
@@ -507,6 +533,7 @@ interface AgentSessionParams {
507533
projectId: number;
508534
logUrl?: string;
509535
sdkSessionId?: string;
536+
model?: string;
510537
}
511538

512539
type SessionResponse = { sessionId: string; channel: string };
@@ -536,6 +563,7 @@ function toSessionConfig(params: AgentSessionParams): SessionConfig {
536563
},
537564
logUrl: params.logUrl,
538565
sdkSessionId: params.sdkSessionId,
566+
model: params.model,
539567
};
540568
}
541569

@@ -648,4 +676,15 @@ export function registerAgentIpc(
648676
sessionManager.updateSessionToken(taskRunId, newToken);
649677
},
650678
);
679+
680+
ipcMain.handle(
681+
"agent-set-model",
682+
async (
683+
_event: IpcMainInvokeEvent,
684+
sessionId: string,
685+
modelId: string,
686+
): Promise<void> => {
687+
await sessionManager.setSessionModel(sessionId, modelId);
688+
},
689+
);
651690
}

apps/array/src/renderer/features/sessions/components/MessageEditor.tsx

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { ArrowUp, Paperclip, Stop } from "@phosphor-icons/react";
33
import { Box, Flex, IconButton, Tooltip } from "@radix-ui/themes";
44
import { logger } from "@renderer/lib/logger";
55
import type { MentionItem } from "@shared/types";
6+
import { ModelSelector } from "./ModelSelector";
67
import { Extension, type JSONContent } from "@tiptap/core";
78
import { Mention } from "@tiptap/extension-mention";
89
import { Placeholder } from "@tiptap/extension-placeholder";
@@ -174,6 +175,7 @@ export interface MessageEditorHandle {
174175

175176
interface MessageEditorProps {
176177
sessionId: string;
178+
taskId?: string;
177179
placeholder?: string;
178180
repoPath?: string | null;
179181
disabled?: boolean;
@@ -191,6 +193,7 @@ export const MessageEditor = forwardRef<
191193
(
192194
{
193195
sessionId,
196+
taskId,
194197
placeholder = "Type a message... @ to mention files",
195198
repoPath,
196199
disabled = false,
@@ -449,26 +452,29 @@ export const MessageEditor = forwardRef<
449452
<EditorContent editor={editor} />
450453
</Box>
451454
<Flex justify="between" align="center">
452-
<input
453-
ref={fileInputRef}
454-
type="file"
455-
multiple
456-
onChange={handleFileSelect}
457-
style={{ display: "none" }}
458-
/>
459-
<Tooltip content="Attach file">
460-
<IconButton
461-
size="1"
462-
variant="ghost"
463-
color="gray"
464-
onClick={() => fileInputRef.current?.click()}
465-
disabled={disabled}
466-
title="Attach file"
467-
style={{ marginLeft: "0px" }}
468-
>
469-
<Paperclip size={14} weight="bold" />
470-
</IconButton>
471-
</Tooltip>
455+
<Flex gap="2" align="center">
456+
<input
457+
ref={fileInputRef}
458+
type="file"
459+
multiple
460+
onChange={handleFileSelect}
461+
style={{ display: "none" }}
462+
/>
463+
<Tooltip content="Attach file">
464+
<IconButton
465+
size="1"
466+
variant="ghost"
467+
color="gray"
468+
onClick={() => fileInputRef.current?.click()}
469+
disabled={disabled}
470+
title="Attach file"
471+
style={{ marginLeft: "0px" }}
472+
>
473+
<Paperclip size={14} weight="bold" />
474+
</IconButton>
475+
</Tooltip>
476+
<ModelSelector taskId={taskId} disabled={disabled} />
477+
</Flex>
472478
<Flex gap="4" align="center">
473479
{isLoading && onCancel ? (
474480
<Tooltip content="Stop">
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import { useSettingsStore } from "@features/settings/stores/settingsStore";
2+
import { Select, Text } from "@radix-ui/themes";
3+
import {
4+
AVAILABLE_MODELS,
5+
getModelsByProvider,
6+
type ModelProvider,
7+
} from "@shared/types/models";
8+
import { Fragment } from "react";
9+
import { useSessionStore } from "../stores/sessionStore";
10+
11+
interface ModelSelectorProps {
12+
taskId?: string;
13+
disabled?: boolean;
14+
onModelChange?: (modelId: string) => void;
15+
}
16+
17+
export function ModelSelector({
18+
taskId,
19+
disabled,
20+
onModelChange,
21+
}: ModelSelectorProps) {
22+
const selectedModel = useSettingsStore((state) => state.selectedModel);
23+
const setSelectedModel = useSettingsStore((state) => state.setSelectedModel);
24+
const setSessionModel = useSessionStore((state) => state.setSessionModel);
25+
const session = useSessionStore((state) =>
26+
taskId ? state.getSessionForTask(taskId) : undefined,
27+
);
28+
29+
const handleChange = (value: string) => {
30+
setSelectedModel(value);
31+
onModelChange?.(value);
32+
33+
// If there's an active session, update the model mid-session
34+
if (taskId && session?.status === "connected" && !session.isCloud) {
35+
setSessionModel(taskId, value);
36+
}
37+
};
38+
39+
const modelsByProvider = getModelsByProvider();
40+
const providers = (Object.keys(modelsByProvider) as ModelProvider[]).filter(
41+
(provider) => modelsByProvider[provider].models.length > 0,
42+
);
43+
44+
const currentModel = AVAILABLE_MODELS.find((m) => m.id === selectedModel);
45+
const displayName = currentModel?.name ?? selectedModel;
46+
47+
return (
48+
<Select.Root
49+
value={selectedModel}
50+
onValueChange={handleChange}
51+
disabled={disabled}
52+
size="1"
53+
>
54+
<Select.Trigger
55+
variant="ghost"
56+
style={{
57+
fontSize: "var(--font-size-1)",
58+
color: "var(--gray-11)",
59+
padding: "4px 8px",
60+
marginLeft: "4px",
61+
height: "auto",
62+
minHeight: "unset",
63+
}}
64+
>
65+
<Text size="1" style={{ fontFamily: "var(--font-mono)" }}>
66+
{displayName}
67+
</Text>
68+
</Select.Trigger>
69+
<Select.Content position="popper" sideOffset={4}>
70+
{providers.map((provider, index) => (
71+
<Fragment key={provider}>
72+
{index > 0 && <Select.Separator />}
73+
<Select.Group>
74+
<Select.Label>{modelsByProvider[provider].name}</Select.Label>
75+
{modelsByProvider[provider].models.map((model) => (
76+
<Select.Item key={model.id} value={model.id}>
77+
{model.name}
78+
</Select.Item>
79+
))}
80+
</Select.Group>
81+
</Fragment>
82+
))}
83+
</Select.Content>
84+
</Select.Root>
85+
);
86+
}

apps/array/src/renderer/features/sessions/components/SessionView.tsx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ function RawLogEntry({
6464
interface SessionViewProps {
6565
events: SessionEvent[];
6666
sessionId: string | null;
67+
taskId?: string;
6768
isRunning: boolean;
6869
isPromptPending?: boolean;
6970
onSendPrompt: (text: string) => void;
@@ -342,6 +343,7 @@ function groupMessagesIntoTurns(
342343
export function SessionView({
343344
events,
344345
sessionId,
346+
taskId,
345347
isRunning,
346348
isPromptPending,
347349
onSendPrompt,
@@ -615,6 +617,7 @@ export function SessionView({
615617
<Box className="border-gray-6 border-t p-3">
616618
<MessageEditor
617619
sessionId={sessionId ?? "default"}
620+
taskId={taskId}
618621
placeholder="Type a message... @ to mention files"
619622
repoPath={repoPath}
620623
disabled={!isRunning}

apps/array/src/renderer/features/sessions/stores/sessionStore.ts

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import type {
33
SessionNotification,
44
} from "@agentclientprotocol/sdk";
55
import { useAuthStore } from "@features/auth/stores/authStore";
6+
import { useSettingsStore } from "@features/settings/stores/settingsStore";
67
import { logger } from "@renderer/lib/logger";
78
import type { Task } from "@shared/types";
89
import { create } from "zustand";
@@ -126,6 +127,9 @@ interface SessionStore {
126127
// Cancel ongoing prompt without terminating session
127128
cancelPrompt: (taskId: string) => Promise<boolean>;
128129

130+
// Change model for active session
131+
setSessionModel: (taskId: string, modelId: string) => Promise<void>;
132+
129133
// Internal: subscribe to IPC events
130134
_subscribeToChannel: (
131135
taskRunId: string,
@@ -317,13 +321,15 @@ export const useSessionStore = create<SessionStore>((set, get) => ({
317321
return;
318322
}
319323

324+
const selectedModel = useSettingsStore.getState().selectedModel;
320325
const result = await window.electronAPI.agentStart({
321326
taskId,
322327
taskRunId: taskRun.id,
323328
repoPath,
324329
apiKey,
325330
apiHost,
326331
projectId,
332+
model: selectedModel,
327333
});
328334

329335
set((state) => ({
@@ -490,6 +496,26 @@ export const useSessionStore = create<SessionStore>((set, get) => ({
490496
}
491497
},
492498

499+
setSessionModel: async (taskId, modelId) => {
500+
const session = get().getSessionForTask(taskId);
501+
if (!session) {
502+
log.warn("No session found for model change", { taskId });
503+
return;
504+
}
505+
506+
if (session.isCloud) {
507+
log.warn("Model change not supported for cloud sessions");
508+
return;
509+
}
510+
511+
try {
512+
await window.electronAPI.agentSetModel(session.taskRunId, modelId);
513+
log.info("Session model changed", { taskId, modelId });
514+
} catch (error) {
515+
log.error("Failed to change session model", { taskId, modelId, error });
516+
}
517+
},
518+
493519
_subscribeToChannel: (taskRunId, _taskId, channel) => {
494520
if (subscriptions.has(taskRunId)) {
495521
return;

apps/array/src/renderer/features/settings/stores/settingsStore.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import type { WorkspaceMode } from "@shared/types";
2+
import { DEFAULT_MODEL } from "@shared/types/models";
23
import { create } from "zustand";
34
import { persist } from "zustand/middleware";
45

@@ -12,13 +13,15 @@ interface SettingsStore {
1213
lastUsedLocalWorkspaceMode: LocalWorkspaceMode;
1314
lastUsedWorkspaceMode: WorkspaceMode;
1415
createPR: boolean;
16+
selectedModel: string;
1517

1618
setAutoRunTasks: (autoRun: boolean) => void;
1719
setDefaultRunMode: (mode: DefaultRunMode) => void;
1820
setLastUsedRunMode: (mode: "local" | "cloud") => void;
1921
setLastUsedLocalWorkspaceMode: (mode: LocalWorkspaceMode) => void;
2022
setLastUsedWorkspaceMode: (mode: WorkspaceMode) => void;
2123
setCreatePR: (createPR: boolean) => void;
24+
setSelectedModel: (model: string) => void;
2225
}
2326

2427
export const useSettingsStore = create<SettingsStore>()(
@@ -30,6 +33,7 @@ export const useSettingsStore = create<SettingsStore>()(
3033
lastUsedLocalWorkspaceMode: "worktree",
3134
lastUsedWorkspaceMode: "worktree",
3235
createPR: true,
36+
selectedModel: DEFAULT_MODEL,
3337

3438
setAutoRunTasks: (autoRun) => set({ autoRunTasks: autoRun }),
3539
setDefaultRunMode: (mode) => set({ defaultRunMode: mode }),
@@ -38,6 +42,7 @@ export const useSettingsStore = create<SettingsStore>()(
3842
set({ lastUsedLocalWorkspaceMode: mode }),
3943
setLastUsedWorkspaceMode: (mode) => set({ lastUsedWorkspaceMode: mode }),
4044
setCreatePR: (createPR) => set({ createPR }),
45+
setSelectedModel: (model) => set({ selectedModel: model }),
4146
}),
4247
{
4348
name: "settings-storage",

apps/array/src/renderer/features/task-detail/components/TaskLogsPanel.tsx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ export function TaskLogsPanel({ taskId, task }: TaskLogsPanelProps) {
9999
<SessionView
100100
events={session?.events ?? []}
101101
sessionId={session?.taskRunId ?? null}
102+
taskId={taskId}
102103
isRunning={isRunning}
103104
isPromptPending={session?.isPromptPending}
104105
onSendPrompt={handleSendPrompt}

0 commit comments

Comments
 (0)