Skip to content

Commit 821a97f

Browse files
committed
feat(ui/circuits): add remix button to fill NewGraphDialog with initial configs
1 parent 4c750ee commit 821a97f

File tree

2 files changed

+105
-32
lines changed

2 files changed

+105
-32
lines changed

ui-ssr/src/components/circuits/new-graph-dialog.tsx

Lines changed: 71 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -35,28 +35,58 @@ import { useDebounce } from '@/hooks/use-debounce'
3535
interface NewGraphDialogProps {
3636
saeSets: string[]
3737
onGraphCreated: (circuitId: string) => void
38+
initialConfig?: {
39+
saeSetName?: string
40+
input?: CircuitInput
41+
desiredLogitProb?: number
42+
maxFeatureNodes?: number
43+
maxNLogits?: number
44+
qkTracingTopk?: number
45+
name?: string
46+
}
47+
trigger?: React.ReactNode
3848
}
3949

4050
export function NewGraphDialog({
4151
saeSets,
4252
onGraphCreated,
53+
initialConfig,
54+
trigger,
4355
}: NewGraphDialogProps) {
4456
const queryClient = useQueryClient()
4557
const [dialogOpen, setDialogOpen] = useState(false)
4658

47-
const [selectedSaeSet, setSelectedSaeSet] = useState<string>('')
48-
const [customGraphId, setCustomGraphId] = useState('')
49-
const [useChatTemplate, setUseChatTemplate] = useState(false)
50-
const [prompt, setPrompt] = useState('')
51-
const [chatMessages, setChatMessages] = useState<ChatMessage[]>([
52-
{ role: 'user', content: '' },
53-
{ role: 'assistant', content: '' },
54-
])
59+
const [selectedSaeSet, setSelectedSaeSet] = useState<string>(
60+
initialConfig?.saeSetName ?? '',
61+
)
62+
const [customGraphId, setCustomGraphId] = useState(initialConfig?.name ?? '')
63+
const [useChatTemplate, setUseChatTemplate] = useState(
64+
initialConfig?.input?.inputType === 'chat_template',
65+
)
66+
const [prompt, setPrompt] = useState(
67+
initialConfig?.input?.inputType === 'plain_text'
68+
? initialConfig.input.text
69+
: '',
70+
)
71+
const [chatMessages, setChatMessages] = useState<ChatMessage[]>(
72+
initialConfig?.input?.inputType === 'chat_template'
73+
? initialConfig.input.messages
74+
: [
75+
{ role: 'user', content: '' },
76+
{ role: 'assistant', content: '' },
77+
],
78+
)
5579

56-
const [desiredLogitProb, setDesiredLogitProb] = useState(0.98)
57-
const [maxNodes, setMaxNodes] = useState(256)
58-
const [maxLogits, setMaxLogits] = useState(1)
59-
const [qkTracingTopk, setQkTracingTopk] = useState(10)
80+
const [desiredLogitProb, setDesiredLogitProb] = useState(
81+
initialConfig?.desiredLogitProb ?? 0.98,
82+
)
83+
const [maxNodes, setMaxNodes] = useState(
84+
initialConfig?.maxFeatureNodes ?? 256,
85+
)
86+
const [maxLogits, setMaxLogits] = useState(initialConfig?.maxNLogits ?? 1)
87+
const [qkTracingTopk, setQkTracingTopk] = useState(
88+
initialConfig?.qkTracingTopk ?? 10,
89+
)
6090

6191
const {
6292
mutate: mutateGenerateCircuit,
@@ -166,17 +196,28 @@ export function NewGraphDialog({
166196
}
167197

168198
const handleReset = () => {
169-
setCustomGraphId('')
170-
setUseChatTemplate(false)
171-
setPrompt('')
172-
setChatMessages([
173-
{ role: 'user', content: '' },
174-
{ role: 'assistant', content: '' },
175-
])
176-
setDesiredLogitProb(0.98)
177-
setMaxNodes(256)
178-
setMaxLogits(1)
179-
setQkTracingTopk(10)
199+
if (initialConfig?.saeSetName !== undefined) {
200+
setSelectedSaeSet(initialConfig.saeSetName)
201+
}
202+
setCustomGraphId(initialConfig?.name ?? '')
203+
setUseChatTemplate(initialConfig?.input?.inputType === 'chat_template')
204+
setPrompt(
205+
initialConfig?.input?.inputType === 'plain_text'
206+
? initialConfig.input.text
207+
: '',
208+
)
209+
setChatMessages(
210+
initialConfig?.input?.inputType === 'chat_template'
211+
? initialConfig.input.messages
212+
: [
213+
{ role: 'user', content: '' },
214+
{ role: 'assistant', content: '' },
215+
],
216+
)
217+
setDesiredLogitProb(initialConfig?.desiredLogitProb ?? 0.98)
218+
setMaxNodes(initialConfig?.maxFeatureNodes ?? 256)
219+
setMaxLogits(initialConfig?.maxNLogits ?? 1)
220+
setQkTracingTopk(initialConfig?.qkTracingTopk ?? 10)
180221
}
181222

182223
const handleDialogClose = () => {
@@ -187,6 +228,7 @@ export function NewGraphDialog({
187228
const handleDialogOpenChange = (open: boolean) => {
188229
if (open) {
189230
setDialogOpen(true)
231+
handleReset()
190232
} else {
191233
handleDialogClose()
192234
}
@@ -197,10 +239,12 @@ export function NewGraphDialog({
197239
return (
198240
<Dialog open={dialogOpen} onOpenChange={handleDialogOpenChange}>
199241
<DialogTrigger asChild>
200-
<Button className="h-12 px-4 gap-2">
201-
<Plus className="h-4 w-4" />
202-
New Graph
203-
</Button>
242+
{trigger || (
243+
<Button className="h-14 px-4 gap-2 font-semibold">
244+
<Plus className="h-4 w-4" />
245+
New Graph
246+
</Button>
247+
)}
204248
</DialogTrigger>
205249
<DialogContent className="max-w-6xl max-h-[90vh]">
206250
<DialogHeader>

ui-ssr/src/routes/circuit.$id.index.tsx

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import {
55
useRouter,
66
} from '@tanstack/react-router'
77
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'
8-
import { AlertCircle, Loader2 } from 'lucide-react'
8+
import { AlertCircle, GitFork, Loader2 } from 'lucide-react'
99
import { useCallback, useMemo, useState } from 'react'
1010
import { z } from 'zod'
1111
import type { CircuitData, FeatureNode, VisState } from '@/types/circuit'
@@ -25,6 +25,7 @@ import { NodeConnections } from '@/components/circuits/node-connections'
2525
import { ThresholdControls } from '@/components/circuits/threshold-controls'
2626
import { FeatureCardHorizontal } from '@/components/feature/feature-card-horizontal'
2727
import { Card } from '@/components/ui/card'
28+
import { Button } from '@/components/ui/button'
2829
import { createRawEdgeIndex, createRawNodeIndex } from '@/utils/circuit-index'
2930

3031
const searchParamsSchema = z.object({
@@ -313,10 +314,38 @@ function CircuitPage() {
313314
onSelect={handleCircuitSelect}
314315
/>
315316
</div>
316-
<NewGraphDialog
317-
saeSets={saeSets}
318-
onGraphCreated={handleGraphCreated}
319-
/>
317+
<div className="flex gap-2">
318+
<NewGraphDialog
319+
saeSets={saeSets}
320+
onGraphCreated={handleGraphCreated}
321+
/>
322+
{circuitData && (
323+
<NewGraphDialog
324+
saeSets={saeSets}
325+
onGraphCreated={handleGraphCreated}
326+
initialConfig={{
327+
saeSetName: circuitData.saeSetName,
328+
input: circuitData.input,
329+
desiredLogitProb: circuitData.config.desiredLogitProb,
330+
maxFeatureNodes: circuitData.config.maxFeatureNodes,
331+
maxNLogits: circuitData.config.maxNLogits,
332+
qkTracingTopk: circuitData.config.qkTracingTopk,
333+
name: circuitData.name
334+
? `${circuitData.name}-remix`
335+
: undefined,
336+
}}
337+
trigger={
338+
<Button
339+
variant="outline"
340+
className="h-14 px-4 gap-2 text-blue-700 border-blue-200 bg-blue-50 hover:bg-blue-100 hover:border-blue-300 transition-colors font-semibold"
341+
>
342+
<GitFork className="h-4 w-4" />
343+
Remix
344+
</Button>
345+
}
346+
/>
347+
)}
348+
</div>
320349
</div>
321350
<div className="flex-1" />
322351
</div>

0 commit comments

Comments
 (0)