Skip to content

Commit 871271f

Browse files
psychedeliciousmaryhipp
authored andcommitted
feat(ui): rough out imagen3 support for canvas
1 parent 1494487 commit 871271f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+554
-137
lines changed

invokeai/frontend/web/public/locales/en.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,6 +1322,7 @@
13221322
"unableToCopyDesc": "Your browser does not support clipboard access. Firefox users may be able to fix this by following ",
13231323
"unableToCopyDesc_theseSteps": "these steps",
13241324
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill is not compatible with Text to Image or Image to Image. Use other FLUX models for these tasks.",
1325+
"image3IncompatibleWithInpaintAndOutpaint": "Imagen3 is does not support Inpainting or Outpainting. Use other models for these tasks.",
13251326
"problemUnpublishingWorkflow": "Problem Unpublishing Workflow",
13261327
"problemUnpublishingWorkflowDescription": "There was a problem unpublishing the workflow. Please try again.",
13271328
"workflowUnpublished": "Workflow Unpublished"

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import { $canvasManager } from 'features/controlLayers/store/ephemeral';
88
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
99
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
1010
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
11+
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
1112
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
1213
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
1314
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
@@ -48,6 +49,8 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
4849
return await buildFLUXGraph(state, manager);
4950
case 'cogview4':
5051
return await buildCogView4Graph(state, manager);
52+
case 'imagen3':
53+
return await buildImagen3Graph(state, manager);
5154
default:
5255
assert(false, `No graph builders for base ${base}`);
5356
}
@@ -68,12 +71,20 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
6871
return;
6972
}
7073

71-
const { g, noise, posCond } = buildGraphResult.value;
74+
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = buildGraphResult.value;
7275

7376
const destination = state.canvasSettings.sendToCanvas ? 'canvas' : 'gallery';
7477

7578
const prepareBatchResult = withResult(() =>
76-
prepareLinearUIBatch(state, g, prepend, noise, posCond, 'canvas', destination)
79+
prepareLinearUIBatch(
80+
state,
81+
g,
82+
prepend,
83+
seedFieldIdentifier,
84+
positivePromptFieldIdentifier,
85+
'canvas',
86+
destination
87+
)
7788
);
7889

7990
if (prepareBatchResult.isErr()) {
Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { useAppSelector } from 'app/store/storeHooks';
2-
import { selectIsCogView4, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
2+
import { selectIsCogView4, selectIsImagen3, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
33
import type { CanvasEntityType } from 'features/controlLayers/store/types';
44
import { useMemo } from 'react';
55
import type { Equals } from 'tsafe';
@@ -8,23 +8,24 @@ import { assert } from 'tsafe';
88
export const useIsEntityTypeEnabled = (entityType: CanvasEntityType) => {
99
const isSD3 = useAppSelector(selectIsSD3);
1010
const isCogView4 = useAppSelector(selectIsCogView4);
11+
const isImagen3 = useAppSelector(selectIsImagen3);
1112

1213
const isEntityTypeEnabled = useMemo<boolean>(() => {
1314
switch (entityType) {
1415
case 'reference_image':
15-
return !isSD3 && !isCogView4;
16+
return !isSD3 && !isCogView4 && !isImagen3;
1617
case 'regional_guidance':
17-
return !isSD3 && !isCogView4;
18+
return !isSD3 && !isCogView4 && !isImagen3;
1819
case 'control_layer':
19-
return !isSD3 && !isCogView4;
20+
return !isSD3 && !isCogView4 && !isImagen3;
2021
case 'inpaint_mask':
21-
return true;
22+
return !isImagen3;
2223
case 'raster_layer':
23-
return true;
24+
return !isImagen3;
2425
default:
2526
assert<Equals<typeof entityType, never>>(false);
2627
}
27-
}, [entityType, isSD3, isCogView4]);
28+
}, [entityType, isSD3, isCogView4, isImagen3]);
2829

2930
return isEntityTypeEnabled;
3031
};

invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityTransformer.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
346346

347347
// If the user is not holding shift, the transform is retaining aspect ratio. It's not possible to snap to the grid
348348
// in this case, because that would change the aspect ratio. So, we only snap to the grid when shift is held.
349-
const gridSize = this.manager.stateApi.$shiftKey.get() ? this.manager.stateApi.getGridSize() : 1;
349+
const gridSize = this.manager.stateApi.$shiftKey.get() ? this.manager.stateApi.getPositionGridSize() : 1;
350350

351351
// We need to snap the anchor to the selected grid size, but the positions provided to this callback are absolute,
352352
// scaled coordinates. They need to be converted to stage coordinates, snapped, then converted back to absolute
@@ -464,7 +464,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
464464
return;
465465
}
466466
const { rect } = this.manager.stateApi.getBbox();
467-
const gridSize = this.manager.stateApi.getGridSize();
467+
const gridSize = this.manager.stateApi.getPositionGridSize();
468468
const width = this.konva.proxyRect.width();
469469
const height = this.konva.proxyRect.height();
470470
const scaleX = rect.width / width;
@@ -498,7 +498,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
498498
return;
499499
}
500500
const { rect } = this.manager.stateApi.getBbox();
501-
const gridSize = this.manager.stateApi.getGridSize();
501+
const gridSize = this.manager.stateApi.getPositionGridSize();
502502
const width = this.konva.proxyRect.width();
503503
const height = this.konva.proxyRect.height();
504504
const scaleX = rect.width / width;
@@ -523,7 +523,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
523523

