diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 7b1e24457f3..edb29ead3fb 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -2022,6 +2022,24 @@ "pullBboxIntoLayerError": "Problem Pulling BBox Into Layer", "pullBboxIntoReferenceImageOk": "Bbox Pulled Into ReferenceImage", "pullBboxIntoReferenceImageError": "Problem Pulling BBox Into ReferenceImage", + "addAdjustments": "Add Adjustments", + "removeAdjustments": "Remove Adjustments", + "adjustments": { + "simple": "Simple", + "curves": "Curves", + "heading": "Adjustments", + "expand": "Expand adjustments", + "collapse": "Collapse adjustments", + "brightness": "Brightness", + "contrast": "Contrast", + "saturation": "Saturation", + "temperature": "Temperature", + "tint": "Tint", + "sharpness": "Sharpness", + "finish": "Finish", + "reset": "Reset", + "master": "Master" + }, "regionIsEmpty": "Selected region is empty", "mergeVisible": "Merge Visible", "mergeDown": "Merge Down", diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayer.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayer.tsx index ddaefb1073e..13dc30dea20 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayer.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayer.tsx @@ -4,6 +4,7 @@ import { CanvasEntityHeader } from 'features/controlLayers/components/common/Can import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions'; import { CanvasEntityPreviewImage } from 'features/controlLayers/components/common/CanvasEntityPreviewImage'; import { CanvasEntityEditableTitle } from 'features/controlLayers/components/common/CanvasEntityTitleEdit'; +import { RasterLayerAdjustmentsPanel } from 'features/controlLayers/components/RasterLayer/RasterLayerAdjustmentsPanel'; import { CanvasEntityStateGate } from 'features/controlLayers/contexts/CanvasEntityStateGate'; import { RasterLayerAdapterGate } from 'features/controlLayers/contexts/EntityAdapterContext'; import { EntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; @@ -39,6 +40,7 @@ export const RasterLayer = memo(({ id }: Props) => { + { + const dispatch = useAppDispatch(); + const entityIdentifier = useEntityIdentifierContext<'raster_layer'>(); + const canvasManager = useCanvasManager(); + const layer = useAppSelector((s) => selectEntity(s.canvas.present, entityIdentifier)); + const { t } = useTranslation(); + + const hasAdjustments = Boolean(layer?.adjustments); + const enabled = Boolean(layer?.adjustments?.enabled); + const collapsed = Boolean(layer?.adjustments?.collapsed); + const mode = layer?.adjustments?.mode ?? 'simple'; + const simple = layer?.adjustments?.simple ?? { + brightness: 0, + contrast: 0, + saturation: 0, + temperature: 0, + tint: 0, + sharpness: 0, + }; + + const onToggleEnabled = useCallback( + (v: boolean) => { + // Only toggle the enabled state; preserve current mode/collapsed so users can A/B compare + dispatch(rasterLayerAdjustmentsSet({ entityIdentifier, adjustments: { enabled: v } })); + }, + [dispatch, entityIdentifier] + ); + + const onReset = useCallback(() => { + // Reset values to defaults but keep adjustments present; preserve enabled/collapsed/mode + dispatch( + rasterLayerAdjustmentsSimpleUpdated({ + entityIdentifier, + simple: { + brightness: 0, + contrast: 0, + saturation: 0, + temperature: 0, + tint: 0, + sharpness: 0, + }, + }) + ); + const defaultPoints: Array<[number, number]> = [ + [0, 0], + [255, 255], + ]; + dispatch(rasterLayerAdjustmentsCurvesUpdated({ entityIdentifier, channel: 'master', points: defaultPoints })); + dispatch(rasterLayerAdjustmentsCurvesUpdated({ entityIdentifier, channel: 'r', points: defaultPoints })); + dispatch(rasterLayerAdjustmentsCurvesUpdated({ entityIdentifier, channel: 'g', points: defaultPoints })); + dispatch(rasterLayerAdjustmentsCurvesUpdated({ entityIdentifier, channel: 'b', points: defaultPoints })); + }, [dispatch, entityIdentifier]); + + const onToggleCollapsed = useCallback(() => { + dispatch( + rasterLayerAdjustmentsSet({ + entityIdentifier, + adjustments: { collapsed: !collapsed }, + }) + ); + }, [dispatch, entityIdentifier, collapsed]); + + const onSetMode = useCallback( + (nextMode: 'simple' | 'curves') => { + if (!layer?.adjustments) { + return; + } + if (nextMode === mode) { + return; + } + dispatch( + rasterLayerAdjustmentsSet({ + entityIdentifier, + adjustments: { mode: nextMode }, + }) + ); + }, + [dispatch, entityIdentifier, layer?.adjustments, mode] + ); + + // Memoized click handlers to avoid inline arrow functions in JSX + const onClickModeSimple = useCallback(() => onSetMode('simple'), [onSetMode]); + const onClickModeCurves = useCallback(() => onSetMode('curves'), [onSetMode]); + + const slider = useMemo( + () => + ({ + row: (label: string, value: number, onChange: (v: number) => void, min = -1, max = 1, step = 0.01) => ( + + + + {label} + + + + + + ), + }) as const, + [] + ); + + const onBrightness = useCallback( + (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { brightness: v } })), + [dispatch, entityIdentifier] + ); + const onContrast = useCallback( + (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { contrast: v } })), + [dispatch, entityIdentifier] + ); + const onSaturation = useCallback( + (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { saturation: v } })), + [dispatch, entityIdentifier] + ); + const onTemperature = useCallback( + (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { temperature: v } })), + [dispatch, entityIdentifier] + ); + const onTint = useCallback( + (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { tint: v } })), + [dispatch, entityIdentifier] + ); + const onSharpness = useCallback( + (v: number) => dispatch(rasterLayerAdjustmentsSimpleUpdated({ entityIdentifier, simple: { sharpness: v } })), + [dispatch, entityIdentifier] + ); + + const handleToggleEnabled = useCallback( + (e: React.ChangeEvent) => onToggleEnabled(e.target.checked), + [onToggleEnabled] + ); + + const onFinish = useCallback(async () => { + // Bake current visual into layer pixels, then clear adjustments + const adapter = canvasManager.getAdapter(entityIdentifier); + if (!adapter || adapter.type !== 'raster_layer_adapter') { + return; + } + const rect = adapter.transformer.getRelativeRect(); + try { + await adapter.renderer.rasterize({ rect, replaceObjects: true }); + // Clear adjustments after baking + dispatch(rasterLayerAdjustmentsSet({ entityIdentifier, adjustments: null })); + } catch { + // no-op; leave state unchanged on failure + } + }, [canvasManager, entityIdentifier, dispatch]); + + // Hide the panel entirely until adjustments are added via context menu + if (!hasAdjustments) { + return null; + } + + return ( + <> + + + } + /> + + Adjustments + + + + + + + + + + + {!collapsed && mode === 'simple' && ( + <> + {slider.row(t('controlLayers.adjustments.brightness'), simple.brightness, onBrightness)} + {slider.row(t('controlLayers.adjustments.contrast'), simple.contrast, onContrast)} + {slider.row(t('controlLayers.adjustments.saturation'), simple.saturation, onSaturation)} + {slider.row(t('controlLayers.adjustments.temperature'), simple.temperature, onTemperature)} + {slider.row(t('controlLayers.adjustments.tint'), simple.tint, onTint)} + {slider.row(t('controlLayers.adjustments.sharpness'), simple.sharpness, onSharpness, 0, 1, 0.01)} + + )} + + {!collapsed && mode === 'curves' && } + + ); +}); + +RasterLayerAdjustmentsPanel.displayName = 'RasterLayerAdjustmentsPanel'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesEditor.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesEditor.tsx new file mode 100644 index 00000000000..930afe05873 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerCurvesEditor.tsx @@ -0,0 +1,429 @@ +import { Flex, Text } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useEntityAdapterContext } from 'features/controlLayers/contexts/EntityAdapterContext'; +import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; +import { rasterLayerAdjustmentsCurvesUpdated } from 'features/controlLayers/store/canvasSlice'; +import { selectEntity } from 'features/controlLayers/store/selectors'; +import type { CanvasRasterLayerState } from 'features/controlLayers/store/types'; +import React, { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'; +import { useTranslation } from 'react-i18next'; + +const DEFAULT_POINTS: Array<[number, number]> = [ + [0, 0], + [255, 255], +]; + +type Channel = 'master' | 'r' | 'g' | 'b'; + +const channelColor: Record = { + master: '#888', + r: '#e53e3e', + g: '#38a169', + b: '#3182ce', +}; + +const clamp = (v: number, min: number, max: number) => (v < min ? min : v > max ? max : v); + +const sortPoints = (pts: Array<[number, number]>) => + [...pts] + .sort((a, b) => a[0] - b[0]) + .map(([x, y]) => [clamp(Math.round(x), 0, 255), clamp(Math.round(y), 0, 255)] as [number, number]); + +type CurveGraphProps = { + title: string; + channel: Channel; + points: Array<[number, number]> | undefined; + histogram: number[] | null; + onChange: (pts: Array<[number, number]>) => void; +}; + +const CurveGraph = memo(function CurveGraph(props: CurveGraphProps) { + const { title, channel, points, histogram, onChange } = props; + const canvasRef = useRef(null); + const [localPoints, setLocalPoints] = useState>(sortPoints(points ?? DEFAULT_POINTS)); + const [dragIndex, setDragIndex] = useState(null); + + useEffect(() => { + setLocalPoints(sortPoints(points ?? DEFAULT_POINTS)); + }, [points]); + + const width = 256; + const height = 160; + // inner margins to keep a small buffer from edges (left/right/bottom) + const MARGIN_LEFT = 8; + const MARGIN_RIGHT = 8; + const MARGIN_TOP = 8; + const MARGIN_BOTTOM = 10; + const INNER_WIDTH = width - MARGIN_LEFT - MARGIN_RIGHT; + const INNER_HEIGHT = height - MARGIN_TOP - MARGIN_BOTTOM; + + // helpers to map value-space [0..255] to canvas pixels (respecting inner margins) + const valueToCanvasX = useCallback( + (x: number) => MARGIN_LEFT + (clamp(x, 0, 255) / 255) * INNER_WIDTH, + [INNER_WIDTH] + ); + const valueToCanvasY = useCallback( + (y: number) => MARGIN_TOP + INNER_HEIGHT - (clamp(y, 0, 255) / 255) * INNER_HEIGHT, + [INNER_HEIGHT] + ); + const canvasToValueX = useCallback( + (cx: number) => clamp(Math.round(((cx - MARGIN_LEFT) / INNER_WIDTH) * 255), 0, 255), + [INNER_WIDTH] + ); + const canvasToValueY = useCallback( + (cy: number) => clamp(Math.round(255 - ((cy - MARGIN_TOP) / INNER_HEIGHT) * 255), 0, 255), + [INNER_HEIGHT] + ); + + const draw = useCallback(() => { + const c = canvasRef.current; + if (!c) { + return; + } + c.width = width; + c.height = height; + const ctx = c.getContext('2d'); + if (!ctx) { + return; + } + + // background + ctx.clearRect(0, 0, width, height); + ctx.fillStyle = '#111'; + ctx.fillRect(0, 0, width, height); + + // grid inside inner rect + ctx.strokeStyle = '#2a2a2a'; + ctx.lineWidth = 1; + for (let i = 0; i <= 4; i++) { + const y = MARGIN_TOP + (i * INNER_HEIGHT) / 4; + ctx.beginPath(); + ctx.moveTo(MARGIN_LEFT + 0.5, y + 0.5); + ctx.lineTo(MARGIN_LEFT + INNER_WIDTH - 0.5, y + 0.5); + ctx.stroke(); + } + for (let i = 0; i <= 4; i++) { + const x = MARGIN_LEFT + (i * INNER_WIDTH) / 4; + ctx.beginPath(); + ctx.moveTo(x + 0.5, MARGIN_TOP + 0.5); + ctx.lineTo(x + 0.5, MARGIN_TOP + INNER_HEIGHT - 0.5); + ctx.stroke(); + } + + // histogram + if (histogram) { + // logarithmic histogram for readability when values vary widely + const logHist = histogram.map((v) => Math.log10((v ?? 0) + 1)); + const max = Math.max(1e-6, ...logHist); + ctx.fillStyle = '#5557'; + const binW = Math.max(1, INNER_WIDTH / 256); + for (let i = 0; i < 256; i++) { + const v = logHist[i] ?? 0; + const h = Math.round((v / max) * (INNER_HEIGHT - 2)); + const x = MARGIN_LEFT + Math.floor(i * binW); + const y = MARGIN_TOP + INNER_HEIGHT - h; + ctx.fillRect(x, y, Math.ceil(binW), h); + } + } + + // curve + const pts = sortPoints(localPoints); + ctx.strokeStyle = channelColor[channel]; + ctx.lineWidth = 2; + ctx.beginPath(); + for (let i = 0; i < pts.length; i++) { + const [x, y] = pts[i]!; + const cx = valueToCanvasX(x); + const cy = valueToCanvasY(y); + if (i === 0) { + ctx.moveTo(cx, cy); + } else { + ctx.lineTo(cx, cy); + } + } + ctx.stroke(); + + // control points + for (let i = 0; i < pts.length; i++) { + const [x, y] = pts[i]!; + const cx = valueToCanvasX(x); + const cy = valueToCanvasY(y); + ctx.fillStyle = '#000'; + ctx.beginPath(); + ctx.arc(cx, cy, 3.5, 0, Math.PI * 2); + ctx.fill(); + ctx.strokeStyle = channelColor[channel]; + ctx.lineWidth = 1.5; + ctx.stroke(); + } + }, [ + MARGIN_LEFT, + MARGIN_TOP, + INNER_HEIGHT, + INNER_WIDTH, + channel, + height, + histogram, + localPoints, + valueToCanvasX, + valueToCanvasY, + width, + ]); + + useEffect(() => { + draw(); + }, [draw]); + + const getNearestPointIndex = useCallback( + (mxCanvas: number, myCanvas: number) => { + // convert canvas px to value-space [0..255] + const xVal = canvasToValueX(mxCanvas); + const yVal = canvasToValueY(myCanvas); + let best = -1; + let bestDist = 9999; + for (let i = 0; i < localPoints.length; i++) { + const [px, py] = localPoints[i]!; + const dx = px - xVal; + const dy = py - yVal; + const d = dx * dx + dy * dy; + if (d < bestDist) { + best = i; + bestDist = d; + } + } + if (best !== -1 && bestDist <= 20 * 20) { + return best; + } + return -1; + }, + [canvasToValueX, canvasToValueY, localPoints] + ); + + const handlePointerDown = useCallback( + (e: React.PointerEvent) => { + e.preventDefault(); + e.stopPropagation(); + const c = canvasRef.current; + if (!c) { + return; + } + const rect = c.getBoundingClientRect(); + const scaleX = c.width / rect.width; + const scaleY = c.height / rect.height; + const mxCanvas = (e.clientX - rect.left) * scaleX; + const myCanvas = (e.clientY - rect.top) * scaleY; + const idx = getNearestPointIndex(mxCanvas, myCanvas); + if (idx !== -1 && idx !== 0 && idx !== localPoints.length - 1) { + setDragIndex(idx); + return; + } + // add new point + const xVal = canvasToValueX(mxCanvas); + const yVal = canvasToValueY(myCanvas); + const next = sortPoints([...localPoints, [xVal, yVal]]); + setLocalPoints(next); + setDragIndex(next.findIndex(([x, y]) => x === xVal && y === yVal)); + }, + [canvasToValueX, canvasToValueY, getNearestPointIndex, localPoints] + ); + + const handlePointerMove = useCallback( + (e: React.PointerEvent) => { + e.preventDefault(); + e.stopPropagation(); + if (dragIndex === null) { + return; + } + const c = canvasRef.current; + if (!c) { + return; + } + const rect = c.getBoundingClientRect(); + const scaleX = c.width / rect.width; + const scaleY = c.height / rect.height; + const mxCanvas = (e.clientX - rect.left) * scaleX; + const myCanvas = (e.clientY - rect.top) * scaleY; + const mxVal = canvasToValueX(mxCanvas); + const myVal = canvasToValueY(myCanvas); + setLocalPoints((prev) => { + const next = [...prev]; + // clamp endpoints to ends and keep them immutable + if (dragIndex === 0) { + return prev; + } + if (dragIndex === prev.length - 1) { + return prev; + } + next[dragIndex] = [mxVal, myVal]; + return sortPoints(next); + }); + }, + [canvasToValueX, canvasToValueY, dragIndex] + ); + + const commit = useCallback( + (pts: Array<[number, number]>) => { + onChange(sortPoints(pts)); + }, + [onChange] + ); + + const handlePointerUp = useCallback(() => { + setDragIndex(null); + commit(localPoints); + }, [commit, localPoints]); + + const handleDoubleClick = useCallback( + (e: React.MouseEvent) => { + e.preventDefault(); + e.stopPropagation(); + const c = canvasRef.current; + if (!c) { + return; + } + const rect = c.getBoundingClientRect(); + const scaleX = c.width / rect.width; + const scaleY = c.height / rect.height; + const mxCanvas = (e.clientX - rect.left) * scaleX; + const myCanvas = (e.clientY - rect.top) * scaleY; + const idx = getNearestPointIndex(mxCanvas, myCanvas); + if (idx > 0 && idx < localPoints.length - 1) { + const next = localPoints.filter((_, i) => i !== idx); + setLocalPoints(next); + commit(next); + } + }, + [commit, getNearestPointIndex, localPoints] + ); + + const canvasStyle = useMemo( + () => ({ width: '100%', height: height, touchAction: 'none', borderRadius: 4, background: '#111' }), + [height] + ); + + return ( +
+ + {title} + + +
+ ); +}); + +export const RasterLayerCurvesEditor = memo(() => { + const dispatch = useAppDispatch(); + const entityIdentifier = useEntityIdentifierContext<'raster_layer'>(); + const adapter = useEntityAdapterContext<'raster_layer'>('raster_layer'); + const { t } = useTranslation(); + const layer = useAppSelector((s) => selectEntity(s.canvas.present, entityIdentifier)) as + | CanvasRasterLayerState + | undefined; + + const [histMaster, setHistMaster] = useState(null); + const [histR, setHistR] = useState(null); + const [histG, setHistG] = useState(null); + const [histB, setHistB] = useState(null); + + const pointsMaster = layer?.adjustments?.curves.master ?? DEFAULT_POINTS; + const pointsR = layer?.adjustments?.curves.r ?? DEFAULT_POINTS; + const pointsG = layer?.adjustments?.curves.g ?? DEFAULT_POINTS; + const pointsB = layer?.adjustments?.curves.b ?? DEFAULT_POINTS; + + const recalcHistogram = useCallback(() => { + try { + const rect = adapter.transformer.getRelativeRect(); + if (rect.width === 0 || rect.height === 0) { + setHistMaster(Array(256).fill(0)); + setHistR(Array(256).fill(0)); + setHistG(Array(256).fill(0)); + setHistB(Array(256).fill(0)); + return; + } + const imageData = adapter.renderer.getImageData({ rect }); + const data = imageData.data; + const len = data.length / 4; + const master = new Array(256).fill(0); + const r = new Array(256).fill(0); + const g = new Array(256).fill(0); + const b = new Array(256).fill(0); + // sample every 4th pixel to lighten work + for (let i = 0; i < len; i += 4) { + const idx = i * 4; + const rv = data[idx] as number; + const gv = data[idx + 1] as number; + const bv = data[idx + 2] as number; + const m = Math.round(0.2126 * rv + 0.7152 * gv + 0.0722 * bv); + if (m >= 0 && m < 256) { + master[m] = (master[m] ?? 0) + 1; + } + if (rv >= 0 && rv < 256) { + r[rv] = (r[rv] ?? 0) + 1; + } + if (gv >= 0 && gv < 256) { + g[gv] = (g[gv] ?? 0) + 1; + } + if (bv >= 0 && bv < 256) { + b[bv] = (b[bv] ?? 0) + 1; + } + } + setHistMaster(master); + setHistR(r); + setHistG(g); + setHistB(b); + } catch { + // ignore + } + }, [adapter]); + + useEffect(() => { + recalcHistogram(); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [layer?.objects, layer?.adjustments]); + + const onChangePoints = useCallback( + (channel: Channel, pts: Array<[number, number]>) => { + dispatch(rasterLayerAdjustmentsCurvesUpdated({ entityIdentifier, channel, points: pts })); + }, + [dispatch, entityIdentifier] + ); + + // Memoize per-channel change handlers to avoid inline lambdas in JSX + const onChangeMaster = useCallback((pts: Array<[number, number]>) => onChangePoints('master', pts), [onChangePoints]); + const onChangeR = useCallback((pts: Array<[number, number]>) => onChangePoints('r', pts), [onChangePoints]); + const onChangeG = useCallback((pts: Array<[number, number]>) => onChangePoints('g', pts), [onChangePoints]); + const onChangeB = useCallback((pts: Array<[number, number]>) => onChangePoints('b', pts), [onChangePoints]); + + const gridStyles: React.CSSProperties = useMemo( + () => ({ display: 'grid', gridTemplateColumns: 'repeat(2, minmax(0, 1fr))', gap: 8 }), + [] + ); + + return ( + +
+ + + + +
+
+ ); +}); + +RasterLayerCurvesEditor.displayName = 'RasterLayerCurvesEditor'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItems.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItems.tsx index 65a16a7b4f9..708f7f29cd6 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItems.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItems.tsx @@ -9,6 +9,7 @@ import { CanvasEntityMenuItemsMergeDown } from 'features/controlLayers/component import { CanvasEntityMenuItemsSave } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSave'; import { CanvasEntityMenuItemsSelectObject } from 'features/controlLayers/components/common/CanvasEntityMenuItemsSelectObject'; import { CanvasEntityMenuItemsTransform } from 'features/controlLayers/components/common/CanvasEntityMenuItemsTransform'; +import { RasterLayerMenuItemsAdjustments } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsAdjustments'; import { RasterLayerMenuItemsConvertToSubMenu } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsConvertToSubMenu'; import { RasterLayerMenuItemsCopyToSubMenu } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItemsCopyToSubMenu'; import { memo } from 'react'; @@ -21,10 +22,10 @@ export const RasterLayerMenuItems = memo(() => { - + diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItemsAdjustments.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItemsAdjustments.tsx new file mode 100644 index 00000000000..77a939b7bf9 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/RasterLayer/RasterLayerMenuItemsAdjustments.tsx @@ -0,0 +1,38 @@ +import { MenuItem } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; +import { rasterLayerAdjustmentsReset, rasterLayerAdjustmentsSet } from 'features/controlLayers/store/canvasSlice'; +import type { CanvasRasterLayerState } from 'features/controlLayers/store/types'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiSlidersHorizontalBold } from 'react-icons/pi'; + +export const RasterLayerMenuItemsAdjustments = memo(() => { + const dispatch = useAppDispatch(); + const entityIdentifier = useEntityIdentifierContext<'raster_layer'>(); + const { t } = useTranslation(); + const layer = useAppSelector((s) => + s.canvas.present.rasterLayers.entities.find((e: CanvasRasterLayerState) => e.id === entityIdentifier.id) + ); + const hasAdjustments = Boolean(layer?.adjustments); + const onToggleAdjustmentsPresence = useCallback(() => { + if (hasAdjustments) { + dispatch(rasterLayerAdjustmentsReset({ entityIdentifier })); + } else { + dispatch( + rasterLayerAdjustmentsSet({ + entityIdentifier, + adjustments: { enabled: true, collapsed: false, mode: 'simple' }, + }) + ); + } + }, [dispatch, entityIdentifier, hasAdjustments]); + + return ( + }> + {hasAdjustments ? t('controlLayers.removeAdjustments') : t('controlLayers.addAdjustments')} + + ); +}); + +RasterLayerMenuItemsAdjustments.displayName = 'RasterLayerMenuItemsAdjustments'; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts index 6c55e949377..5995e80663c 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase.ts @@ -475,7 +475,7 @@ export abstract class CanvasEntityAdapterBase => { const { rect } = this.manager.stateApi.getBbox(); const rasterizeResult = await withResultAsync(() => - this.renderer.rasterize({ rect, replaceObjects: true, attrs: { opacity: 1, filters: [] } }) + this.renderer.rasterize({ rect, replaceObjects: true, attrs: { opacity: 1 } }) ); if (rasterizeResult.isErr()) { toast({ status: 'error', title: 'Failed to crop to bbox' }); diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer.ts index 7e31b594fac..06620584fc5 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer.ts @@ -72,7 +72,7 @@ export class CanvasEntityAdapterControlLayer extends CanvasEntityAdapterBase< this.log.trace({ rect }, 'Getting canvas'); // The opacity may have been changed in response to user selecting a different entity category, so we must restore // the original opacity before rendering the canvas - const attrs: GroupConfig = { opacity: this.state.opacity, filters: [] }; + const attrs: GroupConfig = { opacity: this.state.opacity }; const canvas = this.renderer.getCanvas({ rect, attrs }); return canvas; }; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer.ts index ce138c38321..cd8dee6d2f3 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer.ts @@ -1,4 +1,4 @@ -import { omit } from 'es-toolkit/compat'; +import { omit, throttle } from 'es-toolkit/compat'; import { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase'; import { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer'; import { CanvasEntityFilterer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer'; @@ -6,6 +6,7 @@ import { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasE import { CanvasEntityTransformer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityTransformer'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import { CanvasSegmentAnythingModule } from 'features/controlLayers/konva/CanvasSegmentAnythingModule'; +import { AdjustmentsCurvesFilter, AdjustmentsSimpleFilter, buildCurveLUT } from 'features/controlLayers/konva/filters'; import type { CanvasEntityIdentifier, CanvasRasterLayerState, Rect } from 'features/controlLayers/store/types'; import type { GroupConfig } from 'konva/lib/Group'; import type { JsonObject } from 'type-fest'; @@ -59,13 +60,18 @@ export class CanvasEntityAdapterRasterLayer extends CanvasEntityAdapterBase< if (!prevState || this.state.opacity !== prevState.opacity) { this.syncOpacity(); } + + // Apply per-layer adjustments as a Konva filter + if (!prevState || this.haveAdjustmentsChanged(prevState, this.state)) { + this.syncAdjustmentsFilter(); + } }; getCanvas = (rect?: Rect): HTMLCanvasElement => { this.log.trace({ rect }, 'Getting canvas'); // The opacity may have been changed in response to user selecting a different entity category, so we must restore // the original opacity before rendering the canvas - const attrs: GroupConfig = { opacity: this.state.opacity, filters: [] }; + const attrs: GroupConfig = { opacity: this.state.opacity }; const canvas = this.renderer.getCanvas({ rect, attrs }); return canvas; }; @@ -74,4 +80,79 @@ export class CanvasEntityAdapterRasterLayer extends CanvasEntityAdapterBase< const keysToOmit: (keyof CanvasRasterLayerState)[] = ['name', 'isLocked']; return omit(this.state, keysToOmit); }; + + private syncAdjustmentsFilter = () => { + const a = this.state.adjustments; + const apply = !!a && a.enabled; + // The filter operates on the renderer's object group; we can set filters at the group level via renderer + const group = this.renderer.konva.objectGroup; + if (apply) { + const filters = group.filters() ?? []; + let nextFilters = filters.filter((f: unknown) => f !== AdjustmentsSimpleFilter && f !== AdjustmentsCurvesFilter); + if (a.mode === 'simple') { + group.setAttr('adjustmentsSimple', a.simple); + group.setAttr('adjustmentsCurves', null); + nextFilters = [...nextFilters, AdjustmentsSimpleFilter]; + } else { + // Build LUTs and set curves attr + const master = buildCurveLUT(a.curves.master); + const r = buildCurveLUT(a.curves.r); + const g = buildCurveLUT(a.curves.g); + const b = buildCurveLUT(a.curves.b); + group.setAttr('adjustmentsCurves', { master, r, g, b }); + group.setAttr('adjustmentsSimple', null); + nextFilters = [...nextFilters, AdjustmentsCurvesFilter]; + } + group.filters(nextFilters); + this._throttledCacheRefresh(); + } else { + // Remove our filter if present + const filters = (group.filters() ?? []).filter( + (f: unknown) => f !== AdjustmentsSimpleFilter && f !== AdjustmentsCurvesFilter + ); + group.filters(filters); + group.setAttr('adjustmentsSimple', null); + group.setAttr('adjustmentsCurves', null); + this._throttledCacheRefresh(); + } + }; + + private _throttledCacheRefresh = throttle(() => this.renderer.syncKonvaCache(true), 50); + + private haveAdjustmentsChanged = (prevState: CanvasRasterLayerState, currState: CanvasRasterLayerState): boolean => { + const pa = prevState.adjustments; + const ca = currState.adjustments; + if (pa === ca) { + return false; + } + if (!pa || !ca) { + return true; + } + if (pa.enabled !== ca.enabled) { + return true; + } + if (pa.mode !== ca.mode) { + return true; + } + // simple params + const ps = pa.simple; + const cs = ca.simple; + if ( + ps.brightness !== cs.brightness || + ps.contrast !== cs.contrast || + ps.saturation !== cs.saturation || + ps.temperature !== cs.temperature || + ps.tint !== cs.tint || + ps.sharpness !== cs.sharpness + ) { + return true; + } + // curves reference (UI not implemented yet) - if arrays differ by ref, consider changed + const pc = pa.curves; + const cc = ca.curves; + if (pc !== cc) { + return true; + } + return false; + }; } diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts b/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts index 34c5c9ac5de..6a8a704e136 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/filters.ts @@ -20,3 +20,188 @@ export const LightnessToAlphaFilter = (imageData: ImageData): void => { imageData.data[i * 4 + 3] = Math.min(a, (cMin + cMax) / 2); } }; + +// Utility clamp +const clamp = (v: number, min: number, max: number) => (v < min ? min : v > max ? max : v); + +type SimpleAdjustParams = { + brightness: number; // -1..1 (additive) + contrast: number; // -1..1 (scale around 128) + saturation: number; // -1..1 + temperature: number; // -1..1 (blue<->yellow approx) + tint: number; // -1..1 (green<->magenta approx) + sharpness: number; // -1..1 (light unsharp mask) +}; + +/** + * Per-layer simple adjustments filter (brightness, contrast, saturation, temp, tint, sharpness). + * + * Parameters are read from the Konva node attr `adjustmentsSimple` set by the adapter. + */ +type KonvaFilterThis = { getAttr?: (key: string) => unknown }; +export const AdjustmentsSimpleFilter = function (this: KonvaFilterThis, imageData: ImageData): void { + const params = (this?.getAttr?.('adjustmentsSimple') as SimpleAdjustParams | undefined) ?? null; + if (!params) { + return; + } + + const { brightness, contrast, saturation, temperature, tint, sharpness } = params; + + const data = imageData.data; + const len = data.length / 4; + const width = (imageData as ImageData & { width: number }).width ?? 0; + const height = (imageData as ImageData & { height: number }).height ?? 0; + + // Precompute factors + const brightnessShift = brightness * 255; // additive shift + const contrastFactor = 1 + contrast; // scale around 128 + + // Temperature/Tint multipliers + const tempK = 0.5; + const tintK = 0.5; + const rTempMul = 1 + temperature * tempK; + const bTempMul = 1 - temperature * tempK; + // Tint: green <-> magenta. Positive = magenta (R/B up, G down). Negative = green (G up, R/B down). + const t = clamp(tint, -1, 1) * tintK; + const mag = Math.abs(t); + const rTintMul = t >= 0 ? 1 + mag : 1 - mag; + const gTintMul = t >= 0 ? 1 - mag : 1 + mag; + const bTintMul = t >= 0 ? 1 + mag : 1 - mag; + + // Saturation matrix (HSL-based approximation via luma coefficients) + const lumaR = 0.2126; + const lumaG = 0.7152; + const lumaB = 0.0722; + const S = 1 + saturation; // 0..2 + const m00 = lumaR * (1 - S) + S; + const m01 = lumaG * (1 - S); + const m02 = lumaB * (1 - S); + const m10 = lumaR * (1 - S); + const m11 = lumaG * (1 - S) + S; + const m12 = lumaB * (1 - S); + const m20 = lumaR * (1 - S); + const m21 = lumaG * (1 - S); + const m22 = lumaB * (1 - S) + S; + + // First pass: apply per-pixel color adjustments (excluding sharpness) + for (let i = 0; i < len; i++) { + const idx = i * 4; + let r = data[idx + 0] as number; + let g = data[idx + 1] as number; + let b = data[idx + 2] as number; + const a = data[idx + 3] as number; + + // Brightness (additive) + r = r + brightnessShift; + g = g + brightnessShift; + b = b + brightnessShift; + + // Contrast around mid-point 128 + r = (r - 128) * contrastFactor + 128; + g = (g - 128) * contrastFactor + 128; + b = (b - 128) * contrastFactor + 128; + + // Temperature (R/B axis) and Tint (G vs Magenta) + r = r * rTempMul * rTintMul; + g = g * gTintMul; + b = b * bTempMul * bTintMul; + + // Saturation via matrix + const r2 = r * m00 + g * m01 + b * m02; + const g2 = r * m10 + g * m11 + b * m12; + const b2 = r * m20 + g * m21 + b * m22; + + data[idx + 0] = clamp(r2, 0, 255); + data[idx + 1] = clamp(g2, 0, 255); + data[idx + 2] = clamp(b2, 0, 255); + data[idx + 3] = a; + } + + // Optional sharpen (simple unsharp mask with 3x3 kernel) + if (Math.abs(sharpness) > 1e-3 && width > 2 && height > 2) { + const src = new Uint8ClampedArray(data); // copy of modified data + const a = Math.max(-1, Math.min(1, sharpness)) * 0.5; // amount + const center = 1 + 4 * a; + const neighbor = -a; + for (let y = 1; y < height - 1; y++) { + for (let x = 1; x < width - 1; x++) { + const idx = (y * width + x) * 4; + for (let c = 0; c < 3; c++) { + const centerPx = src[idx + c] ?? 0; + const leftPx = src[idx - 4 + c] ?? 0; + const rightPx = src[idx + 4 + c] ?? 0; + const topPx = src[idx - width * 4 + c] ?? 0; + const bottomPx = src[idx + width * 4 + c] ?? 0; + const v = centerPx * center + leftPx * neighbor + rightPx * neighbor + topPx * neighbor + bottomPx * neighbor; + data[idx + c] = clamp(v, 0, 255); + } + // preserve alpha + } + } + } +}; + +// Build a 256-length LUT from 0..255 control points (linear interpolation for v1) +export const buildCurveLUT = (points: Array<[number, number]>): number[] => { + if (!points || points.length === 0) { + return Array.from({ length: 256 }, (_, i) => i); + } + const pts = points + .map(([x, y]) => [clamp(Math.round(x), 0, 255), clamp(Math.round(y), 0, 255)] as [number, number]) + .sort((a, b) => a[0] - b[0]); + if ((pts[0]?.[0] ?? 0) !== 0) { + pts.unshift([0, pts[0]?.[1] ?? 0]); + } + const last = pts[pts.length - 1]; + if ((last?.[0] ?? 255) !== 255) { + pts.push([255, last?.[1] ?? 255]); + } + const lut = new Array(256); + let j = 0; + for (let x = 0; x <= 255; x++) { + while (j < pts.length - 2 && x > (pts[j + 1]?.[0] ?? 255)) { + j++; + } + const p0 = pts[j] ?? [0, 0]; + const p1 = pts[j + 1] ?? [255, 255]; + const [x0, y0] = p0; + const [x1, y1] = p1; + const t = x1 === x0 ? 0 : (x - x0) / (x1 - x0); + const y = y0 + (y1 - y0) * t; + lut[x] = clamp(Math.round(y), 0, 255); + } + return lut; +}; + +type CurvesAdjustParams = { + master: number[]; + r: number[]; + g: number[]; + b: number[]; +}; + +// Curves filter: apply master curve, then per-channel curves +export const AdjustmentsCurvesFilter = function (this: KonvaFilterThis, imageData: ImageData): void { + const params = (this?.getAttr?.('adjustmentsCurves') as CurvesAdjustParams | undefined) ?? null; + if (!params) { + return; + } + const { master, r, g, b } = params; + if (!master || !r || !g || !b) { + return; + } + const data = imageData.data; + const len = data.length / 4; + for (let i = 0; i < len; i++) { + const idx = i * 4; + const r0 = data[idx + 0] as number; + const g0 = data[idx + 1] as number; + const b0 = data[idx + 2] as number; + const rm = master[r0] ?? r0; + const gm = master[g0] ?? g0; + const bm = master[b0] ?? b0; + data[idx + 0] = clamp(r[rm] ?? rm, 0, 255); + data[idx + 1] = clamp(g[gm] ?? gm, 0, 255); + data[idx + 2] = clamp(b[bm] ?? bm, 0, 255); + } +}; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts index 61168a0ec5a..bcc8fb9ffc5 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts @@ -102,6 +102,165 @@ const slice = createSlice({ reducers: { // undoable canvas state //#region Raster layers + rasterLayerAdjustmentsSet: ( + state, + action: PayloadAction< + EntityIdentifierPayload< + { + adjustments: + | NonNullable + | { enabled?: boolean; collapsed?: boolean; mode?: 'simple' | 'curves' } + | null; + }, + 'raster_layer' + > + > + ) => { + const { entityIdentifier, adjustments } = action.payload; + const layer = selectEntity(state, entityIdentifier); + if (!layer) { + return; + } + if (adjustments === null) { + layer.adjustments = null; + return; + } + if (layer.adjustments === null) { + layer.adjustments = { + version: 1, + enabled: true, + collapsed: false, + mode: 'simple', + simple: { brightness: 0, contrast: 0, saturation: 0, temperature: 0, tint: 0, sharpness: 0 }, + curves: { + master: [ + [0, 0], + [255, 255], + ], + r: [ + [0, 0], + [255, 255], + ], + g: [ + [0, 0], + [255, 255], + ], + b: [ + [0, 0], + [255, 255], + ], + }, + }; + } + if (typeof adjustments === 'object' && adjustments !== null && 'version' in adjustments) { + layer.adjustments = merge(layer.adjustments, adjustments as NonNullable); + } else { + // Shallow toggles only + const partial = adjustments as { enabled?: boolean; collapsed?: boolean; mode?: 'simple' | 'curves' }; + layer.adjustments = merge(layer.adjustments, partial); + } + }, + rasterLayerAdjustmentsReset: (state, action: PayloadAction>) => { + const { entityIdentifier } = action.payload; + const layer = selectEntity(state, entityIdentifier); + if (!layer) { + return; + } + layer.adjustments = null; + }, + rasterLayerAdjustmentsSimpleUpdated: ( + state, + action: PayloadAction< + EntityIdentifierPayload< + { + simple: Partial['simple']>>; + }, + 'raster_layer' + > + > + ) => { + const { entityIdentifier, simple } = action.payload; + const layer = selectEntity(state, entityIdentifier); + if (!layer) { + return; + } + if (!layer.adjustments) { + // initialize baseline + layer.adjustments = { + version: 1, + enabled: true, + collapsed: false, + mode: 'simple', + simple: { brightness: 0, contrast: 0, saturation: 0, temperature: 0, tint: 0, sharpness: 0 }, + curves: { + master: [ + [0, 0], + [255, 255], + ], + r: [ + [0, 0], + [255, 255], + ], + g: [ + [0, 0], + [255, 255], + ], + b: [ + [0, 0], + [255, 255], + ], + }, + }; + } + layer.adjustments.simple = merge(layer.adjustments.simple, simple); + }, + rasterLayerAdjustmentsCurvesUpdated: ( + state, + action: PayloadAction< + EntityIdentifierPayload< + { + channel: 'master' | 'r' | 'g' | 'b'; + points: Array<[number, number]>; + }, + 'raster_layer' + > + > + ) => { + const { entityIdentifier, channel, points } = action.payload; + const layer = selectEntity(state, entityIdentifier); + if (!layer) { + return; + } + if (!layer.adjustments) { + // initialize baseline + layer.adjustments = { + version: 1, + enabled: true, + collapsed: false, + mode: 'curves', + simple: { brightness: 0, contrast: 0, saturation: 0, temperature: 0, tint: 0, sharpness: 0 }, + curves: { + master: [ + [0, 0], + [255, 255], + ], + r: [ + [0, 0], + [255, 255], + ], + g: [ + [0, 0], + [255, 255], + ], + b: [ + [0, 0], + [255, 255], + ], + }, + }; + } + layer.adjustments.curves[channel] = points; + }, rasterLayerAdded: { reducer: ( state, @@ -1621,6 +1780,11 @@ export const { entityBrushLineAdded, entityEraserLineAdded, entityRectAdded, + // Raster layer adjustments + rasterLayerAdjustmentsSet, + rasterLayerAdjustmentsReset, + rasterLayerAdjustmentsSimpleUpdated, + rasterLayerAdjustmentsCurvesUpdated, entityDeleted, entityArrangedForwardOne, entityArrangedToFront, diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index f771f8c7469..fdba9fa3f03 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -368,6 +368,32 @@ const zCanvasRasterLayerState = zCanvasEntityBase.extend({ position: zCoordinate, opacity: zOpacity, objects: z.array(zCanvasObjectState), + // Optional per-layer color adjustments (simple + curves). When null/undefined, no adjustments are applied. + adjustments: z + .object({ + version: z.literal(1), + enabled: z.boolean(), + collapsed: z.boolean(), + mode: z.enum(['simple', 'curves']), + simple: z.object({ + // All simple params normalized to [-1, 1] except sharpness [0, 1] + brightness: z.number().gte(-1).lte(1), + contrast: z.number().gte(-1).lte(1), + saturation: z.number().gte(-1).lte(1), + temperature: z.number().gte(-1).lte(1), + tint: z.number().gte(-1).lte(1), + sharpness: z.number().gte(0).lte(1), + }), + curves: z.object({ + // Curves are arrays of [x, y] control points in 0..255 space (no strict monotonic checks here) + master: z.array(z.tuple([z.number().int().min(0).max(255), z.number().int().min(0).max(255)])).min(2), + r: z.array(z.tuple([z.number().int().min(0).max(255), z.number().int().min(0).max(255)])).min(2), + g: z.array(z.tuple([z.number().int().min(0).max(255), z.number().int().min(0).max(255)])).min(2), + b: z.array(z.tuple([z.number().int().min(0).max(255), z.number().int().min(0).max(255)])).min(2), + }), + }) + .optional() + .nullable(), }); export type CanvasRasterLayerState = z.infer; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/util.ts b/invokeai/frontend/web/src/features/controlLayers/store/util.ts index 2d40cf17793..e14cfd546f4 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/util.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/util.ts @@ -198,6 +198,7 @@ export const getRasterLayerState = ( objects: [], opacity: 1, position: { x: 0, y: 0 }, + adjustments: null, }; merge(entityState, overrides); return entityState;