@@ -44,11 +44,19 @@ export interface ModelConfig {
44
44
api_version ?: string ;
45
45
}
46
46
47
+ // Define model slot types
48
+ export type ModelSlotType = 'generation' | 'hint' ;
49
+
50
+ export interface ModelSlots {
51
+ generation ?: string ; // model id assigned to generation tasks
52
+ hint ?: string ; // model id assigned to hint tasks
53
+ }
54
+
47
55
// Define a type for the slice state
48
56
export interface DataFormulatorState {
49
57
sessionId : string | undefined ;
50
58
models : ModelConfig [ ] ;
51
- selectedModelId : string | undefined ;
59
+ modelSlots : ModelSlots ;
52
60
testedModels : { id : string , status : 'ok' | 'error' | 'testing' | 'unknown' , message : string } [ ] ;
53
61
54
62
tables : DictTable [ ] ;
@@ -89,7 +97,7 @@ export interface DataFormulatorState {
89
97
const initialState : DataFormulatorState = {
90
98
sessionId : undefined ,
91
99
models : [ ] ,
92
- selectedModelId : undefined ,
100
+ modelSlots : { } ,
93
101
testedModels : [ ] ,
94
102
95
103
tables : [ ] ,
@@ -263,7 +271,7 @@ export const dataFormulatorSlice = createSlice({
263
271
// avoid resetting inputted models
264
272
// state.oaiModels = state.oaiModels.filter((m: any) => m.endpoint != 'default');
265
273
266
- state . selectedModelId = state . models . length > 0 ? state . models [ 0 ] . id : undefined ;
274
+ state . modelSlots = { } ;
267
275
state . testedModels = [ ] ;
268
276
269
277
state . tables = [ ] ;
@@ -289,7 +297,7 @@ export const dataFormulatorSlice = createSlice({
289
297
let savedState = action . payload ;
290
298
291
299
state . models = savedState . models ;
292
- state . selectedModelId = savedState . selectedModelId ;
300
+ state . modelSlots = savedState . modelSlots || { } ;
293
301
state . testedModels = [ ] ; // models should be tested again
294
302
295
303
//state.table = undefined;
@@ -318,16 +326,25 @@ export const dataFormulatorSlice = createSlice({
318
326
state . config = action . payload ;
319
327
} ,
320
328
selectModel : ( state , action : PayloadAction < string | undefined > ) => {
321
- state . selectedModelId = action . payload ;
329
+ state . modelSlots = { ...state . modelSlots , generation : action . payload } ;
330
+ } ,
331
+ setModelSlot : ( state , action : PayloadAction < { slotType : ModelSlotType , modelId : string | undefined } > ) => {
332
+ state . modelSlots = { ...state . modelSlots , [ action . payload . slotType ] : action . payload . modelId } ;
333
+ } ,
334
+ setModelSlots : ( state , action : PayloadAction < ModelSlots > ) => {
335
+ state . modelSlots = action . payload ;
322
336
} ,
323
337
addModel : ( state , action : PayloadAction < ModelConfig > ) => {
324
338
state . models = [ ...state . models , action . payload ] ;
325
339
} ,
326
340
removeModel : ( state , action : PayloadAction < string > ) => {
327
341
state . models = state . models . filter ( model => model . id != action . payload ) ;
328
- if ( state . selectedModelId == action . payload ) {
329
- state . selectedModelId = undefined ;
330
- }
342
+ // Remove the model from all slots if it's assigned
343
+ Object . keys ( state . modelSlots ) . forEach ( slotType => {
344
+ if ( state . modelSlots [ slotType as ModelSlotType ] === action . payload ) {
345
+ state . modelSlots [ slotType as ModelSlotType ] = undefined ;
346
+ }
347
+ } ) ;
331
348
} ,
332
349
updateModelStatus : ( state , action : PayloadAction < { id : string , status : 'ok' | 'error' | 'testing' | 'unknown' , message : string } > ) => {
333
350
let id = action . payload . id ;
@@ -735,16 +752,19 @@ export const dataFormulatorSlice = createSlice({
735
752
736
753
state . models = [
737
754
...defaultModels ,
738
- ...state . models . filter ( e => ! defaultModels . map ( ( m : ModelConfig ) => m . endpoint ) . includes ( e . endpoint ) )
755
+ ...state . models . filter ( e => ! defaultModels . some ( ( m : ModelConfig ) =>
756
+ m . endpoint === e . endpoint && m . model === e . model &&
757
+ m . api_base === e . api_base && m . api_version === e . api_version
758
+ ) )
739
759
] ;
740
760
741
761
state . testedModels = [
742
762
...defaultModels . map ( ( m : ModelConfig ) => { return { id : m . id , status : 'ok' } } ) ,
743
763
...state . testedModels . filter ( t => ! defaultModels . map ( ( m : ModelConfig ) => m . id ) . includes ( t . id ) )
744
764
]
745
765
746
- if ( state . selectedModelId == undefined && defaultModels . length > 0 ) {
747
- state . selectedModelId = defaultModels [ 0 ] . id ;
766
+ if ( state . modelSlots . generation == undefined && defaultModels . length > 0 ) {
767
+ state . modelSlots . generation = defaultModels [ 0 ] . id ;
748
768
}
749
769
750
770
// console.log("load model complete");
@@ -769,7 +789,14 @@ export const dataFormulatorSlice = createSlice({
769
789
770
790
export const dfSelectors = {
771
791
getActiveModel : ( state : DataFormulatorState ) : ModelConfig => {
772
- return state . models . find ( m => m . id == state . selectedModelId ) || state . models [ 0 ] ;
792
+ return state . models . find ( m => m . id == state . modelSlots . generation ) || state . models [ 0 ] ;
793
+ } ,
794
+ getModelBySlot : ( state : DataFormulatorState , slotType : ModelSlotType ) : ModelConfig | undefined => {
795
+ const modelId = state . modelSlots [ slotType ] ;
796
+ return modelId ? state . models . find ( m => m . id === modelId ) : undefined ;
797
+ } ,
798
+ getAllSlotTypes : ( ) : ModelSlotType [ ] => {
799
+ return [ 'generation' , 'hint' ] ;
773
800
} ,
774
801
getActiveBaseTableIds : ( state : DataFormulatorState ) => {
775
802
let focusedTableId = state . focusedTableId ;
0 commit comments