diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json
index 7b1e24457f3..edb29ead3fb 100644
--- a/invokeai/frontend/web/public/locales/en.json
+++ b/invokeai/frontend/web/public/locales/en.json
@@ -2022,6 +2022,24 @@
"pullBboxIntoLayerError": "Problem Pulling BBox Into Layer",
"pullBboxIntoReferenceImageOk": "Bbox Pulled Into ReferenceImage",
"pullBboxIntoReferenceImageError": "Problem Pulling BBox Into ReferenceImage",
+ "addAdjustments": "Add Adjustments",
+ "removeAdjustments": "Remove Adjustments",
+ "adjustments": {
+ "simple": "Simple",
+ "curves": "Curves",
+ "heading": "Adjustments",
+ "expand": "Expand adjustments",
+ "collapse": "Collapse adjustments",
+ "brightness": "Brightness",
+ "contrast": "Contrast",
+ "saturation": "Saturation",
+ "temperature": "Temperature",
+ "tint": "Tint",
+ "sharpness": "Sharpness",
+ "finish": "Finish",
+ "reset": "Reset",
+ "master": "Master"
+ },
"regionIsEmpty": "Selected region is empty",
"mergeVisible": "Merge Visible",
"mergeDown": "Merge Down",
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayer.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayer.tsx
index ddaefb1073e..13dc30dea20 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayer.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayer.tsx
@@ -4,6 +4,7 @@ import { CanvasEntityHeader } from 'features/controlLayers/components/common/Can
import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions';
import { CanvasEntityPreviewImage } from 'features/controlLayers/components/common/CanvasEntityPreviewImage';
import { CanvasEntityEditableTitle } from 'features/controlLayers/components/common/CanvasEntityTitleEdit';
+import { RasterLayerAdjustmentsPanel } from 'features/controlLayers/components/RasterLayer/RasterLayerAdjustmentsPanel';
import { CanvasEntityStateGate } from 'features/controlLayers/contexts/CanvasEntityStateGate';
import { RasterLayerAdapterGate } from 'features/controlLayers/contexts/EntityAdapterContext';
import { EntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
@@ -39,6 +40,7 @@ export const RasterLayer = memo(({ id }: Props) => {
+
{
+ const dispatch = useAppDispatch();
+ const entityIdentifier = useEntityIdentifierContext<'raster_layer'>();
+ const canvasManager = useCanvasManager();
+ const layer = useAppSelector((s) => selectEntity(s.canvas.present, entityIdentifier));
+ const { t } = useTranslation();
+
+ const hasAdjustments = Boolean(layer?.adjustments);
+ const enabled = Boolean(layer?.adjustments?.enabled);
+ const collapsed = Boolean(layer?.adjustments?.collapsed);
+ const mode = layer?.adjustments?.mode ?? 'simple';
+ const simple = layer?.adjustments?.simple ?? {
+ brightness: 0,
+ contrast: 0,
+ saturation: 0,
+ temperature: 0,
+ tint: 0,
+ sharpness: 0,
+ };
+
+ const onToggleEnabled = useCallback(
+ (v: boolean) => {
+ // Only toggle the enabled state; preserve current mode/collapsed so users can A/B compare
+ dispatch(rasterLayerAdjustmentsSet({ entityIdentifier, adjustments: { enabled: v } }));
+ },
+ [dispatch, entityIdentifier]
+ );
+
+ const onReset = useCallback(() => {
+ // Reset values to defaults but keep adjustments present; preserve enabled/collapsed/mode
+ dispatch(
+ rasterLayerAdjustmentsSimpleUpdated({
+ entityIdentifier,
+ simple: {
+ brightness: 0,
+ contrast: 0,
+ saturation: 0,
+ temperature: 0,
+ tint: 0,
+ sharpness: 0,
+ },
+ })
+ );
+ const defaultPoints: Array<[number, number]> = [
+ [0, 0],
+ [255, 255],
+ ];
+ dispatch(rasterLayerAdjustmentsCurvesUpdated({ entityIdentifier, channel: 'master', points: defaultPoints }));
+ dispatch(rasterLayerAdjustmentsCurvesUpdated({ entityIdentifier, channel: 'r', points: defaultPoints }));
+ dispatch(rasterLayerAdjustmentsCurvesUpdated({ entityIdentifier, channel: 'g', points: defaultPoints }));
+ dispatch(rasterLayerAdjustmentsCurvesUpdated({ entityIdentifier, channel: 'b', points: defaultPoints }));
+ }, [dispatch, entityIdentifier]);
+
+ const onToggleCollapsed = useCallback(() => {
+ dispatch(
+ rasterLayerAdjustmentsSet({
+ entityIdentifier,
+ adjustments: { collapsed: !collapsed },
+ })
+ );
+ }, [dispatch, entityIdentifier, collapsed]);
+
+ const onSetMode = useCallback(
+ (nextMode: 'simple' | 'curves') => {
+ if (!layer?.adjustments) {
+ return;
+ }
+ if (nextMode === mode) {
+ return;
+ }
+ dispatch(
+ rasterLayerAdjustmentsSet({
+ entityIdentifier,
+ adjustments: { mode: nextMode },
+ })
+ );
+ },
+ [dispatch, entityIdentifier, layer?.adjustments, mode]
+ );
+
+ // Memoized click handlers to avoid inline arrow functions in JSX
+ const onClickModeSimple = useCallback(() => onSetMode('simple'), [onSetMode]);
+ const onClickModeCurves = useCallback(() => onSetMode('curves'), [onSetMode]);
+
+ const slider = useMemo(
+ () =>
+ ({
+ row: (label: string, value: number, onChange: (v: number) => void, min = -1, max = 1, step = 0.01) => (
+
+
+
+ {label}
+
+
+
+
+
+ ),
+ }) as const,
+ []
+ );
+
+ const onBrightness = useCallback(
+ (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { brightness: v } })),
+ [dispatch, entityIdentifier]
+ );
+ const onContrast = useCallback(
+ (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { contrast: v } })),
+ [dispatch, entityIdentifier]
+ );
+ const onSaturation = useCallback(
+ (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { saturation: v } })),
+ [dispatch, entityIdentifier]
+ );
+ const onTemperature = useCallback(
+ (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { temperature: v } })),
+ [dispatch, entityIdentifier]
+ );
+ const onTint = useCallback(
+ (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { tint: v } })),
+ [dispatch, entityIdentifier]
+ );
+ const onSharpness = useCallback(
+ (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { sharpness: v } })),
+ [dispatch, entityIdentifier]
+ );
+
+ const handleToggleEnabled = useCallback(
+ (e: React.ChangeEvent) => onToggleEnabled(e.target.checked),
+ [onToggleEnabled]
+ );
+
+ const onFinish = useCallback(async () => {
+ // Bake current visual into layer pixels, then clear adjustments
+ const adapter = canvasManager.getAdapter(entityIdentifier);
+ if (!adapter || adapter.type !== 'raster_layer_adapter') {
+ return;
+ }
+ const rect = adapter.transformer.getRelativeRect();
+ try {
+ await adapter.renderer.rasterize({ rect, replaceObjects: true });
+ // Clear adjustments after baking
+ dispatch(rasterLayerAdjustmentsSet({ entityIdentifier, adjustments: null }));
+ } catch {
+ // no-op; leave state unchanged on failure
+ }
+ }, [canvasManager, entityIdentifier, dispatch]);
+
+ // Hide the panel entirely until adjustments are added via context menu
+ if (!hasAdjustments) {
+ return null;
+ }
+
+ return (
+ <>
+
+
+ }
+ />
+
+ Adjustments
+
+
+
+
+
+
+
+
+
+
+ {!collapsed && mode === 'simple' && (
+ <>
+ {slider.row(t('controlLayers.adjustments.brightness'), simple.brightness, onBrightness)}
+ {slider.row(t('controlLayers.adjustments.contrast'), simple.contrast, onContrast)}
+ {slider.row(t('controlLayers.adjustments.saturation'), simple.saturation, onSaturation)}
+ {slider.row(t('controlLayers.adjustments.temperature'), simple.temperature, onTemperature)}
+ {slider.row(t('controlLayers.adjustments.tint'), simple.tint, onTint)}
+ {slider.row(t('controlLayers.adjustments.sharpness'), simple.sharpness, onSharpness, 0, 1, 0.01)}
+ >
+ )}
+
+ {!collapsed && mode === 'curves' && }
+ >
+ );
+});
+
+RasterLayerAdjustmentsPanel.displayName = 'RasterLayerAdjustmentsPanel';
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesEditor.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesEditor.tsx
new file mode 100644
index 00000000000..930afe05873
--- /dev/null
+++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesEditor.tsx
@@ -0,0 +1,429 @@
+import { Flex, Text } from '@invoke-ai/ui-library';
+import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
+import { useEntityAdapterContext } from 'features/controlLayers/contexts/EntityAdapterContext';
+import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
+import { rasterLayerAdjustmentsCurvesUpdated } from 'features/controlLayers/store/canvasSlice';
+import { selectEntity } from 'features/controlLayers/store/selectors';
+import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
+import React, { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
+import { useTranslation } from 'react-i18next';
+
+const DEFAULT_POINTS: Array<[number, number]> = [
+ [0, 0],
+ [255, 255],
+];
+
+type Channel = 'master' | 'r' | 'g' | 'b';
+
+const channelColor: Record = {
+ master: '#888',
+ r: '#e53e3e',
+ g: '#38a169',
+ b: '#3182ce',
+};
+
+const clamp = (v: number, min: number, max: number) => (v < min ? min : v > max ? max : v);
+
+const sortPoints = (pts: Array<[number, number]>) =>
+ [...pts]
+ .sort((a, b) => a[0] - b[0])
+ .map(([x, y]) => [clamp(Math.round(x), 0, 255), clamp(Math.round(y), 0, 255)] as [number, number]);
+
+type CurveGraphProps = {
+ title: string;
+ channel: Channel;
+ points: Array<[number, number]> | undefined;
+ histogram: number[] | null;
+ onChange: (pts: Array<[number, number]>) => void;
+};
+
+const CurveGraph = memo(function CurveGraph(props: CurveGraphProps) {
+ const { title, channel, points, histogram, onChange } = props;
+ const canvasRef = useRef(null);
+ const [localPoints, setLocalPoints] = useState>(sortPoints(points ?? DEFAULT_POINTS));
+ const [dragIndex, setDragIndex] = useState(null);
+
+ useEffect(() => {
+ setLocalPoints(sortPoints(points ?? DEFAULT_POINTS));
+ }, [points]);
+
+ const width = 256;
+ const height = 160;
+ // inner margins to keep a small buffer from edges (left/right/bottom)
+ const MARGIN_LEFT = 8;
+ const MARGIN_RIGHT = 8;
+ const MARGIN_TOP = 8;
+ const MARGIN_BOTTOM = 10;
+ const INNER_WIDTH = width - MARGIN_LEFT - MARGIN_RIGHT;
+ const INNER_HEIGHT = height - MARGIN_TOP - MARGIN_BOTTOM;
+
+ // helpers to map value-space [0..255] to canvas pixels (respecting inner margins)
+ const valueToCanvasX = useCallback(
+ (x: number) => MARGIN_LEFT + (clamp(x, 0, 255) / 255) * INNER_WIDTH,
+ [INNER_WIDTH]
+ );
+ const valueToCanvasY = useCallback(
+ (y: number) => MARGIN_TOP + INNER_HEIGHT - (clamp(y, 0, 255) / 255) * INNER_HEIGHT,
+ [INNER_HEIGHT]
+ );
+ const canvasToValueX = useCallback(
+ (cx: number) => clamp(Math.round(((cx - MARGIN_LEFT) / INNER_WIDTH) * 255), 0, 255),
+ [INNER_WIDTH]
+ );
+ const canvasToValueY = useCallback(
+ (cy: number) => clamp(Math.round(255 - ((cy - MARGIN_TOP) / INNER_HEIGHT) * 255), 0, 255),
+ [INNER_HEIGHT]
+ );
+
+ const draw = useCallback(() => {
+ const c = canvasRef.current;
+ if (!c) {
+ return;
+ }
+ c.width = width;
+ c.height = height;
+ const ctx = c.getContext('2d');
+ if (!ctx) {
+ return;
+ }
+
+ // background
+ ctx.clearRect(0, 0, width, height);
+ ctx.fillStyle = '#111';
+ ctx.fillRect(0, 0, width, height);
+
+ // grid inside inner rect
+ ctx.strokeStyle = '#2a2a2a';
+ ctx.lineWidth = 1;
+ for (let i = 0; i <= 4; i++) {
+ const y = MARGIN_TOP + (i * INNER_HEIGHT) / 4;
+ ctx.beginPath();
+ ctx.moveTo(MARGIN_LEFT + 0.5, y + 0.5);
+ ctx.lineTo(MARGIN_LEFT + INNER_WIDTH - 0.5, y + 0.5);
+ ctx.stroke();
+ }
+ for (let i = 0; i <= 4; i++) {
+ const x = MARGIN_LEFT + (i * INNER_WIDTH) / 4;
+ ctx.beginPath();
+ ctx.moveTo(x + 0.5, MARGIN_TOP + 0.5);
+ ctx.lineTo(x + 0.5, MARGIN_TOP + INNER_HEIGHT - 0.5);
+ ctx.stroke();
+ }
+
+ // histogram
+ if (histogram) {
+ // logarithmic histogram for readability when values vary widely
+ const logHist = histogram.map((v) => Math.log10((v ?? 0) + 1));
+ const max = Math.max(1e-6, ...logHist);
+ ctx.fillStyle = '#5557';
+ const binW = Math.max(1, INNER_WIDTH / 256);
+ for (let i = 0; i < 256; i++) {
+ const v = logHist[i] ?? 0;
+ const h = Math.round((v / max) * (INNER_HEIGHT - 2));
+ const x = MARGIN_LEFT + Math.floor(i * binW);
+ const y = MARGIN_TOP + INNER_HEIGHT - h;
+ ctx.fillRect(x, y, Math.ceil(binW), h);
+ }
+ }
+
+ // curve
+ const pts = sortPoints(localPoints);
+ ctx.strokeStyle = channelColor[channel];
+ ctx.lineWidth = 2;
+ ctx.beginPath();
+ for (let i = 0; i < pts.length; i++) {
+ const [x, y] = pts[i]!;
+ const cx = valueToCanvasX(x);
+ const cy = valueToCanvasY(y);
+ if (i === 0) {
+ ctx.moveTo(cx, cy);
+ } else {
+ ctx.lineTo(cx, cy);
+ }
+ }
+ ctx.stroke();
+
+ // control points
+ for (let i = 0; i < pts.length; i++) {
+ const [x, y] = pts[i]!;
+ const cx = valueToCanvasX(x);
+ const cy = valueToCanvasY(y);
+ ctx.fillStyle = '#000';
+ ctx.beginPath();
+ ctx.arc(cx, cy, 3.5, 0, Math.PI * 2);
+ ctx.fill();
+ ctx.strokeStyle = channelColor[channel];
+ ctx.lineWidth = 1.5;
+ ctx.stroke();
+ }
+ }, [
+ MARGIN_LEFT,
+ MARGIN_TOP,
+ INNER_HEIGHT,
+ INNER_WIDTH,
+ channel,
+ height,
+ histogram,
+ localPoints,
+ valueToCanvasX,
+ valueToCanvasY,
+ width,
+ ]);
+
+ useEffect(() => {
+ draw();
+ }, [draw]);
+
+ const getNearestPointIndex = useCallback(
+ (mxCanvas: number, myCanvas: number) => {
+ // convert canvas px to value-space [0..255]
+ const xVal = canvasToValueX(mxCanvas);
+ const yVal = canvasToValueY(myCanvas);
+ let best = -1;
+ let bestDist = 9999;
+ for (let i = 0; i < localPoints.length; i++) {
+ const [px, py] = localPoints[i]!;
+ const dx = px - xVal;
+ const dy = py - yVal;
+ const d = dx * dx + dy * dy;
+ if (d < bestDist) {
+ best = i;
+ bestDist = d;
+ }
+ }
+ if (best !== -1 && bestDist <= 20 * 20) {
+ return best;
+ }
+ return -1;
+ },
+ [canvasToValueX, canvasToValueY, localPoints]
+ );
+
+ const handlePointerDown = useCallback(
+ (e: React.PointerEvent) => {
+ e.preventDefault();
+ e.stopPropagation();
+ const c = canvasRef.current;
+ if (!c) {
+ return;
+ }
+ const rect = c.getBoundingClientRect();
+ const scaleX = c.width / rect.width;
+ const scaleY = c.height / rect.height;
+ const mxCanvas = (e.clientX - rect.left) * scaleX;
+ const myCanvas = (e.clientY - rect.top) * scaleY;
+ const idx = getNearestPointIndex(mxCanvas, myCanvas);
+ if (idx !== -1 && idx !== 0 && idx !== localPoints.length - 1) {
+ setDragIndex(idx);
+ return;
+ }
+ // add new point
+ const xVal = canvasToValueX(mxCanvas);
+ const yVal = canvasToValueY(myCanvas);
+ const next = sortPoints([...localPoints, [xVal, yVal]]);
+ setLocalPoints(next);
+ setDragIndex(next.findIndex(([x, y]) => x === xVal && y === yVal));
+ },
+ [canvasToValueX, canvasToValueY, getNearestPointIndex, localPoints]
+ );
+
+ const handlePointerMove = useCallback(
+ (e: React.PointerEvent) => {
+ e.preventDefault();
+ e.stopPropagation();
+ if (dragIndex === null) {
+ return;
+ }
+ const c = canvasRef.current;
+ if (!c) {
+ return;
+ }
+ const rect = c.getBoundingClientRect();
+ const scaleX = c.width / rect.width;
+ const scaleY = c.height / rect.height;
+ const mxCanvas = (e.clientX - rect.left) * scaleX;
+ const myCanvas = (e.clientY - rect.top) * scaleY;
+ const mxVal = canvasToValueX(mxCanvas);
+ const myVal = canvasToValueY(myCanvas);
+ setLocalPoints((prev) => {
+ const next = [...prev];
+ // clamp endpoints to ends and keep them immutable
+ if (dragIndex === 0) {
+ return prev;
+ }
+ if (dragIndex === prev.length - 1) {
+ return prev;
+ }
+ next[dragIndex] = [mxVal, myVal];
+ return sortPoints(next);
+ });
+ },
+ [canvasToValueX, canvasToValueY, dragIndex]
+ );
+
+ const commit = useCallback(
+ (pts: Array<[number, number]>) => {
+ onChange(sortPoints(pts));
+ },
+ [onChange]
+ );
+
+ const handlePointerUp = useCallback(() => {
+ setDragIndex(null);
+ commit(localPoints);
+ }, [commit, localPoints]);
+
+ const handleDoubleClick = useCallback(
+ (e: React.MouseEvent) => {
+ e.preventDefault();
+ e.stopPropagation();
+ const c = canvasRef.current;
+ if (!c) {
+ return;
+ }
+ const rect = c.getBoundingClientRect();
+ const scaleX = c.width / rect.width;
+ const scaleY = c.height / rect.height;
+ const mxCanvas = (e.clientX - rect.left) * scaleX;
+ const myCanvas = (e.clientY - rect.top) * scaleY;
+ const idx = getNearestPointIndex(mxCanvas, myCanvas);
+ if (idx > 0 && idx < localPoints.length - 1) {
+ const next = localPoints.filter((_, i) => i !== idx);
+ setLocalPoints(next);
+ commit(next);
+ }
+ },
+ [commit, getNearestPointIndex, localPoints]
+ );
+
+ const canvasStyle = useMemo(
+ () => ({ width: '100%', height: height, touchAction: 'none', borderRadius: 4, background: '#111' }),
+ [height]
+ );
+
+ return (
+
+
+ {title}
+
+
+
+ );
+});
+
+export const RasterLayerCurvesEditor = memo(() => {
+ const dispatch = useAppDispatch();
+ const entityIdentifier = useEntityIdentifierContext<'raster_layer'>();
+ const adapter = useEntityAdapterContext<'raster_layer'>('raster_layer');
+ const { t } = useTranslation();
+ const layer = useAppSelector((s) => selectEntity(s.canvas.present, entityIdentifier)) as
+ | CanvasRasterLayerState
+ | undefined;
+
+ const [histMaster, setHistMaster] = useState(null);
+ const [histR, setHistR] = useState(null);
+ const [histG, setHistG] = useState(null);
+ const [histB, setHistB] = useState(null);
+
+ const pointsMaster = layer?.adjustments?.curves.master ?? DEFAULT_POINTS;
+ const pointsR = layer?.adjustments?.curves.r ?? DEFAULT_POINTS;
+ const pointsG = layer?.adjustments?.curves.g ?? DEFAULT_POINTS;
+ const pointsB = layer?.adjustments?.curves.b ?? DEFAULT_POINTS;
+
+ const recalcHistogram = useCallback(() => {
+ try {
+ const rect = adapter.transformer.getRelativeRect();
+ if (rect.width === 0 || rect.height === 0) {
+ setHistMaster(Array(256).fill(0));
+ setHistR(Array(256).fill(0));
+ setHistG(Array(256).fill(0));
+ setHistB(Array(256).fill(0));
+ return;
+ }
+ const imageData = adapter.renderer.getImageData({ rect });
+ const data = imageData.data;
+ const len = data.length / 4;
+ const master = new Array(256).fill(0);
+ const r = new Array(256).fill(0);
+ const g = new Array(256).fill(0);
+ const b = new Array(256).fill(0);
+ // sample every 4th pixel to lighten work
+ for (let i = 0; i < len; i += 4) {
+ const idx = i * 4;
+ const rv = data[idx] as number;
+ const gv = data[idx + 1] as number;
+ const bv = data[idx + 2] as number;
+ const m = Math.round(0.2126 * rv + 0.7152 * gv + 0.0722 * bv);
+ if (m >= 0 && m < 256) {
+ master[m] = (master[m] ?? 0) + 1;
+ }
+ if (rv >= 0 && rv < 256) {
+ r[rv] = (r[rv] ?? 0) + 1;
+ }
+ if (gv >= 0 && gv < 256) {
+ g[gv] = (g[gv] ?? 0) + 1;
+ }
+ if (bv >= 0 && bv < 256) {
+ b[bv] = (b[bv] ?? 0) + 1;
+ }
+ }
+ setHistMaster(master);
+ setHistR(r);
+ setHistG(g);
+ setHistB(b);
+ } catch {
+ // ignore
+ }
+ }, [adapter]);
+
+ useEffect(() => {
+ recalcHistogram();
+ // eslint-disable-next-line react-hooks/exhaustive-deps
+ }, [layer?.objects, layer?.adjustments]);
+
+ const onChangePoints = useCallback(
+ (channel: Channel, pts: Array<[number, number]>) => {
+ dispatch(rasterLayerAdjustmentsCurvesUpdated({ entityIdentifier, channel, points: pts }));
+ },
+ [dispatch, entityIdentifier]
+ );
+
+ // Memoize per-channel change handlers to avoid inline lambdas in JSX
+ const onChangeMaster = useCallback((pts: Array<[number, number]>) => onChangePoints('master', pts), [onChangePoints]);
+ const onChangeR = useCallback((pts: Array<[number, number]>) => onChangePoints('r', pts), [onChangePoints]);
+ const onChangeG = useCallback((pts: Array<[number, number]>) => onChangePoints('g', pts), [onChangePoints]);
+ const onChangeB = useCallback((pts: Array<[number, number]>) => onChangePoints('b', pts), [onChangePoints]);
+
+ const gridStyles: React.CSSProperties = useMemo(
+ () => ({ display: 'grid', gridTemplateColumns: 'repeat(2, minmax(0, 1fr))', gap: 8 }),
+ []
+ );
+
+ return (
+
+
+
+
+
+
+
+
+ );
+});
+
+RasterLayerCurvesEditor.displayName = 'RasterLayerCurvesEditor';
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItems.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItems.tsx
index 65a16a7b4f9..708f7f29cd6 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItems.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItems.tsx
@@ -9,6 +9,7 @@ import { CanvasEntityMenuItemsMergeDown } from 'features/controlLayers/component
import { CanvasEntityMenuItemsSave } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSave';
import { CanvasEntityMenuItemsSelectObject } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSelectObject';
import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform';
+import { RasterLayerMenuItemsAdjustments } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsAdjustments';
import { RasterLayerMenuItemsConvertToSubMenu } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsConvertToSubMenu';
import { RasterLayerMenuItemsCopyToSubMenu } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsCopyToSubMenu';
import { memo } from 'react';
@@ -21,10 +22,10 @@ export const RasterLayerMenuItems = memo(() => {
-
+
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItemsAdjustments.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItemsAdjustments.tsx
new file mode 100644
index 00000000000..77a939b7bf9
--- /dev/null
+++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItemsAdjustments.tsx
@@ -0,0 +1,38 @@
+import { MenuItem } from '@invoke-ai/ui-library';
+import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
+import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
+import { rasterLayerAdjustmentsReset, rasterLayerAdjustmentsSet } from 'features/controlLayers/store/canvasSlice';
+import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
+import { memo, useCallback } from 'react';
+import { useTranslation } from 'react-i18next';
+import { PiSlidersHorizontalBold } from 'react-icons/pi';
+
+export const RasterLayerMenuItemsAdjustments = memo(() => {
+ const dispatch = useAppDispatch();
+ const entityIdentifier = useEntityIdentifierContext<'raster_layer'>();
+ const { t } = useTranslation();
+ const layer = useAppSelector((s) =>
+ s.canvas.present.rasterLayers.entities.find((e: CanvasRasterLayerState) => e.id === entityIdentifier.id)
+ );
+ const hasAdjustments = Boolean(layer?.adjustments);
+ const onToggleAdjustmentsPresence = useCallback(() => {
+ if (hasAdjustments) {
+ dispatch(rasterLayerAdjustmentsReset({ entityIdentifier }));
+ } else {
+ dispatch(
+ rasterLayerAdjustmentsSet({
+ entityIdentifier,
+ adjustments: { enabled: true, collapsed: false, mode: 'simple' },
+ })
+ );
+ }
+ }, [dispatch, entityIdentifier, hasAdjustments]);
+
+ return (
+ }>
+ {hasAdjustments ? t('controlLayers.removeAdjustments') : t('controlLayers.addAdjustments')}
+
+ );
+});
+
+RasterLayerMenuItemsAdjustments.displayName = 'RasterLayerMenuItemsAdjustments';
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts
index 6c55e949377..5995e80663c 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts
@@ -475,7 +475,7 @@ export abstract class CanvasEntityAdapterBase => {
const { rect } = this.manager.stateApi.getBbox();
const rasterizeResult = await withResultAsync(() =>
- this.renderer.rasterize({ rect, replaceObjects: true, attrs: { opacity: 1, filters: [] } })
+ this.renderer.rasterize({ rect, replaceObjects: true, attrs: { opacity: 1 } })
);
if (rasterizeResult.isErr()) {
toast({ status: 'error', title: 'Failed to crop to bbox' });
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer.ts
index 7e31b594fac..06620584fc5 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer.ts
@@ -72,7 +72,7 @@ export class CanvasEntityAdapterControlLayer extends CanvasEntityAdapterBase<
this.log.trace({ rect }, 'Getting canvas');
// The opacity may have been changed in response to user selecting a different entity category, so we must restore
// the original opacity before rendering the canvas
- const attrs: GroupConfig = { opacity: this.state.opacity, filters: [] };
+ const attrs: GroupConfig = { opacity: this.state.opacity };
const canvas = this.renderer.getCanvas({ rect, attrs });
return canvas;
};
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer.ts
index ce138c38321..cd8dee6d2f3 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer.ts
@@ -1,4 +1,4 @@
-import { omit } from 'es-toolkit/compat';
+import { omit, throttle } from 'es-toolkit/compat';
import { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase';
import { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer';
import { CanvasEntityFilterer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer';
@@ -6,6 +6,7 @@ import { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasE
import { CanvasEntityTransformer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityTransformer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasSegmentAnythingModule } from 'features/controlLayers/konva/CanvasSegmentAnythingModule';
+import { AdjustmentsCurvesFilter, AdjustmentsSimpleFilter, buildCurveLUT } from 'features/controlLayers/konva/filters';
import type { CanvasEntityIdentifier, CanvasRasterLayerState, Rect } from 'features/controlLayers/store/types';
import type { GroupConfig } from 'konva/lib/Group';
import type { JsonObject } from 'type-fest';
@@ -59,13 +60,18 @@ export class CanvasEntityAdapterRasterLayer extends CanvasEntityAdapterBase<
if (!prevState || this.state.opacity !== prevState.opacity) {
this.syncOpacity();
}
+
+ // Apply per-layer adjustments as a Konva filter
+ if (!prevState || this.haveAdjustmentsChanged(prevState, this.state)) {
+ this.syncAdjustmentsFilter();
+ }
};
getCanvas = (rect?: Rect): HTMLCanvasElement => {
this.log.trace({ rect }, 'Getting canvas');
// The opacity may have been changed in response to user selecting a different entity category, so we must restore
// the original opacity before rendering the canvas
- const attrs: GroupConfig = { opacity: this.state.opacity, filters: [] };
+ const attrs: GroupConfig = { opacity: this.state.opacity };
const canvas = this.renderer.getCanvas({ rect, attrs });
return canvas;
};
@@ -74,4 +80,79 @@ export class CanvasEntityAdapterRasterLayer extends CanvasEntityAdapterBase<
const keysToOmit: (keyof CanvasRasterLayerState)[] = ['name', 'isLocked'];
return omit(this.state, keysToOmit);
};
+
+ private syncAdjustmentsFilter = () => {
+ const a = this.state.adjustments;
+ const apply = !!a && a.enabled;
+ // The filter operates on the renderer's object group; we can set filters at the group level via renderer
+ const group = this.renderer.konva.objectGroup;
+ if (apply) {
+ const filters = group.filters() ?? [];
+ let nextFilters = filters.filter((f: unknown) => f !== AdjustmentsSimpleFilter && f !== AdjustmentsCurvesFilter);
+ if (a.mode === 'simple') {
+ group.setAttr('adjustmentsSimple', a.simple);
+ group.setAttr('adjustmentsCurves', null);
+ nextFilters = [...nextFilters, AdjustmentsSimpleFilter];
+ } else {
+ // Build LUTs and set curves attr
+ const master = buildCurveLUT(a.curves.master);
+ const r = buildCurveLUT(a.curves.r);
+ const g = buildCurveLUT(a.curves.g);
+ const b = buildCurveLUT(a.curves.b);
+ group.setAttr('adjustmentsCurves', { master, r, g, b });
+ group.setAttr('adjustmentsSimple', null);
+ nextFilters = [...nextFilters, AdjustmentsCurvesFilter];
+ }
+ group.filters(nextFilters);
+ this._throttledCacheRefresh();
+ } else {
+ // Remove our filter if present
+ const filters = (group.filters() ?? []).filter(
+ (f: unknown) => f !== AdjustmentsSimpleFilter && f !== AdjustmentsCurvesFilter
+ );
+ group.filters(filters);
+ group.setAttr('adjustmentsSimple', null);
+ group.setAttr('adjustmentsCurves', null);
+ this._throttledCacheRefresh();
+ }
+ };
+
+ private _throttledCacheRefresh = throttle(() => this.renderer.syncKonvaCache(true), 50);
+
+ private haveAdjustmentsChanged = (prevState: CanvasRasterLayerState, currState: CanvasRasterLayerState): boolean => {
+ const pa = prevState.adjustments;
+ const ca = currState.adjustments;
+ if (pa === ca) {
+ return false;
+ }
+ if (!pa || !ca) {
+ return true;
+ }
+ if (pa.enabled !== ca.enabled) {
+ return true;
+ }
+ if (pa.mode !== ca.mode) {
+ return true;
+ }
+ // simple params
+ const ps = pa.simple;
+ const cs = ca.simple;
+ if (
+ ps.brightness !== cs.brightness ||
+ ps.contrast !== cs.contrast ||
+ ps.saturation !== cs.saturation ||
+ ps.temperature !== cs.temperature ||
+ ps.tint !== cs.tint ||
+ ps.sharpness !== cs.sharpness
+ ) {
+ return true;
+ }
+ // curves reference (UI not implemented yet) - if arrays differ by ref, consider changed
+ const pc = pa.curves;
+ const cc = ca.curves;
+ if (pc !== cc) {
+ return true;
+ }
+ return false;
+ };
}
diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts b/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts
index 34c5c9ac5de..6a8a704e136 100644
--- a/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts
@@ -20,3 +20,188 @@ export const LightnessToAlphaFilter = (imageData: ImageData): void => {
imageData.data[i * 4 + 3] = Math.min(a, (cMin + cMax) / 2);
}
};
+
+// Utility clamp
+const clamp = (v: number, min: number, max: number) => (v < min ? min : v > max ? max : v);
+
+type SimpleAdjustParams = {
+ brightness: number; // -1..1 (additive)
+ contrast: number; // -1..1 (scale around 128)
+ saturation: number; // -1..1
+ temperature: number; // -1..1 (blue<->yellow approx)
+ tint: number; // -1..1 (green<->magenta approx)
+ sharpness: number; // -1..1 (light unsharp mask)
+};
+
+/**
+ * Per-layer simple adjustments filter (brightness, contrast, saturation, temp, tint, sharpness).
+ *
+ * Parameters are read from the Konva node attr `adjustmentsSimple` set by the adapter.
+ */
+type KonvaFilterThis = { getAttr?: (key: string) => unknown };
+export const AdjustmentsSimpleFilter = function (this: KonvaFilterThis, imageData: ImageData): void {
+ const params = (this?.getAttr?.('adjustmentsSimple') as SimpleAdjustParams | undefined) ?? null;
+ if (!params) {
+ return;
+ }
+
+ const { brightness, contrast, saturation, temperature, tint, sharpness } = params;
+
+ const data = imageData.data;
+ const len = data.length / 4;
+ const width = (imageData as ImageData & { width: number }).width ?? 0;
+ const height = (imageData as ImageData & { height: number }).height ?? 0;
+
+ // Precompute factors
+ const brightnessShift = brightness * 255; // additive shift
+ const contrastFactor = 1 + contrast; // scale around 128
+
+ // Temperature/Tint multipliers
+ const tempK = 0.5;
+ const tintK = 0.5;
+ const rTempMul = 1 + temperature * tempK;
+ const bTempMul = 1 - temperature * tempK;
+ // Tint: green <-> magenta. Positive = magenta (R/B up, G down). Negative = green (G up, R/B down).
+ const t = clamp(tint, -1, 1) * tintK;
+ const mag = Math.abs(t);
+ const rTintMul = t >= 0 ? 1 + mag : 1 - mag;
+ const gTintMul = t >= 0 ? 1 - mag : 1 + mag;
+ const bTintMul = t >= 0 ? 1 + mag : 1 - mag;
+
+ // Saturation matrix (HSL-based approximation via luma coefficients)
+ const lumaR = 0.2126;
+ const lumaG = 0.7152;
+ const lumaB = 0.0722;
+ const S = 1 + saturation; // 0..2
+ const m00 = lumaR * (1 - S) + S;
+ const m01 = lumaG * (1 - S);
+ const m02 = lumaB * (1 - S);
+ const m10 = lumaR * (1 - S);
+ const m11 = lumaG * (1 - S) + S;
+ const m12 = lumaB * (1 - S);
+ const m20 = lumaR * (1 - S);
+ const m21 = lumaG * (1 - S);
+ const m22 = lumaB * (1 - S) + S;
+
+ // First pass: apply per-pixel color adjustments (excluding sharpness)
+ for (let i = 0; i < len; i++) {
+ const idx = i * 4;
+ let r = data[idx + 0] as number;
+ let g = data[idx + 1] as number;
+ let b = data[idx + 2] as number;
+ const a = data[idx + 3] as number;
+
+ // Brightness (additive)
+ r = r + brightnessShift;
+ g = g + brightnessShift;
+ b = b + brightnessShift;
+
+ // Contrast around mid-point 128
+ r = (r - 128) * contrastFactor + 128;
+ g = (g - 128) * contrastFactor + 128;
+ b = (b - 128) * contrastFactor + 128;
+
+ // Temperature (R/B axis) and Tint (G vs Magenta)
+ r = r * rTempMul * rTintMul;
+ g = g * gTintMul;
+ b = b * bTempMul * bTintMul;
+
+ // Saturation via matrix
+ const r2 = r * m00 + g * m01 + b * m02;
+ const g2 = r * m10 + g * m11 + b * m12;
+ const b2 = r * m20 + g * m21 + b * m22;
+
+ data[idx + 0] = clamp(r2, 0, 255);
+ data[idx + 1] = clamp(g2, 0, 255);
+ data[idx + 2] = clamp(b2, 0, 255);
+ data[idx + 3] = a;
+ }
+
+ // Optional sharpen (simple unsharp mask with 3x3 kernel)
+ if (Math.abs(sharpness) > 1e-3 && width > 2 && height > 2) {
+ const src = new Uint8ClampedArray(data); // copy of modified data
+ const a = Math.max(-1, Math.min(1, sharpness)) * 0.5; // amount
+ const center = 1 + 4 * a;
+ const neighbor = -a;
+ for (let y = 1; y < height - 1; y++) {
+ for (let x = 1; x < width - 1; x++) {
+ const idx = (y * width + x) * 4;
+ for (let c = 0; c < 3; c++) {
+ const centerPx = src[idx + c] ?? 0;
+ const leftPx = src[idx - 4 + c] ?? 0;
+ const rightPx = src[idx + 4 + c] ?? 0;
+ const topPx = src[idx - width * 4 + c] ?? 0;
+ const bottomPx = src[idx + width * 4 + c] ?? 0;
+ const v = centerPx * center + leftPx * neighbor + rightPx * neighbor + topPx * neighbor + bottomPx * neighbor;
+ data[idx + c] = clamp(v, 0, 255);
+ }
+ // preserve alpha
+ }
+ }
+ }
+};
+
+// Build a 256-length LUT from 0..255 control points (linear interpolation for v1)
+export const buildCurveLUT = (points: Array<[number, number]>): number[] => {
+ if (!points || points.length === 0) {
+ return Array.from({ length: 256 }, (_, i) => i);
+ }
+ const pts = points
+ .map(([x, y]) => [clamp(Math.round(x), 0, 255), clamp(Math.round(y), 0, 255)] as [number, number])
+ .sort((a, b) => a[0] - b[0]);
+ if ((pts[0]?.[0] ?? 0) !== 0) {
+ pts.unshift([0, pts[0]?.[1] ?? 0]);
+ }
+ const last = pts[pts.length - 1];
+ if ((last?.[0] ?? 255) !== 255) {
+ pts.push([255, last?.[1] ?? 255]);
+ }
+ const lut = new Array(256);
+ let j = 0;
+ for (let x = 0; x <= 255; x++) {
+ while (j < pts.length - 2 && x > (pts[j + 1]?.[0] ?? 255)) {
+ j++;
+ }
+ const p0 = pts[j] ?? [0, 0];
+ const p1 = pts[j + 1] ?? [255, 255];
+ const [x0, y0] = p0;
+ const [x1, y1] = p1;
+ const t = x1 === x0 ? 0 : (x - x0) / (x1 - x0);
+ const y = y0 + (y1 - y0) * t;
+ lut[x] = clamp(Math.round(y), 0, 255);
+ }
+ return lut;
+};
+
+type CurvesAdjustParams = {
+ master: number[];
+ r: number[];
+ g: number[];
+ b: number[];
+};
+
+// Curves filter: apply master curve, then per-channel curves
+export const AdjustmentsCurvesFilter = function (this: KonvaFilterThis, imageData: ImageData): void {
+ const params = (this?.getAttr?.('adjustmentsCurves') as CurvesAdjustParams | undefined) ?? null;
+ if (!params) {
+ return;
+ }
+ const { master, r, g, b } = params;
+ if (!master || !r || !g || !b) {
+ return;
+ }
+ const data = imageData.data;
+ const len = data.length / 4;
+ for (let i = 0; i < len; i++) {
+ const idx = i * 4;
+ const r0 = data[idx + 0] as number;
+ const g0 = data[idx + 1] as number;
+ const b0 = data[idx + 2] as number;
+ const rm = master[r0] ?? r0;
+ const gm = master[g0] ?? g0;
+ const bm = master[b0] ?? b0;
+ data[idx + 0] = clamp(r[rm] ?? rm, 0, 255);
+ data[idx + 1] = clamp(g[gm] ?? gm, 0, 255);
+ data[idx + 2] = clamp(b[bm] ?? bm, 0, 255);
+ }
+};
diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts
index 61168a0ec5a..bcc8fb9ffc5 100644
--- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts
@@ -102,6 +102,165 @@ const slice = createSlice({
reducers: {
// undoable canvas state
//#region Raster layers
+ rasterLayerAdjustmentsSet: (
+ state,
+ action: PayloadAction<
+ EntityIdentifierPayload<
+ {
+ adjustments:
+ | NonNullable
+ | { enabled?: boolean; collapsed?: boolean; mode?: 'simple' | 'curves' }
+ | null;
+ },
+ 'raster_layer'
+ >
+ >
+ ) => {
+ const { entityIdentifier, adjustments } = action.payload;
+ const layer = selectEntity(state, entityIdentifier);
+ if (!layer) {
+ return;
+ }
+ if (adjustments === null) {
+ layer.adjustments = null;
+ return;
+ }
+ if (layer.adjustments === null) {
+ layer.adjustments = {
+ version: 1,
+ enabled: true,
+ collapsed: false,
+ mode: 'simple',
+ simple: { brightness: 0, contrast: 0, saturation: 0, temperature: 0, tint: 0, sharpness: 0 },
+ curves: {
+ master: [
+ [0, 0],
+ [255, 255],
+ ],
+ r: [
+ [0, 0],
+ [255, 255],
+ ],
+ g: [
+ [0, 0],
+ [255, 255],
+ ],
+ b: [
+ [0, 0],
+ [255, 255],
+ ],
+ },
+ };
+ }
+ if (typeof adjustments === 'object' && adjustments !== null && 'version' in adjustments) {
+ layer.adjustments = merge(layer.adjustments, adjustments as NonNullable);
+ } else {
+ // Shallow toggles only
+ const partial = adjustments as { enabled?: boolean; collapsed?: boolean; mode?: 'simple' | 'curves' };
+ layer.adjustments = merge(layer.adjustments, partial);
+ }
+ },
+ rasterLayerAdjustmentsReset: (state, action: PayloadAction>) => {
+ const { entityIdentifier } = action.payload;
+ const layer = selectEntity(state, entityIdentifier);
+ if (!layer) {
+ return;
+ }
+ layer.adjustments = null;
+ },
+ rasterLayerAdjustmentsSimpleUpdated: (
+ state,
+ action: PayloadAction<
+ EntityIdentifierPayload<
+ {
+ simple: Partial['simple']>>;
+ },
+ 'raster_layer'
+ >
+ >
+ ) => {
+ const { entityIdentifier, simple } = action.payload;
+ const layer = selectEntity(state, entityIdentifier);
+ if (!layer) {
+ return;
+ }
+ if (!layer.adjustments) {
+ // initialize baseline
+ layer.adjustments = {
+ version: 1,
+ enabled: true,
+ collapsed: false,
+ mode: 'simple',
+ simple: { brightness: 0, contrast: 0, saturation: 0, temperature: 0, tint: 0, sharpness: 0 },
+ curves: {
+ master: [
+ [0, 0],
+ [255, 255],
+ ],
+ r: [
+ [0, 0],
+ [255, 255],
+ ],
+ g: [
+ [0, 0],
+ [255, 255],
+ ],
+ b: [
+ [0, 0],
+ [255, 255],
+ ],
+ },
+ };
+ }
+ layer.adjustments.simple = merge(layer.adjustments.simple, simple);
+ },
+ rasterLayerAdjustmentsCurvesUpdated: (
+ state,
+ action: PayloadAction<
+ EntityIdentifierPayload<
+ {
+ channel: 'master' | 'r' | 'g' | 'b';
+ points: Array<[number, number]>;
+ },
+ 'raster_layer'
+ >
+ >
+ ) => {
+ const { entityIdentifier, channel, points } = action.payload;
+ const layer = selectEntity(state, entityIdentifier);
+ if (!layer) {
+ return;
+ }
+ if (!layer.adjustments) {
+ // initialize baseline
+ layer.adjustments = {
+ version: 1,
+ enabled: true,
+ collapsed: false,
+ mode: 'curves',
+ simple: { brightness: 0, contrast: 0, saturation: 0, temperature: 0, tint: 0, sharpness: 0 },
+ curves: {
+ master: [
+ [0, 0],
+ [255, 255],
+ ],
+ r: [
+ [0, 0],
+ [255, 255],
+ ],
+ g: [
+ [0, 0],
+ [255, 255],
+ ],
+ b: [
+ [0, 0],
+ [255, 255],
+ ],
+ },
+ };
+ }
+ layer.adjustments.curves[channel] = points;
+ },
rasterLayerAdded: {
reducer: (
state,
@@ -1621,6 +1780,11 @@ export const {
entityBrushLineAdded,
entityEraserLineAdded,
entityRectAdded,
+ // Raster layer adjustments
+ rasterLayerAdjustmentsSet,
+ rasterLayerAdjustmentsReset,
+ rasterLayerAdjustmentsSimpleUpdated,
+ rasterLayerAdjustmentsCurvesUpdated,
entityDeleted,
entityArrangedForwardOne,
entityArrangedToFront,
diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts
index f771f8c7469..fdba9fa3f03 100644
--- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts
@@ -368,6 +368,32 @@ const zCanvasRasterLayerState = zCanvasEntityBase.extend({
position: zCoordinate,
opacity: zOpacity,
objects: z.array(zCanvasObjectState),
+ // Optional per-layer color adjustments (simple + curves). When null/undefined, no adjustments are applied.
+ adjustments: z
+ .object({
+ version: z.literal(1),
+ enabled: z.boolean(),
+ collapsed: z.boolean(),
+ mode: z.enum(['simple', 'curves']),
+ simple: z.object({
+ // All simple params normalized to [-1, 1] except sharpness [0, 1]
+ brightness: z.number().gte(-1).lte(1),
+ contrast: z.number().gte(-1).lte(1),
+ saturation: z.number().gte(-1).lte(1),
+ temperature: z.number().gte(-1).lte(1),
+ tint: z.number().gte(-1).lte(1),
+ sharpness: z.number().gte(0).lte(1),
+ }),
+ curves: z.object({
+ // Curves are arrays of [x, y] control points in 0..255 space (no strict monotonic checks here)
+ master: z.array(z.tuple([z.number().int().min(0).max(255), z.number().int().min(0).max(255)])).min(2),
+ r: z.array(z.tuple([z.number().int().min(0).max(255), z.number().int().min(0).max(255)])).min(2),
+ g: z.array(z.tuple([z.number().int().min(0).max(255), z.number().int().min(0).max(255)])).min(2),
+ b: z.array(z.tuple([z.number().int().min(0).max(255), z.number().int().min(0).max(255)])).min(2),
+ }),
+ })
+ .optional()
+ .nullable(),
});
export type CanvasRasterLayerState = z.infer;
diff --git a/invokeai/frontend/web/src/features/controlLayers/store/util.ts b/invokeai/frontend/web/src/features/controlLayers/store/util.ts
index 2d40cf17793..e14cfd546f4 100644
--- a/invokeai/frontend/web/src/features/controlLayers/store/util.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/store/util.ts
@@ -198,6 +198,7 @@ export const getRasterLayerState = (
objects: [],
opacity: 1,
position: { x: 0, y: 0 },
+ adjustments: null,
};
merge(entityState, overrides);
return entityState;