Skip to content

Commit e867ec4

Browse files
committed
model dialog prep
1 parent b895f39 commit e867ec4

File tree

2 files changed

+297
-95
lines changed

2 files changed

+297
-95
lines changed

src/app/dfSlice.tsx

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,19 @@ export interface ModelConfig {
4444
api_version?: string;
4545
}
4646

47+
// Define model slot types
48+
export type ModelSlotType = 'generation' | 'hint';
49+
50+
export interface ModelSlots {
51+
generation?: string; // model id assigned to generation tasks
52+
hint?: string; // model id assigned to hint tasks
53+
}
54+
4755
// Define a type for the slice state
4856
export interface DataFormulatorState {
4957
sessionId: string | undefined;
5058
models: ModelConfig[];
51-
selectedModelId: string | undefined;
59+
modelSlots: ModelSlots;
5260
testedModels: {id: string, status: 'ok' | 'error' | 'testing' | 'unknown', message: string}[];
5361

5462
tables : DictTable[];
@@ -89,7 +97,7 @@ export interface DataFormulatorState {
8997
const initialState: DataFormulatorState = {
9098
sessionId: undefined,
9199
models: [],
92-
selectedModelId: undefined,
100+
modelSlots: {},
93101
testedModels: [],
94102

95103
tables: [],
@@ -263,7 +271,7 @@ export const dataFormulatorSlice = createSlice({
263271
// avoid resetting inputted models
264272
// state.oaiModels = state.oaiModels.filter((m: any) => m.endpoint != 'default');
265273

266-
state.selectedModelId = state.models.length > 0 ? state.models[0].id : undefined;
274+
state.modelSlots = {};
267275
state.testedModels = [];
268276

269277
state.tables = [];
@@ -289,7 +297,7 @@ export const dataFormulatorSlice = createSlice({
289297
let savedState = action.payload;
290298

291299
state.models = savedState.models;
292-
state.selectedModelId = savedState.selectedModelId;
300+
state.modelSlots = savedState.modelSlots || {};
293301
state.testedModels = []; // models should be tested again
294302

295303
//state.table = undefined;
@@ -318,16 +326,25 @@ export const dataFormulatorSlice = createSlice({
318326
state.config = action.payload;
319327
},
320328
selectModel: (state, action: PayloadAction<string | undefined>) => {
321-
state.selectedModelId = action.payload;
329+
state.modelSlots = { ...state.modelSlots, generation: action.payload };
330+
},
331+
setModelSlot: (state, action: PayloadAction<{slotType: ModelSlotType, modelId: string | undefined}>) => {
332+
state.modelSlots = { ...state.modelSlots, [action.payload.slotType]: action.payload.modelId };
333+
},
334+
setModelSlots: (state, action: PayloadAction<ModelSlots>) => {
335+
state.modelSlots = action.payload;
322336
},
323337
addModel: (state, action: PayloadAction<ModelConfig>) => {
324338
state.models = [...state.models, action.payload];
325339
},
326340
removeModel: (state, action: PayloadAction<string>) => {
327341
state.models = state.models.filter(model => model.id != action.payload);
328-
if (state.selectedModelId == action.payload) {
329-
state.selectedModelId = undefined;
330-
}
342+
// Remove the model from all slots if it's assigned
343+
Object.keys(state.modelSlots).forEach(slotType => {
344+
if (state.modelSlots[slotType as ModelSlotType] === action.payload) {
345+
state.modelSlots[slotType as ModelSlotType] = undefined;
346+
}
347+
});
331348
},
332349
updateModelStatus: (state, action: PayloadAction<{id: string, status: 'ok' | 'error' | 'testing' | 'unknown', message: string}>) => {
333350
let id = action.payload.id;
@@ -743,8 +760,8 @@ export const dataFormulatorSlice = createSlice({
743760
...state.testedModels.filter(t => !defaultModels.map((m: ModelConfig) => m.id).includes(t.id))
744761
]
745762

746-
if (state.selectedModelId == undefined && defaultModels.length > 0) {
747-
state.selectedModelId = defaultModels[0].id;
763+
if (state.modelSlots.generation == undefined && defaultModels.length > 0) {
764+
state.modelSlots.generation = defaultModels[0].id;
748765
}
749766

750767
// console.log("load model complete");
@@ -769,7 +786,14 @@ export const dataFormulatorSlice = createSlice({
769786

770787
export const dfSelectors = {
771788
getActiveModel: (state: DataFormulatorState) : ModelConfig => {
772-
return state.models.find(m => m.id == state.selectedModelId) || state.models[0];
789+
return state.models.find(m => m.id == state.modelSlots.generation) || state.models[0];
790+
},
791+
getModelBySlot: (state: DataFormulatorState, slotType: ModelSlotType) : ModelConfig | undefined => {
792+
const modelId = state.modelSlots[slotType];
793+
return modelId ? state.models.find(m => m.id === modelId) : undefined;
794+
},
795+
getAllSlotTypes: () : ModelSlotType[] => {
796+
return ['generation', 'hint'];
773797
},
774798
getActiveBaseTableIds: (state: DataFormulatorState) => {
775799
let focusedTableId = state.focusedTableId;

0 commit comments

Comments
 (0)