Skip to content

Commit 95675c0

Browse files
psychedelicioushipsterusername
authored andcommitted
feat(ui): use zod to define canvas state
By modeling canvas state as a zod schema vs a Typescript type, we get a runtime validator that can be used for metadata recall.
1 parent 4dc1945 commit 95675c0

File tree

7 files changed

+147
-110
lines changed

7 files changed

+147
-110
lines changed

invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ import { getScaledBoundingBoxDimensions } from 'features/controlLayers/util/getS
2424
import { simplifyFlatNumbersArray } from 'features/controlLayers/util/simplify';
2525
import { zModelIdentifierField } from 'features/nodes/types/common';
2626
import { calculateNewSize } from 'features/parameters/components/Bbox/calculateNewSize';
27-
import { ASPECT_RATIO_MAP, initialAspectRatioState } from 'features/parameters/components/Bbox/constants';
28-
import type { AspectRatioID } from 'features/parameters/components/Bbox/types';
27+
import { ASPECT_RATIO_MAP } from 'features/parameters/components/Bbox/constants';
2928
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
3029
import type { IRect } from 'konva/lib/types';
3130
import { merge, omit } from 'lodash-es';
@@ -35,6 +34,7 @@ import type { ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterM
3534
import { assert } from 'tsafe';
3635

3736
import type {
37+
AspectRatioID,
3838
BoundingBoxScaleMethod,
3939
CanvasControlLayerState,
4040
CanvasEntityIdentifier,
@@ -92,7 +92,11 @@ const getInitialState = (): CanvasState => {
9292
bbox: {
9393
rect: { x: 0, y: 0, width: 512, height: 512 },
9494
optimalDimension: 512,
95-
aspectRatio: deepClone(initialAspectRatioState),
95+
aspectRatio: {
96+
id: '1:1',
97+
value: 1,
98+
isLocked: false,
99+
},
96100
scaleMethod: 'auto',
97101
scaledSize: {
98102
width: 512,
@@ -739,7 +743,7 @@ export const canvasSlice = createSlice({
739743
state.bbox.rect.width = width;
740744
state.bbox.rect.height = height;
741745
} else {
742-
state.bbox.aspectRatio = deepClone(initialAspectRatioState);
746+
state.bbox.aspectRatio = deepClone(initialState.bbox.aspectRatio);
743747
state.bbox.rect.width = state.bbox.optimalDimension;
744748
state.bbox.rect.height = state.bbox.optimalDimension;
745749
}

invokeai/frontend/web/src/features/controlLayers/store/types.test.ts

Lines changed: 47 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,66 @@ import type { Invocation } from 'services/api/types';
1919
import type { Equals } from 'tsafe';
2020
import { assert } from 'tsafe';
2121
import { describe, test } from 'vitest';
22+
import type { z } from 'zod';
2223

23-
import type { CLIPVisionModelV2, ControlModeV2, IPMethodV2 } from './types';
24+
import type {
25+
CanvasEntityIdentifier,
26+
CLIPVisionModelV2,
27+
ControlModeV2,
28+
IPMethodV2,
29+
zCanvasEntityIdentifer,
30+
} from './types';
2431

2532
describe('Control Adapter Types', () => {
26-
test('ProcessorType', () => {
33+
test('FilterType', () => {
34+
// FilterType is a union of all filter types. FilterConfig is inferred from a zod union. zod does not support
35+
// extracting a specific field from a union, so FilterType is defined separately. To ensure that FilterType is
36+
// consistent with FilterConfig['type'], we compare the two types.
2737
assert<Equals<FilterConfig['type'], FilterType>>();
2838
});
2939
test('IP Adapter Method', () => {
40+
// This ensures the manually defined IPMethodV2 type is consistent with the type we get from the API.
3041
assert<Equals<NonNullable<Invocation<'ip_adapter'>['method']>, IPMethodV2>>();
3142
});
3243
test('CLIP Vision Model', () => {
44+
// This ensures the manually defined CLIPVisionModelV2 type is consistent with the type we get from the API.
3345
assert<Equals<NonNullable<Invocation<'ip_adapter'>['clip_vision_model']>, CLIPVisionModelV2>>();
3446
});
3547
test('Control Mode', () => {
48+
// This ensures the manually defined ControlModeV2 type is consistent with the type we get from the API.
3649
assert<Equals<NonNullable<Invocation<'controlnet'>['control_mode']>, ControlModeV2>>();
3750
});
3851
test('DepthAnything Model Size', () => {
52+
// This ensures the manually defined DepthAnythingModelSize type is consistent with the type we get from the API.
3953
assert<Equals<NonNullable<Invocation<'depth_anything_depth_estimation'>['model_size']>, DepthAnythingModelSize>>();
4054
});
4155
test('Processor Configs', () => {
56+
// Types derived from OpenAPI
57+
type _CannyEdgeDetectionFilterConfig = Required<
58+
Pick<Invocation<'canny_edge_detection'>, 'type' | 'low_threshold' | 'high_threshold'>
59+
>;
60+
type _ColorMapFilterConfig = Required<Pick<Invocation<'color_map'>, 'type' | 'tile_size'>>;
61+
type _ContentShuffleFilterConfig = Required<Pick<Invocation<'content_shuffle'>, 'type' | 'scale_factor'>>;
62+
type _DepthAnythingFilterConfig = Required<
63+
Pick<Invocation<'depth_anything_depth_estimation'>, 'type' | 'model_size'>
64+
>;
65+
type _HEDEdgeDetectionFilterConfig = Required<Pick<Invocation<'hed_edge_detection'>, 'type' | 'scribble'>>;
66+
type _LineartAnimeEdgeDetectionFilterConfig = Required<Pick<Invocation<'lineart_anime_edge_detection'>, 'type'>>;
67+
type _LineartEdgeDetectionFilterConfig = Required<Pick<Invocation<'lineart_edge_detection'>, 'type' | 'coarse'>>;
68+
type _MediaPipeFaceDetectionFilterConfig = Required<
69+
Pick<Invocation<'mediapipe_face_detection'>, 'type' | 'max_faces' | 'min_confidence'>
70+
>;
71+
type _MLSDDetectionFilterConfig = Required<
72+
Pick<Invocation<'mlsd_detection'>, 'type' | 'score_threshold' | 'distance_threshold'>
73+
>;
74+
type _NormalMapFilterConfig = Required<Pick<Invocation<'normal_map'>, 'type'>>;
75+
type _DWOpenposeDetectionFilterConfig = Required<
76+
Pick<Invocation<'dw_openpose_detection'>, 'type' | 'draw_body' | 'draw_face' | 'draw_hands'>
77+
>;
78+
type _PiDiNetEdgeDetectionFilterConfig = Required<
79+
Pick<Invocation<'pidi_edge_detection'>, 'type' | 'quantize_edges' | 'scribble'>
80+
>;
81+
4282
// The processor configs are manually modeled zod schemas. This test ensures that the inferred types are correct.
4383
// The types prefixed with `_` are types generated from OpenAPI, while the types without the prefix are manually modeled.
4484
assert<Equals<_CannyEdgeDetectionFilterConfig, CannyEdgeDetectionFilterConfig>>();
@@ -54,28 +94,9 @@ describe('Control Adapter Types', () => {
5494
assert<Equals<_DWOpenposeDetectionFilterConfig, DWOpenposeDetectionFilterConfig>>();
5595
assert<Equals<_PiDiNetEdgeDetectionFilterConfig, PiDiNetEdgeDetectionFilterConfig>>();
5696
});
97+
test('CanvasEntityIdentifier', () => {
98+
// The generic type `CanvasEntityIdentifier` is defined manually, but it must be equal to the inferred type from
99+
// the zod schema.
100+
assert<Equals<CanvasEntityIdentifier, z.infer<typeof zCanvasEntityIdentifer>>>();
101+
});
57102
});
58-
59-
// Types derived from OpenAPI
60-
type _CannyEdgeDetectionFilterConfig = Required<
61-
Pick<Invocation<'canny_edge_detection'>, 'type' | 'low_threshold' | 'high_threshold'>
62-
>;
63-
type _ColorMapFilterConfig = Required<Pick<Invocation<'color_map'>, 'type' | 'tile_size'>>;
64-
type _ContentShuffleFilterConfig = Required<Pick<Invocation<'content_shuffle'>, 'type' | 'scale_factor'>>;
65-
type _DepthAnythingFilterConfig = Required<Pick<Invocation<'depth_anything_depth_estimation'>, 'type' | 'model_size'>>;
66-
type _HEDEdgeDetectionFilterConfig = Required<Pick<Invocation<'hed_edge_detection'>, 'type' | 'scribble'>>;
67-
type _LineartAnimeEdgeDetectionFilterConfig = Required<Pick<Invocation<'lineart_anime_edge_detection'>, 'type'>>;
68-
type _LineartEdgeDetectionFilterConfig = Required<Pick<Invocation<'lineart_edge_detection'>, 'type' | 'coarse'>>;
69-
type _MediaPipeFaceDetectionFilterConfig = Required<
70-
Pick<Invocation<'mediapipe_face_detection'>, 'type' | 'max_faces' | 'min_confidence'>
71-
>;
72-
type _MLSDDetectionFilterConfig = Required<
73-
Pick<Invocation<'mlsd_detection'>, 'type' | 'score_threshold' | 'distance_threshold'>
74-
>;
75-
type _NormalMapFilterConfig = Required<Pick<Invocation<'normal_map'>, 'type'>>;
76-
type _DWOpenposeDetectionFilterConfig = Required<
77-
Pick<Invocation<'dw_openpose_detection'>, 'type' | 'draw_body' | 'draw_face' | 'draw_hands'>
78-
>;
79-
type _PiDiNetEdgeDetectionFilterConfig = Required<
80-
Pick<Invocation<'pidi_edge_detection'>, 'type' | 'quantize_edges' | 'scribble'>
81-
>;

invokeai/frontend/web/src/features/controlLayers/store/types.ts

Lines changed: 83 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import type { SerializableObject } from 'common/types';
22
import { zModelIdentifierField } from 'features/nodes/types/common';
3-
import type { AspectRatioState } from 'features/parameters/components/Bbox/types';
4-
import type { ParameterHeight, ParameterLoRAModel, ParameterWidth } from 'features/parameters/types/parameterSchemas';
5-
import { zParameterNegativePrompt, zParameterPositivePrompt } from 'features/parameters/types/parameterSchemas';
3+
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
4+
import {
5+
zParameterImageDimension,
6+
zParameterNegativePrompt,
7+
zParameterPositivePrompt,
8+
} from 'features/parameters/types/parameterSchemas';
69
import type { ImageDTO } from 'services/api/types';
710
import { z } from 'zod';
811

@@ -217,20 +220,36 @@ export type BoundingBoxScaleMethod = z.infer<typeof zBoundingBoxScaleMethod>;
217220
export const isBoundingBoxScaleMethod = (v: unknown): v is BoundingBoxScaleMethod =>
218221
zBoundingBoxScaleMethod.safeParse(v).success;
219222

220-
export type CanvasEntityState =
221-
| CanvasRasterLayerState
222-
| CanvasControlLayerState
223-
| CanvasRegionalGuidanceState
224-
| CanvasInpaintMaskState
225-
| CanvasReferenceImageState;
223+
const zCanvasEntityState = z.discriminatedUnion('type', [
224+
zCanvasRasterLayerState,
225+
zCanvasControlLayerState,
226+
zCanvasRegionalGuidanceState,
227+
zCanvasInpaintMaskState,
228+
zCanvasReferenceImageState,
229+
]);
230+
export type CanvasEntityState = z.infer<typeof zCanvasEntityState>;
226231

227-
export type CanvasRenderableEntityState =
228-
| CanvasRasterLayerState
229-
| CanvasControlLayerState
230-
| CanvasRegionalGuidanceState
231-
| CanvasInpaintMaskState;
232+
const zCanvasRenderableEntityState = z.discriminatedUnion('type', [
233+
zCanvasRasterLayerState,
234+
zCanvasControlLayerState,
235+
zCanvasRegionalGuidanceState,
236+
zCanvasInpaintMaskState,
237+
]);
238+
export type CanvasRenderableEntityState = z.infer<typeof zCanvasRenderableEntityState>;
239+
240+
const zCanvasEntityType = z.union([
241+
zCanvasRasterLayerState.shape.type,
242+
zCanvasControlLayerState.shape.type,
243+
zCanvasRegionalGuidanceState.shape.type,
244+
zCanvasInpaintMaskState.shape.type,
245+
zCanvasReferenceImageState.shape.type,
246+
]);
247+
export type CanvasEntityType = z.infer<typeof zCanvasEntityType>;
232248

233-
export type CanvasEntityType = CanvasEntityState['type'];
249+
export const zCanvasEntityIdentifer = z.object({
250+
id: zId,
251+
type: zCanvasEntityType,
252+
});
234253
export type CanvasEntityIdentifier<T extends CanvasEntityType = CanvasEntityType> = { id: string; type: T };
235254

236255
export type LoRA = {
@@ -246,45 +265,55 @@ export type StagingAreaImage = {
246265
offsetY: number;
247266
};
248267

249-
export type CanvasState = {
250-
_version: 3;
251-
selectedEntityIdentifier: CanvasEntityIdentifier | null;
252-
bookmarkedEntityIdentifier: CanvasEntityIdentifier | null;
253-
inpaintMasks: {
254-
isHidden: boolean;
255-
entities: CanvasInpaintMaskState[];
256-
};
257-
rasterLayers: {
258-
isHidden: boolean;
259-
entities: CanvasRasterLayerState[];
260-
};
261-
controlLayers: {
262-
isHidden: boolean;
263-
entities: CanvasControlLayerState[];
264-
};
265-
regionalGuidance: {
266-
isHidden: boolean;
267-
entities: CanvasRegionalGuidanceState[];
268-
};
269-
referenceImages: {
270-
entities: CanvasReferenceImageState[];
271-
};
272-
bbox: {
273-
rect: {
274-
x: number;
275-
y: number;
276-
width: ParameterWidth;
277-
height: ParameterHeight;
278-
};
279-
aspectRatio: AspectRatioState;
280-
scaledSize: {
281-
width: ParameterWidth;
282-
height: ParameterHeight;
283-
};
284-
scaleMethod: BoundingBoxScaleMethod;
285-
optimalDimension: number;
286-
};
287-
};
268+
const zAspectRatioID = z.enum(['Free', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);
269+
export type AspectRatioID = z.infer<typeof zAspectRatioID>;
270+
export const isAspectRatioID = (v: string): v is AspectRatioID => zAspectRatioID.safeParse(v).success;
271+
272+
const zCanvasState = z.object({
273+
_version: z.literal(3),
274+
selectedEntityIdentifier: zCanvasEntityIdentifer.nullable(),
275+
bookmarkedEntityIdentifier: zCanvasEntityIdentifer.nullable(),
276+
inpaintMasks: z.object({
277+
isHidden: z.boolean(),
278+
entities: z.array(zCanvasInpaintMaskState),
279+
}),
280+
rasterLayers: z.object({
281+
isHidden: z.boolean(),
282+
entities: z.array(zCanvasRasterLayerState),
283+
}),
284+
controlLayers: z.object({
285+
isHidden: z.boolean(),
286+
entities: z.array(zCanvasControlLayerState),
287+
}),
288+
regionalGuidance: z.object({
289+
isHidden: z.boolean(),
290+
entities: z.array(zCanvasRegionalGuidanceState),
291+
}),
292+
referenceImages: z.object({
293+
entities: z.array(zCanvasReferenceImageState),
294+
}),
295+
bbox: z.object({
296+
rect: z.object({
297+
x: z.number().int(),
298+
y: z.number().int(),
299+
width: zParameterImageDimension,
300+
height: zParameterImageDimension,
301+
}),
302+
aspectRatio: z.object({
303+
id: zAspectRatioID,
304+
value: z.number().gt(0),
305+
isLocked: z.boolean(),
306+
}),
307+
scaledSize: z.object({
308+
width: zParameterImageDimension,
309+
height: zParameterImageDimension,
310+
}),
311+
scaleMethod: zBoundingBoxScaleMethod,
312+
optimalDimension: z.number().int().positive(),
313+
}),
314+
});
315+
316+
export type CanvasState = z.infer<typeof zCanvasState>;
288317

289318
export type StageAttrs = {
290319
x: Coordinate['x'];

invokeai/frontend/web/src/features/parameters/components/Bbox/BboxAspectRatioSelect.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ import type { SingleValue } from 'chakra-react-select';
55
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
66
import { bboxAspectRatioIdChanged } from 'features/controlLayers/store/canvasSlice';
77
import { selectAspectRatioID } from 'features/controlLayers/store/selectors';
8+
import { isAspectRatioID } from 'features/controlLayers/store/types';
89
import { ASPECT_RATIO_OPTIONS } from 'features/parameters/components/Bbox/constants';
9-
import { isAspectRatioID } from 'features/parameters/components/Bbox/types';
1010
import { memo, useCallback, useMemo } from 'react';
1111
import { useTranslation } from 'react-i18next';
1212

invokeai/frontend/web/src/features/parameters/components/Bbox/constants.ts

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import type { ComboboxOption } from '@invoke-ai/ui-library';
2-
3-
import type { AspectRatioID, AspectRatioState } from './types';
2+
import type { AspectRatioID } from 'features/controlLayers/store/types';
43

54
export const ASPECT_RATIO_OPTIONS: ComboboxOption[] = [
65
{ label: 'Free' as const, value: 'Free' },
@@ -22,9 +21,3 @@ export const ASPECT_RATIO_MAP: Record<Exclude<AspectRatioID, 'Free'>, { ratio: n
2221
'2:3': { ratio: 2 / 3, inverseID: '3:2' },
2322
'9:16': { ratio: 9 / 16, inverseID: '16:9' },
2423
};
25-
26-
export const initialAspectRatioState: AspectRatioState = {
27-
id: '1:1',
28-
value: 1,
29-
isLocked: false,
30-
};

invokeai/frontend/web/src/features/parameters/components/Bbox/types.ts

Lines changed: 0 additions & 11 deletions
This file was deleted.

invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,18 +82,19 @@ export const isParameterSeed = (val: unknown): val is ParameterSeed => zParamete
8282
// #endregion
8383

8484
// #region Width
85-
const zParameterWidth = z
85+
export const zParameterImageDimension = z
8686
.number()
8787
.min(64)
8888
.transform((val) => roundToMultiple(val, 8));
89-
export type ParameterWidth = z.infer<typeof zParameterWidth>;
90-
export const isParameterWidth = (val: unknown): val is ParameterWidth => zParameterWidth.safeParse(val).success;
89+
export type ParameterWidth = z.infer<typeof zParameterImageDimension>;
90+
export const isParameterWidth = (val: unknown): val is ParameterWidth =>
91+
zParameterImageDimension.safeParse(val).success;
9192
// #endregion
9293

9394
// #region Height
94-
const zParameterHeight = zParameterWidth;
95-
export type ParameterHeight = z.infer<typeof zParameterHeight>;
96-
export const isParameterHeight = (val: unknown): val is ParameterHeight => zParameterHeight.safeParse(val).success;
95+
export type ParameterHeight = z.infer<typeof zParameterImageDimension>;
96+
export const isParameterHeight = (val: unknown): val is ParameterHeight =>
97+
zParameterImageDimension.safeParse(val).success;
9798
// #endregion
9899

99100
// #region Model

0 commit comments

Comments
 (0)