Skip to content

Commit 400badb

Browse files
author
Attila Cseh
committed
canvasStagingAreaSlice refactored
1 parent 01caa0b commit 400badb

File tree

3 files changed

+30
-24
lines changed

3 files changed

+30
-24
lines changed

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import type { AppStartListening } from 'app/store/store';
33
import { bboxSyncedToOptimalDimension, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
44
import {
55
buildSelectIsStagingBySessionId,
6-
selectSelectedCanvasSessionId,
6+
selectActiveCanvasSessionId,
77
} from 'features/controlLayers/store/canvasStagingAreaSlice';
88
import { loraIsEnabledChanged } from 'features/controlLayers/store/lorasSlice';
99
import { modelChanged, syncedToOptimalDimension, vaeSelected } from 'features/controlLayers/store/paramsSlice';
@@ -162,7 +162,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
162162
if (modelBase !== state.params.model?.base) {
163163
// Sync generate tab settings whenever the model base changes
164164
dispatch(syncedToOptimalDimension());
165-
const sessionId = selectSelectedCanvasSessionId(state);
165+
const sessionId = selectActiveCanvasSessionId(state);
166166
const selectIsStaging = buildSelectIsStagingBySessionId(sessionId);
167167
const isStaging = selectIsStaging(state);
168168
if (!isStaging) {

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/setDefaultSettings.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import { isNil } from 'es-toolkit';
33
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasSlice';
44
import {
55
buildSelectIsStagingBySessionId,
6-
selectSelectedCanvasSessionId,
6+
selectActiveCanvasSessionId,
77
} from 'features/controlLayers/store/canvasStagingAreaSlice';
88
import {
99
heightChanged,
@@ -118,7 +118,7 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
118118
}
119119
const setSizeOptions = { updateAspectRatio: true, clamp: true };
120120

121-
const sessionId = selectSelectedCanvasSessionId(state);
121+
const sessionId = selectActiveCanvasSessionId(state);
122122
const selectIsStaging = buildSelectIsStagingBySessionId(sessionId);
123123
const isStaging = selectIsStaging(state);
124124

invokeai/frontend/web/src/features/controlLayers/store/canvasStagingAreaSlice.ts

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ import {
1717
import { selectActiveCanvasId } from './selectors';
1818

1919
const zCanvasSessionState = z.object({
20-
canvasId: z.string(),
20+
canvasId: z.string().min(1),
2121
canvasSessionId: z.string(),
2222
canvasDiscardedQueueItems: z.array(z.number().int()),
2323
});
2424
type CanvasSessionState = z.infer<typeof zCanvasSessionState>;
2525
const zCanvasStagingAreaState = z.object({
2626
_version: z.literal(2),
27-
sessions: z.array(zCanvasSessionState),
27+
sessions: z.record(z.string(), zCanvasSessionState),
2828
});
2929
type CanvasStagingAreaState = z.infer<typeof zCanvasStagingAreaState>;
3030

@@ -39,17 +39,17 @@ const getInitialCanvasSessionState = (canvasId: string): CanvasSessionState => (
3939

4040
const getInitialState = (): CanvasStagingAreaState => ({
4141
_version: 2,
42-
sessions: [],
42+
sessions: {},
4343
});
4444

45-
const slice = createSlice({
45+
const canvasStagingAreaSlice = createSlice({
4646
name: 'canvasSession',
4747
initialState: getInitialState(),
4848
reducers: {
4949
canvasQueueItemDiscarded: (state, action: CanvasPayloadAction<{ itemId: number }>) => {
5050
const { canvasId, itemId } = action.payload;
5151

52-
const session = state.sessions.find((session) => session.canvasId === canvasId);
52+
const session = state.sessions[canvasId];
5353
if (!session) {
5454
return;
5555
}
@@ -62,7 +62,7 @@ const slice = createSlice({
6262
reducer: (state, action: CanvasPayloadAction<{ canvasSessionId: string }>) => {
6363
const { canvasId, canvasSessionId } = action.payload;
6464

65-
const session = state.sessions.find((session) => session.canvasId === canvasId);
65+
const session = state.sessions[canvasId];
6666
if (!session) {
6767
return;
6868
}
@@ -83,25 +83,27 @@ const slice = createSlice({
8383
extraReducers(builder) {
8484
builder.addCase(canvasCreated, (state, action) => {
8585
const session = getInitialCanvasSessionState(action.payload.canvasId);
86-
state.sessions.push(session);
86+
state.sessions[session.canvasId] = session;
8787
});
8888
builder.addCase(canvasRemoved, (state, action) => {
89-
state.sessions = state.sessions.filter((session) => session.canvasId !== action.payload.canvasId);
89+
delete state.sessions[action.payload.canvasId];
9090
});
9191
builder.addCase(canvasMultiCanvasMigrated, (state, action) => {
92-
const session = state.sessions.find((session) => session.canvasId === MIGRATION_MULTI_CANVAS_ID_PLACEHOLDER);
92+
const session = state.sessions[MIGRATION_MULTI_CANVAS_ID_PLACEHOLDER];
9393
if (!session) {
9494
return;
9595
}
9696
session.canvasId = action.payload.canvasId;
97+
state.sessions[session.canvasId] = session;
98+
delete state.sessions[MIGRATION_MULTI_CANVAS_ID_PLACEHOLDER];
9799
});
98100
},
99101
});
100102

101-
export const { canvasSessionReset, canvasQueueItemDiscarded } = slice.actions;
103+
export const { canvasSessionReset, canvasQueueItemDiscarded } = canvasStagingAreaSlice.actions;
102104

103-
export const canvasSessionSliceConfig: SliceConfig<typeof slice> = {
104-
slice,
105+
export const canvasSessionSliceConfig: SliceConfig<typeof canvasStagingAreaSlice> = {
106+
slice: canvasStagingAreaSlice,
105107
schema: zCanvasStagingAreaState,
106108
getInitialState,
107109
persistConfig: {
@@ -119,7 +121,7 @@ export const canvasSessionSliceConfig: SliceConfig<typeof slice> = {
119121

120122
state = {
121123
_version: 2,
122-
sessions: [session],
124+
sessions: { [session.canvasId]: session },
123125
};
124126
}
125127

@@ -128,28 +130,32 @@ export const canvasSessionSliceConfig: SliceConfig<typeof slice> = {
128130
},
129131
};
130132

131-
const findSessionByCanvasId = (sessions: CanvasSessionState[], canvasId: string) => {
132-
const session = sessions.find((s) => s.canvasId === canvasId);
133+
const findSessionByCanvasId = (sessions: Record<string, CanvasSessionState>, canvasId: string) => {
134+
const session = sessions[canvasId];
133135
assert(session, 'Session must exist for a canvas once the canvas has been created');
134136
return session;
135137
};
136138
export const selectCanvasSessionByCanvasId = (state: RootState, canvasId: string) =>
137139
findSessionByCanvasId(state.canvasSession.sessions, canvasId);
138-
const selectSelectedCanvasSession = (state: RootState) => {
140+
const selectActiveCanvasSession = (state: RootState) => {
139141
const canvasId = selectActiveCanvasId(state);
140142
return findSessionByCanvasId(state.canvasSession.sessions, canvasId);
141143
};
144+
export const selectCanvasSessionBySessionId = (state: RootState, sessionId: string) => {
145+
const session = Object.values(state.canvasSession.sessions).find((s) => s.canvasSessionId === sessionId);
146+
assert(session, 'Session does not exist');
147+
return session;
148+
};
142149
export const selectCanvasSessionId = (state: RootState, canvasId: string) => {
143150
const session = selectCanvasSessionByCanvasId(state, canvasId);
144151
return session.canvasSessionId;
145152
};
146-
export const selectSelectedCanvasSessionId = (state: RootState) => {
147-
const session = selectSelectedCanvasSession(state);
153+
export const selectActiveCanvasSessionId = (state: RootState) => {
154+
const session = selectActiveCanvasSession(state);
148155
return session.canvasSessionId;
149156
};
150157
const selectCanvasSessionDiscardedItemsBySessionId = (state: RootState, sessionId: string) => {
151-
const session = state.canvasSession.sessions.find((s) => s.canvasSessionId === sessionId);
152-
assert(session, 'Session does not exist');
158+
const session = selectCanvasSessionBySessionId(state, sessionId);
153159
return session.canvasDiscardedQueueItems;
154160
};
155161
export const buildSelectCanvasQueueItemsBySessionId = (sessionId: string) =>

0 commit comments

Comments
 (0)