524524
onDragMove = () => {
525525
// Snap the interaction rect to the grid
526-
const gridSize = this.manager.stateApi.getGridSize();
526+
const gridSize = this.manager.stateApi.getPositionGridSize();
527527
this.konva.proxyRect.x(roundToMultiple(this.konva.proxyRect.x(), gridSize));
528528
this.konva.proxyRect.y(roundToMultiple(this.konva.proxyRect.y(), gridSize));
529529

invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ export class CanvasStateApiModule extends CanvasModuleBase {
493493
* Gets the _positional_ grid size for the current canvas. Note that this is not the same as bbox grid size, which is
494494
* based on the currently-selected model.
495495
*/
496-
getGridSize = (): number => {
496+
getPositionGridSize = (): number => {
497497
const snapToGrid = this.getSettings().snapToGrid;
498498
if (!snapToGrid) {
499499
return 1;

invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasBboxToolModule.ts

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'
44
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
55
import { fitRectToGrid, getKonvaNodeDebugAttrs, getPrefixedId } from 'features/controlLayers/konva/util';
66
import { selectBboxOverlay } from 'features/controlLayers/store/canvasSettingsSlice';
7+
import { selectModel } from 'features/controlLayers/store/paramsSlice';
78
import { selectBbox } from 'features/controlLayers/store/selectors';
8-
import type { Coordinate, Rect } from 'features/controlLayers/store/types';
9+
import type { Coordinate, Rect, Tool } from 'features/controlLayers/store/types';
10+
import type { ModelIdentifierField } from 'features/nodes/types/common';
911
import Konva from 'konva';
1012
import { noop } from 'lodash-es';
1113
import { atom } from 'nanostores';
@@ -178,6 +180,9 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
178180
// Listen for the bbox overlay setting to update the overlay's visibility
179181
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectBboxOverlay, this.render));
180182

183+
// Listen for the model changing - some model types constraint the bbox to a certain size or aspect ratio.
184+
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectModel, this.render));
185+
181186
// Update on busy state changes
182187
this.subscriptions.add(this.manager.$isBusy.listen(this.render));
183188
}
@@ -218,12 +223,25 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
218223

219224
this.syncOverlay();
220225

226+
const model = this.manager.stateApi.runSelector(selectModel);
227+
221228
this.konva.transformer.setAttrs({
222229
listening: tool === 'bbox',
223-
enabledAnchors: tool === 'bbox' ? ALL_ANCHORS : NO_ANCHORS,
230+
enabledAnchors: this.getEnabledAnchors(tool, model),
224231
});
225232
};
226233

