@@ -35,28 +35,58 @@ import { useDebounce } from '@/hooks/use-debounce'
3535interface NewGraphDialogProps {
3636 saeSets : string [ ]
3737 onGraphCreated : ( circuitId : string ) => void
38+ initialConfig ?: {
39+ saeSetName ?: string
40+ input ?: CircuitInput
41+ desiredLogitProb ?: number
42+ maxFeatureNodes ?: number
43+ maxNLogits ?: number
44+ qkTracingTopk ?: number
45+ name ?: string
46+ }
47+ trigger ?: React . ReactNode
3848}
3949
4050export function NewGraphDialog ( {
4151 saeSets,
4252 onGraphCreated,
53+ initialConfig,
54+ trigger,
4355} : NewGraphDialogProps ) {
4456 const queryClient = useQueryClient ( )
4557 const [ dialogOpen , setDialogOpen ] = useState ( false )
4658
47- const [ selectedSaeSet , setSelectedSaeSet ] = useState < string > ( '' )
48- const [ customGraphId , setCustomGraphId ] = useState ( '' )
49- const [ useChatTemplate , setUseChatTemplate ] = useState ( false )
50- const [ prompt , setPrompt ] = useState ( '' )
51- const [ chatMessages , setChatMessages ] = useState < ChatMessage [ ] > ( [
52- { role : 'user' , content : '' } ,
53- { role : 'assistant' , content : '' } ,
54- ] )
59+ const [ selectedSaeSet , setSelectedSaeSet ] = useState < string > (
60+ initialConfig ?. saeSetName ?? '' ,
61+ )
62+ const [ customGraphId , setCustomGraphId ] = useState ( initialConfig ?. name ?? '' )
63+ const [ useChatTemplate , setUseChatTemplate ] = useState (
64+ initialConfig ?. input ?. inputType === 'chat_template' ,
65+ )
66+ const [ prompt , setPrompt ] = useState (
67+ initialConfig ?. input ?. inputType === 'plain_text'
68+ ? initialConfig . input . text
69+ : '' ,
70+ )
71+ const [ chatMessages , setChatMessages ] = useState < ChatMessage [ ] > (
72+ initialConfig ?. input ?. inputType === 'chat_template'
73+ ? initialConfig . input . messages
74+ : [
75+ { role : 'user' , content : '' } ,
76+ { role : 'assistant' , content : '' } ,
77+ ] ,
78+ )
5579
56- const [ desiredLogitProb , setDesiredLogitProb ] = useState ( 0.98 )
57- const [ maxNodes , setMaxNodes ] = useState ( 256 )
58- const [ maxLogits , setMaxLogits ] = useState ( 1 )
59- const [ qkTracingTopk , setQkTracingTopk ] = useState ( 10 )
80+ const [ desiredLogitProb , setDesiredLogitProb ] = useState (
81+ initialConfig ?. desiredLogitProb ?? 0.98 ,
82+ )
83+ const [ maxNodes , setMaxNodes ] = useState (
84+ initialConfig ?. maxFeatureNodes ?? 256 ,
85+ )
86+ const [ maxLogits , setMaxLogits ] = useState ( initialConfig ?. maxNLogits ?? 1 )
87+ const [ qkTracingTopk , setQkTracingTopk ] = useState (
88+ initialConfig ?. qkTracingTopk ?? 10 ,
89+ )
6090
6191 const {
6292 mutate : mutateGenerateCircuit ,
@@ -166,17 +196,28 @@ export function NewGraphDialog({
166196 }
167197
168198 const handleReset = ( ) => {
169- setCustomGraphId ( '' )
170- setUseChatTemplate ( false )
171- setPrompt ( '' )
172- setChatMessages ( [
173- { role : 'user' , content : '' } ,
174- { role : 'assistant' , content : '' } ,
175- ] )
176- setDesiredLogitProb ( 0.98 )
177- setMaxNodes ( 256 )
178- setMaxLogits ( 1 )
179- setQkTracingTopk ( 10 )
199+ if ( initialConfig ?. saeSetName !== undefined ) {
200+ setSelectedSaeSet ( initialConfig . saeSetName )
201+ }
202+ setCustomGraphId ( initialConfig ?. name ?? '' )
203+ setUseChatTemplate ( initialConfig ?. input ?. inputType === 'chat_template' )
204+ setPrompt (
205+ initialConfig ?. input ?. inputType === 'plain_text'
206+ ? initialConfig . input . text
207+ : '' ,
208+ )
209+ setChatMessages (
210+ initialConfig ?. input ?. inputType === 'chat_template'
211+ ? initialConfig . input . messages
212+ : [
213+ { role : 'user' , content : '' } ,
214+ { role : 'assistant' , content : '' } ,
215+ ] ,
216+ )
217+ setDesiredLogitProb ( initialConfig ?. desiredLogitProb ?? 0.98 )
218+ setMaxNodes ( initialConfig ?. maxFeatureNodes ?? 256 )
219+ setMaxLogits ( initialConfig ?. maxNLogits ?? 1 )
220+ setQkTracingTopk ( initialConfig ?. qkTracingTopk ?? 10 )
180221 }
181222
182223 const handleDialogClose = ( ) => {
@@ -187,6 +228,7 @@ export function NewGraphDialog({
187228 const handleDialogOpenChange = ( open : boolean ) => {
188229 if ( open ) {
189230 setDialogOpen ( true )
231+ handleReset ( )
190232 } else {
191233 handleDialogClose ( )
192234 }
@@ -197,10 +239,12 @@ export function NewGraphDialog({
197239 return (
198240 < Dialog open = { dialogOpen } onOpenChange = { handleDialogOpenChange } >
199241 < DialogTrigger asChild >
200- < Button className = "h-12 px-4 gap-2" >
201- < Plus className = "h-4 w-4" />
202- New Graph
203- </ Button >
242+ { trigger || (
243+ < Button className = "h-14 px-4 gap-2 font-semibold" >
244+ < Plus className = "h-4 w-4" />
245+ New Graph
246+ </ Button >
247+ ) }
204248 </ DialogTrigger >
205249 < DialogContent className = "max-w-6xl max-h-[90vh]" >
206250 < DialogHeader >
0 commit comments