Skip to content

Commit 65d2011

Browse files
author
Attila Cseh
committed
multi-canvas Redux refactor
1 parent 36e400d commit 65d2011

File tree

208 files changed

+2049
-1221
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

208 files changed

+2049
-1221
lines changed

invokeai/frontend/web/.storybook/ReduxInit.tsx

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import type { PropsWithChildren } from 'react';
33
import { memo, useEffect } from 'react';
44

55
import { useAppDispatch } from '../src/app/store/storeHooks';
6-
import { modelChanged } from '../src/features/controlLayers/store/paramsSlice';
6+
import { modelChanged } from 'features/controlLayers/store/actions';
77
/**
88
* Initializes some state for storybook. Must be in a different component
99
* so that it is run inside the redux context.
@@ -13,7 +13,9 @@ export const ReduxInit = memo(({ children }: PropsWithChildren) => {
1313
useGlobalModifiersInit();
1414
useEffect(() => {
1515
dispatch(
16-
modelChanged({ model: { key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' } })
16+
modelChanged({
17+
model: { key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' },
18+
})
1719
);
1820
}, [dispatch]);
1921

invokeai/frontend/web/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
"@invoke-ai/ui-library": "github:invoke-ai/ui-library#v0.0.48",
4949
"@nanostores/react": "^1.0.0",
5050
"@observ33r/object-equals": "^1.1.5",
51-
"@reduxjs/toolkit": "2.8.2",
51+
"@reduxjs/toolkit": "2.9.0",
5252
"@roarr/browser-log-writer": "^1.3.0",
5353
"@xyflow/react": "^12.8.2",
5454
"ag-psd": "^28.2.2",

invokeai/frontend/web/pnpm-lock.yaml

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import type { Middleware, UnknownAction } from '@reduxjs/toolkit';
2+
import { injectTabActionContext } from 'app/store/util';
3+
import { isCanvasInstanceAction } from 'features/controlLayers/store/canvasSlice';
4+
import { selectActiveCanvasId, selectActiveTab } from 'features/controlLayers/store/selectors';
5+
import { isTabInstanceParamsAction } from 'features/controlLayers/store/tabSlice';
6+
7+
export const actionContextMiddleware: Middleware = (store) => (next) => (action) => {
8+
const currentAction = action as UnknownAction;
9+
10+
if (isTabActionContextRequired(currentAction)) {
11+
const state = store.getState();
12+
const tab = selectActiveTab(state);
13+
const canvasId = tab === 'canvas' ? selectActiveCanvasId(state) : undefined;
14+
15+
injectTabActionContext(currentAction, tab, canvasId);
16+
}
17+
18+
return next(action);
19+
};
20+
21+
const isTabActionContextRequired = (action: UnknownAction) => {
22+
return isTabInstanceParamsAction(action) || isCanvasInstanceAction(action);
23+
};

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/appStarted.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { createAction } from '@reduxjs/toolkit';
22
import type { AppStartListening } from 'app/store/store';
33
import { noop } from 'es-toolkit';
4-
import { setInfillMethod } from 'features/controlLayers/store/paramsSlice';
4+
import { selectInfillMethod, setInfillMethod } from 'features/controlLayers/store/paramsSlice';
55
import { selectLastSelectedItem } from 'features/gallery/store/gallerySelectors';
66
import { imageSelected } from 'features/gallery/store/gallerySlice';
77
import { appInfoApi } from 'services/api/endpoints/appInfo';
@@ -36,7 +36,7 @@ export const addAppStartedListener = (startAppListening: AppStartListening) => {
3636
dispatch(appInfoApi.endpoints.getPatchmatchStatus.initiate())
3737
.unwrap()
3838
.then((isPatchmatchAvailable) => {
39-
const infillMethod = getState().params.infillMethod;
39+
const infillMethod = selectInfillMethod(getState());
4040

4141
if (!isPatchmatchAvailable && infillMethod === 'patchmatch') {
4242
dispatch(setInfillMethod('lama'));

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import type { AppStartListening } from 'app/store/store';
22
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
3-
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
3+
import { selectCanvases } from 'features/controlLayers/store/selectors';
44
import { getImageUsage } from 'features/deleteImageModal/store/state';
55
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
66
import { selectNodesSlice } from 'features/nodes/store/selectors';
@@ -19,12 +19,12 @@ export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppS
1919

2020
const state = getState();
2121
const nodes = selectNodesSlice(state);
22-
const canvas = selectCanvasSlice(state);
22+
const canvases = selectCanvases(state);
2323
const upscale = selectUpscaleSlice(state);
2424
const refImages = selectRefImagesSlice(state);
2525

2626
deleted_images.forEach((image_name) => {
27-
const imageUsage = getImageUsage(nodes, canvas, upscale, refImages, image_name);
27+
const imageUsage = getImageUsage(nodes, canvases, upscale, refImages, image_name);
2828

2929
if (imageUsage.isNodesImage && !wasNodeEditorReset) {
3030
dispatch(nodeEditorReset());

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import { logger } from 'app/logging/logger';
22
import type { AppStartListening } from 'app/store/store';
3+
import { modelChanged } from 'features/controlLayers/store/actions';
34
import { bboxSyncedToOptimalDimension, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
4-
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
5-
import { loraIsEnabledChanged } from 'features/controlLayers/store/lorasSlice';
6-
import { modelChanged, syncedToOptimalDimension, vaeSelected } from 'features/controlLayers/store/paramsSlice';
5+
import { selectActiveCanvasIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
6+
import { loraIsEnabledChanged, selectAddedLoRAs } from 'features/controlLayers/store/lorasSlice';
7+
import { selectActiveTabParams, syncedToOptimalDimension, vaeSelected } from 'features/controlLayers/store/paramsSlice';
78
import { refImageModelChanged, selectReferenceImageEntities } from 'features/controlLayers/store/refImagesSlice';
89
import {
10+
selectActiveCanvas,
911
selectAllEntitiesOfType,
1012
selectBboxModelBase,
11-
selectCanvasSlice,
1213
} from 'features/controlLayers/store/selectors';
1314
import { getEntityIdentifier } from 'features/controlLayers/store/types';
1415
import { SUPPORTS_REF_IMAGES_BASE_MODELS } from 'features/modelManagerV2/models';
@@ -25,7 +26,8 @@ const log = logger('models');
2526
export const addModelSelectedListener = (startAppListening: AppStartListening) => {
2627
startAppListening({
2728
actionCreator: modelSelected,
28-
effect: (action, { getState, dispatch }) => {
29+
effect: (action, api) => {
30+
const { getState, dispatch } = api;
2931
const state = getState();
3032
const result = zParameterModel.safeParse(action.payload);
3133

@@ -36,22 +38,23 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
3638

3739
const newModel = result.data;
3840
const newBase = newModel.base;
39-
const didBaseModelChange = state.params.model?.base !== newBase;
41+
const params = selectActiveTabParams(state);
42+
const didBaseModelChange = params.model?.base !== newBase;
4043

4144
if (didBaseModelChange) {
4245
// we may need to reset some incompatible submodels
4346
let modelsUpdatedDisabledOrCleared = 0;
4447

4548
// handle incompatible loras
46-
state.loras.loras.forEach((lora) => {
49+
selectAddedLoRAs(state).forEach((lora) => {
4750
if (lora.model.base !== newBase) {
4851
dispatch(loraIsEnabledChanged({ id: lora.id, isEnabled: false }));
4952
modelsUpdatedDisabledOrCleared += 1;
5053
}
5154
});
5255

5356
// handle incompatible vae
54-
const { vae } = state.params;
57+
const { vae } = params;
5558
if (vae && vae.base !== newBase) {
5659
dispatch(vaeSelected(null));
5760
modelsUpdatedDisabledOrCleared += 1;
@@ -105,7 +108,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
105108
const newRegionalRefImageModel = selectRegionalRefImageModels(state)[0] ?? null;
106109

107110
// All regional guidance entities are updated to use the same new model.
108-
const canvasState = selectCanvasSlice(state);
111+
const canvasState = selectActiveCanvas(state);
109112
const canvasRegionalGuidanceEntities = selectAllEntitiesOfType(canvasState, 'regional_guidance');
110113
for (const entity of canvasRegionalGuidanceEntities) {
111114
for (const refImage of entity.referenceImages) {
@@ -139,14 +142,14 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
139142
}
140143
}
141144

142-
dispatch(modelChanged({ model: newModel, previousModel: state.params.model }));
145+
dispatch(modelChanged({ model: newModel, previousModel: params.model }));
143146

144147
const modelBase = selectBboxModelBase(state);
145148

146-
if (modelBase !== state.params.model?.base) {
149+
if (modelBase !== params.model?.base) {
147150
// Sync generate tab settings whenever the model base changes
148151
dispatch(syncedToOptimalDimension());
149-
const isStaging = buildSelectIsStaging(selectCanvasSessionId(state))(state);
152+
const isStaging = selectActiveCanvasIsStaging(state);
150153
if (!isStaging) {
151154
// Canvas tab only syncs if not staging
152155
dispatch(bboxSyncedToOptimalDimension());

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
import { logger } from 'app/logging/logger';
22
import type { AppDispatch, AppStartListening, RootState } from 'app/store/store';
3+
import { modelChanged } from 'features/controlLayers/store/actions';
34
import { controlLayerModelChanged, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
4-
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
5+
import { loraDeleted, selectAddedLoRAs } from 'features/controlLayers/store/lorasSlice';
56
import {
67
clipEmbedModelSelected,
78
fluxVAESelected,
8-
modelChanged,
99
refinerModelChanged,
10+
selectActiveTabParams,
1011
t5EncoderModelSelected,
1112
vaeSelected,
1213
} from 'features/controlLayers/store/paramsSlice';
1314
import { refImageModelChanged, selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
14-
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
15+
import { selectActiveCanvas } from 'features/controlLayers/store/selectors';
1516
import {
1617
getEntityIdentifier,
1718
isFLUXReduxConfig,
@@ -99,7 +100,7 @@ type ModelHandler = (
99100
) => undefined;
100101

101102
const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
102-
const selectedMainModel = state.params.model;
103+
const selectedMainModel = selectActiveTabParams(state).model;
103104
const allMainModels = models.filter(isNonRefinerMainModelConfig).sort((a) => (a.base === 'sdxl' ? -1 : 1));
104105

105106
const firstModel = allMainModels[0];
@@ -127,7 +128,7 @@ const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
127128
};
128129

129130
const handleRefinerModels: ModelHandler = (models, state, dispatch, log) => {
130-
const selectedRefinerModel = state.params.refinerModel;
131+
const selectedRefinerModel = selectActiveTabParams(state).refinerModel;
131132

132133
// `null` is a valid refiner model - no need to do anything.
133134
if (selectedRefinerModel === null) {
@@ -151,7 +152,7 @@ const handleRefinerModels: ModelHandler = (models, state, dispatch, log) => {
151152
};
152153

153154
const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
154-
const selectedVAEModel = state.params.vae;
155+
const selectedVAEModel = selectActiveTabParams(state).vae;
155156

156157
// `null` is a valid VAE - it means "use the VAE baked into the currently-selected main model"
157158
if (selectedVAEModel === null) {
@@ -176,7 +177,7 @@ const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
176177

177178
const handleLoRAModels: ModelHandler = (models, state, dispatch, log) => {
178179
const loraModels = models.filter(isLoRAModelConfig);
179-
state.loras.loras.forEach((lora) => {
180+
selectAddedLoRAs(state).forEach((lora) => {
180181
const isLoRAAvailable = loraModels.some((m) => m.key === lora.model.key);
181182
if (isLoRAAvailable) {
182183
return;
@@ -188,7 +189,7 @@ const handleLoRAModels: ModelHandler = (models, state, dispatch, log) => {
188189

189190
const handleControlAdapterModels: ModelHandler = (models, state, dispatch, log) => {
190191
const caModels = models.filter(isControlLayerModelConfig);
191-
selectCanvasSlice(state).controlLayers.entities.forEach((entity) => {
192+
selectActiveCanvas(state).controlLayers.entities.forEach((entity) => {
192193
const selectedControlAdapterModel = entity.controlAdapter.model;
193194
// `null` is a valid control adapter model - no need to do anything.
194195
if (!selectedControlAdapterModel) {
@@ -223,7 +224,7 @@ const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
223224
dispatch(refImageModelChanged({ id: entity.id, modelConfig: null }));
224225
});
225226

226-
selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
227+
selectActiveCanvas(state).regionalGuidance.entities.forEach((entity) => {
227228
entity.referenceImages.forEach(({ id: referenceImageId, config }) => {
228229
if (!isRegionalGuidanceIPAdapterConfig(config)) {
229230
return;
@@ -266,7 +267,7 @@ const handleFLUXReduxModels: ModelHandler = (models, state, dispatch, log) => {
266267
dispatch(refImageModelChanged({ id: entity.id, modelConfig: null }));
267268
});
268269

269-
selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
270+
selectActiveCanvas(state).regionalGuidance.entities.forEach((entity) => {
270271
entity.referenceImages.forEach(({ id: referenceImageId, config }) => {
271272
if (!isRegionalGuidanceFLUXReduxConfig(config)) {
272273
return;
@@ -384,7 +385,7 @@ const handleTileControlNetModel: ModelHandler = (models, state, dispatch, log) =
384385
};
385386

386387
const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {
387-
const selectedT5EncoderModel = state.params.t5EncoderModel;
388+
const selectedT5EncoderModel = selectActiveTabParams(state).t5EncoderModel;
388389
const t5EncoderModels = models.filter((m) => isT5EncoderModelConfigOrSubmodel(m));
389390

390391
// If the currently selected model is available, we don't need to do anything
@@ -412,7 +413,7 @@ const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {
412413
};
413414

414415
const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, log) => {
415-
const selectedCLIPEmbedModel = state.params.clipEmbedModel;
416+
const selectedCLIPEmbedModel = selectActiveTabParams(state).clipEmbedModel;
416417
const CLIPEmbedModels = models.filter((m) => isCLIPEmbedModelConfigOrSubmodel(m));
417418

418419
// If the currently selected model is available, we don't need to do anything
@@ -440,7 +441,7 @@ const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, log) => {
440441
};
441442

442443
const handleFLUXVAEModels: ModelHandler = (models, state, dispatch, log) => {
443-
const selectedFLUXVAEModel = state.params.fluxVAE;
444+
const selectedFLUXVAEModel = selectActiveTabParams(state).fluxVAE;
444445
const fluxVAEModels = models.filter((m) => isFluxVAEModelConfig(m));
445446

446447
// If the currently selected model is available, we don't need to do anything

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/setDefaultSettings.ts

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import type { AppStartListening } from 'app/store/store';
22
import { isNil } from 'es-toolkit';
33
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasSlice';
4-
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
4+
import { selectActiveCanvasIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
55
import {
66
heightChanged,
7+
selectActiveTabParams,
78
setCfgRescaleMultiplier,
89
setCfgScale,
910
setGuidance,
@@ -13,6 +14,7 @@ import {
1314
vaeSelected,
1415
widthChanged,
1516
} from 'features/controlLayers/store/paramsSlice';
17+
import { selectActiveTab } from 'features/controlLayers/store/selectors';
1618
import { setDefaultSettings } from 'features/parameters/store/actions';
1719
import {
1820
isParameterCFGRescaleMultiplier,
@@ -26,18 +28,18 @@ import {
2628
zParameterVAEModel,
2729
} from 'features/parameters/types/parameterSchemas';
2830
import { toast } from 'features/toast/toast';
29-
import { selectActiveTab } from 'features/ui/store/uiSelectors';
3031
import { t } from 'i18next';
3132
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
3233
import { isNonRefinerMainModelConfig } from 'services/api/types';
3334

3435
export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
3536
startAppListening({
3637
actionCreator: setDefaultSettings,
37-
effect: async (action, { dispatch, getState }) => {
38+
effect: async (action, api) => {
39+
const { dispatch, getState } = api;
3840
const state = getState();
3941

40-
const currentModel = state.params.model;
42+
const currentModel = selectActiveTabParams(state).model;
4143

4244
if (!currentModel) {
4345
return;
@@ -115,7 +117,7 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
115117
}
116118
const setSizeOptions = { updateAspectRatio: true, clamp: true };
117119

118-
const isStaging = buildSelectIsStaging(selectCanvasSessionId(state))(state);
120+
const isStaging = selectActiveCanvasIsStaging(state);
119121

120122
const activeTab = selectActiveTab(getState());
121123
if (activeTab === 'generate') {

0 commit comments

Comments
 (0)