234+
getEnabledAnchors = (tool: Tool, model?: ModelIdentifierField | null): string[] => {
235+
if (tool !== 'bbox') {
236+
return NO_ANCHORS;
237+
}
238+
if (model?.base === 'imagen3') {
239+
// The bbox is not resizable in imagen3 mode
240+
return NO_ANCHORS;
241+
}
242+
return ALL_ANCHORS;
243+
};
244+
227245
syncOverlay = () => {
228246
const bboxOverlay = this.manager.stateApi.getSettings().bboxOverlay;
229247

@@ -251,7 +269,7 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
251269
onDragMove = () => {
252270
// The grid size here is the _position_ grid size, not the _dimension_ grid size - it is not constratined by the
253271
// currently-selected model.
254-
const gridSize = this.manager.stateApi.getGridSize();
272+
const gridSize = this.manager.stateApi.getPositionGridSize();
255273
const bbox = this.manager.stateApi.getBbox();
256274
const bboxRect: Rect = {
257275
...bbox.rect,

invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ import type {
6767
IPMethodV2,
6868
T2IAdapterConfig,
6969
} from './types';
70-
import { getEntityIdentifier, isRenderableEntity } from './types';
70+
import { getEntityIdentifier, isImagen3AspectRatioID, isRenderableEntity } from './types';
7171
import {
7272
converters,
7373
getControlLayerState,
@@ -1139,7 +1139,21 @@ export const canvasSlice = createSlice({
11391139
syncScaledSize(state);
11401140
},
11411141
bboxChangedFromCanvas: (state, action: PayloadAction<IRect>) => {
1142-
state.bbox.rect = action.payload;
1142+
const newBboxRect = action.payload;
1143+
const oldBboxRect = state.bbox.rect;
1144+
1145+
state.bbox.rect = newBboxRect;
1146+
1147+
if (newBboxRect.width === oldBboxRect.width && newBboxRect.height === oldBboxRect.height) {
1148+
return;
1149+
}
1150+
1151+
const oldAspectRatio = state.bbox.aspectRatio.value;
1152+
const newAspectRatio = newBboxRect.width / newBboxRect.height;
1153+
1154+
if (oldAspectRatio === newAspectRatio) {
1155+
return;
1156+
}
11431157

11441158
// TODO(psyche): Figure out a way to handle this without resetting the aspect ratio on every change.
11451159
// This action is dispatched when the user resizes or moves the bbox from the canvas. For now, when the user
@@ -1198,6 +1212,26 @@ export const canvasSlice = createSlice({
11981212
state.bbox.aspectRatio.id = id;
11991213
if (id === 'Free') {
12001214
state.bbox.aspectRatio.isLocked = false;
1215+
} else if (state.bbox.modelBase === 'imagen3' && isImagen3AspectRatioID(id)) {
1216+
// Imagen3 has specific output sizes that are not exactly the same as the aspect ratio. Need special handling.
1217+
if (id === '16:9') {
1218+
state.bbox.rect.width = 1408;
1219+
state.bbox.rect.height = 768;
1220+
} else if (id === '4:3') {
1221+
state.bbox.rect.width = 1280;
1222+
state.bbox.rect.height = 896;
1223+
} else if (id === '1:1') {
1224+
state.bbox.rect.width = 1024;
1225+
state.bbox.rect.height = 1024;
1226+
} else if (id === '3:4') {
1227+
state.bbox.rect.width = 896;
1228+
state.bbox.rect.height = 1280;
1229+
} else if (id === '9:16') {
1230+
state.bbox.rect.width = 768;
1231+
state.bbox.rect.height = 1408;
1232+
}
1233+
state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height;
1234+
state.bbox.aspectRatio.isLocked = true;
12011235
} else {
12021236
state.bbox.aspectRatio.isLocked = true;
12031237
state.bbox.aspectRatio.value = ASPECT_RATIO_MAP[id].ratio;
@@ -1670,6 +1704,13 @@ export const canvasSlice = createSlice({
16701704
const base = model?.base;
16711705
if (isMainModelBase(base) && state.bbox.modelBase !== base) {
16721706
state.bbox.modelBase = base;
1707+
if (base === 'imagen3') {
1708+
state.bbox.aspectRatio.isLocked = true;
1709+
state.bbox.aspectRatio.value = 1;
1710+
state.bbox.aspectRatio.id = '1:1';
1711+
state.bbox.rect.width = 1024;
1712+
state.bbox.rect.height = 1024;
1713+
}
16731714
syncScaledSize(state);
16741715
}
16751716
});
@@ -1802,6 +1843,10 @@ export const canvasPersistConfig: PersistConfig<CanvasState> = {
18021843
};
18031844

18041845
const syncScaledSize = (state: CanvasState) => {
1846+
if (state.bbox.modelBase === 'imagen3') {
1847+
// Imagen3 has fixed sizes. Scaled bbox is not supported.
1848+
return;
1849+
}
18051850
if (state.bbox.scaleMethod === 'auto') {
18061851
// Sync both aspect ratio and size
18071852
const { width, height } = state.bbox.rect;

invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ export type ParamsState = {
7979
clipLEmbedModel: ParameterCLIPLEmbedModel | null;
8080
clipGEmbedModel: ParameterCLIPGEmbedModel | null;
8181
controlLora: ParameterControlLoRAModel | null;
82+
imagen3EnhancePrompt: boolean;
8283
};
8384

8485
const initialState: ParamsState = {
@@ -128,6 +129,7 @@ const initialState: ParamsState = {
128129
clipLEmbedModel: null,
129130
clipGEmbedModel: null,
130131
controlLora: null,
132+
imagen3EnhancePrompt: true,
131133
};
132134

133135
export const paramsSlice = createSlice({
@@ -290,6 +292,9 @@ export const paramsSlice = createSlice({
290292
setCanvasCoherenceMinDenoise: (state, action: PayloadAction<number>) => {
291293
state.canvasCoherenceMinDenoise = action.payload;
292294
},
295+
imagen3EnhancePromptChanged: (state, action: PayloadAction<boolean>) => {
296+
state.imagen3EnhancePrompt = action.payload;
297+
},
293298
paramsReset: (state) => resetState(state),
294299
},
295300
extraReducers(builder) {
@@ -357,6 +362,7 @@ export const {
357362
setRefinerStart,
358363
modelChanged,
359364
paramsReset,
365+
imagen3EnhancePromptChanged,
360366
} = paramsSlice.actions;
361367

362368
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
@@ -380,6 +386,7 @@ export const selectIsSDXL = createParamsSelector((params) => params.model?.base
380386
export const selectIsFLUX = createParamsSelector((params) => params.model?.base === 'flux');
381387
export const selectIsSD3 = createParamsSelector((params) => params.model?.base === 'sd-3');
382388
export const selectIsCogView4 = createParamsSelector((params) => params.model?.base === 'cogview4');
389+
export const selectIsImagen3 = createParamsSelector((params) => params.model?.base === 'imagen3');
383390

384391
export const selectModel = createParamsSelector((params) => params.model);
385392
export const selectModelKey = createParamsSelector((params) => params.model?.key);
@@ -414,6 +421,7 @@ export const selectNegativePrompt = createParamsSelector((params) => params.nega
414421
export const selectPositivePrompt2 = createParamsSelector((params) => params.positivePrompt2);
415422
export const selectNegativePrompt2 = createParamsSelector((params) => params.negativePrompt2);
416423
export const selectShouldConcatPrompts = createParamsSelector((params) => params.shouldConcatPrompts);
424+
export const selectImagen3EnhancePrompt = createParamsSelector((params) => params.imagen3EnhancePrompt);
417425
export const selectScheduler = createParamsSelector((params) => params.scheduler);
418426
export const selectSeamlessXAxis = createParamsSelector((params) => params.seamlessXAxis);
419427
export const selectSeamlessYAxis = createParamsSelector((params) => params.seamlessYAxis);

invokeai/frontend/web/src/features/controlLayers/store/types.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,11 @@ export type StagingAreaImage = {
388388
};
389389

390390
export const zAspectRatioID = z.enum(['Free', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);
391+
export const zImagen3AspectRatioID = z.enum(['16:9', '4:3', '1:1', '3:4', '9:16']);
392+
export const isImagen3AspectRatioID = (v: unknown): v is z.infer<typeof zImagen3AspectRatioID> =>
393+
zImagen3AspectRatioID.safeParse(v).success;
391394
export type AspectRatioID = z.infer<typeof zAspectRatioID>;
392-
export const isAspectRatioID = (v: string): v is AspectRatioID => zAspectRatioID.safeParse(v).success;
395+
export const isAspectRatioID = (v: unknown): v is AspectRatioID => zAspectRatioID.safeParse(v).success;
393396

394397
const zCanvasState = z.object({
395398
_version: z.literal(3),

invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ export const BASE_COLOR_MAP: Record<BaseModelType, string> = {
1616
'sdxl-refiner': 'invokeBlue',
1717
flux: 'gold',
1818
cogview4: 'red',
19+
imagen3: 'pink'
1920
};
2021

2122
const ModelBaseBadge = ({ base }: Props) => {

0 commit comments

Comments
 (0)