Skip to content

Commit f909e81

Browse files
feat(ui): better types & runtime guarantees for filter data stored in konva node attrs
1 parent 8c85f16 commit f909e81

File tree

3 files changed

+38
-29
lines changed

3 files changed

+38
-29
lines changed

invokeai/frontend/web/src/features/controlLayers/konva/filters.ts

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
*/
55

66
import { clamp } from 'es-toolkit/compat';
7+
import { zCurvesAdjustmentsLUTs, zSimpleAdjustmentsConfig } from 'features/controlLayers/store/types';
78
import type Konva from 'konva';
89

910
/**
@@ -24,25 +25,18 @@ export const LightnessToAlphaFilter = (imageData: ImageData): void => {
2425
}
2526
};
2627

27-
type SimpleAdjustParams = {
28-
brightness: number; // -1..1 (additive)
29-
contrast: number; // -1..1 (scale around 128)
30-
saturation: number; // -1..1
31-
temperature: number; // -1..1 (blue<->yellow approx)
32-
tint: number; // -1..1 (green<->magenta approx)
33-
sharpness: number; // -1..1 (light unsharp mask)
34-
};
35-
3628
/**
3729
* Per-layer simple adjustments filter (brightness, contrast, saturation, temp, tint, sharpness).
3830
*
3931
* Parameters are read from the Konva node attr `adjustmentsSimple` set by the adapter.
4032
*/
4133
export const AdjustmentsSimpleFilter = function (this: Konva.Node, imageData: ImageData): void {
42-
const params = (this?.getAttr?.('adjustmentsSimple') as SimpleAdjustParams | undefined) ?? null;
43-
if (!params) {
34+
const paramsRaw = this.getAttr('adjustmentsSimple');
35+
const parseResult = zSimpleAdjustmentsConfig.safeParse(paramsRaw);
36+
if (!parseResult.success) {
4437
return;
4538
}
39+
const params = parseResult.data;
4640

4741
const { brightness, contrast, saturation, temperature, tint, sharpness } = params;
4842

@@ -172,19 +166,19 @@ export const buildCurveLUT = (points: Array<[number, number]>): number[] => {
172166
return lut;
173167
};
174168

175-
type CurvesAdjustParams = {
176-
master: number[];
177-
r: number[];
178-
g: number[];
179-
b: number[];
180-
};
181-
182-
// Curves filter: apply master curve, then per-channel curves
169+
/**
170+
* Per-layer curves adjustments filter (master, r, g, b)
171+
*
172+
* Parameters are read from the Konva node attr `adjustmentsCurves` set by the adapter.
173+
*/
183174
export const AdjustmentsCurvesFilter = function (this: Konva.Node, imageData: ImageData): void {
184-
const params = (this?.getAttr?.('adjustmentsCurves') as CurvesAdjustParams | undefined) ?? null;
185-
if (!params) {
175+
const paramsRaw = this.getAttr('adjustmentsCurves');
176+
const parseResult = zCurvesAdjustmentsLUTs.safeParse(paramsRaw);
177+
if (!parseResult.success) {
186178
return;
187179
}
180+
const params = parseResult.data;
181+
188182
const { master, r, g, b } = params;
189183
if (!master || !r || !g || !b) {
190184
return;

invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import type {
2828
RasterLayerAdjustments,
2929
RegionalGuidanceRefImageState,
3030
RgbColor,
31-
SimpleConfig,
31+
SimpleAdjustmentsConfig,
3232
} from 'features/controlLayers/store/types';
3333
import {
3434
calculateNewSize,
@@ -148,7 +148,7 @@ const slice = createSlice({
148148
},
149149
rasterLayerAdjustmentsSimpleUpdated: (
150150
state,
151-
action: PayloadAction<EntityIdentifierPayload<{ simple: Partial<SimpleConfig> }, 'raster_layer'>>
151+
action: PayloadAction<EntityIdentifierPayload<{ simple: Partial<SimpleAdjustmentsConfig> }, 'raster_layer'>>
152152
) => {
153153
const { entityIdentifier, simple } = action.payload;
154154
const layer = selectEntity(state, entityIdentifier);

invokeai/frontend/web/src/features/controlLayers/store/types.ts

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -378,31 +378,46 @@ const zControlLoRAConfig = z.object({
378378
});
379379
export type ControlLoRAConfig = z.infer<typeof zControlLoRAConfig>;
380380

381-
const zSimpleConfig = z.object({
382-
// All simple params normalized to [-1, 1] except sharpness [0, 1]
381+
/**
382+
* All simple params normalized to `[-1, 1]` except sharpness `[0, 1]`.
383+
*
384+
* - Brightness: -1 (darken) to 1 (brighten)
385+
* - Contrast: -1 (decrease contrast) to 1 (increase contrast)
386+
* - Saturation: -1 (desaturate) to 1 (saturate)
387+
* - Temperature: -1 (cooler/blue) to 1 (warmer/yellow)
388+
* - Tint: -1 (greener) to 1 (more magenta)
389+
* - Sharpness: 0 (no sharpening) to 1 (maximum sharpening)
390+
*/
391+
export const zSimpleAdjustmentsConfig = z.object({
383392
brightness: z.number().gte(-1).lte(1),
384393
contrast: z.number().gte(-1).lte(1),
385394
saturation: z.number().gte(-1).lte(1),
386395
temperature: z.number().gte(-1).lte(1),
387396
tint: z.number().gte(-1).lte(1),
388397
sharpness: z.number().gte(0).lte(1),
389398
});
390-
export type SimpleConfig = z.infer<typeof zSimpleConfig>;
399+
export type SimpleAdjustmentsConfig = z.infer<typeof zSimpleAdjustmentsConfig>;
391400

392401
const zUint8 = z.number().int().min(0).max(255);
393402
const zChannelPoints = z.array(z.tuple([zUint8, zUint8])).min(2);
394403
const zChannelName = z.enum(['master', 'r', 'g', 'b']);
395-
const zCurvesConfig = z.record(zChannelName, zChannelPoints);
404+
const zCurvesAdjustmentsConfig = z.record(zChannelName, zChannelPoints);
396405
export type ChannelName = z.infer<typeof zChannelName>;
397406
export type ChannelPoints = z.infer<typeof zChannelPoints>;
398407

408+
/**
409+
* The curves adjustments are stored as LUTs in the Konva node attributes. Konva will use these values when applying
410+
* the filter.
411+
*/
412+
export const zCurvesAdjustmentsLUTs = z.record(zChannelName, z.array(zUint8));
413+
399414
const zRasterLayerAdjustments = z.object({
400415
version: z.literal(1),
401416
enabled: z.boolean(),
402417
collapsed: z.boolean(),
403418
mode: z.enum(['simple', 'curves']),
404-
simple: zSimpleConfig,
405-
curves: zCurvesConfig,
419+
simple: zSimpleAdjustmentsConfig,
420+
curves: zCurvesAdjustmentsConfig,
406421
});
407422
export type RasterLayerAdjustments = z.infer<typeof zRasterLayerAdjustments>;
408423

0 commit comments

Comments
 (0)