Skip to content

Commit b24c2cb

Browse files
Mary HippMary Hipp
authored andcommitted
introduce new canvas_workflows that can be used to indicate inputs and outputs of the canvas, build UI where use can select a workflow with these nodes to run against canvas
1 parent 5885db4 commit b24c2cb

File tree

17 files changed

+964
-131
lines changed

17 files changed

+964
-131
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""Canvas workflow bridge invocations."""
2+
3+
from invokeai.app.invocations.baseinvocation import (
4+
BaseInvocation,
5+
Classification,
6+
invocation,
7+
)
8+
from invokeai.app.invocations.fields import ImageField, Input, InputField, WithBoard, WithMetadata
9+
from invokeai.app.invocations.primitives import ImageOutput
10+
from invokeai.app.services.shared.invocation_context import InvocationContext
11+
12+
13+
@invocation(
14+
"canvas_composite_raster_input",
15+
title="Canvas Composite Input",
16+
tags=["canvas", "workflow", "canvas-workflow-input"],
17+
category="canvas",
18+
version="1.0.0",
19+
classification=Classification.Beta,
20+
)
21+
class CanvasCompositeRasterInputInvocation(BaseInvocation, WithMetadata, WithBoard):
22+
"""Provides the flattened canvas raster layer to a workflow."""
23+
24+
image: ImageField = InputField(
25+
description="The flattened canvas raster layer.",
26+
input=Input.Direct,
27+
)
28+
29+
def invoke(self, context: InvocationContext) -> ImageOutput:
30+
image_dto = context.images.get_dto(self.image.image_name)
31+
return ImageOutput.build(image_dto=image_dto)
32+
33+
34+
@invocation(
35+
"canvas_workflow_output",
36+
title="Canvas Workflow Output",
37+
tags=["canvas", "workflow", "canvas-workflow-output"],
38+
category="canvas",
39+
version="1.0.0",
40+
classification=Classification.Beta,
41+
)
42+
class CanvasWorkflowOutputInvocation(BaseInvocation, WithMetadata, WithBoard):
43+
"""Designates the workflow image output used by the canvas."""
44+
45+
image: ImageField = InputField(
46+
description="The workflow's resulting image.",
47+
input=Input.Connection,
48+
)
49+
50+
def invoke(self, context: InvocationContext) -> ImageOutput:
51+
image_dto = context.images.get_dto(self.image.image_name)
52+
return ImageOutput.build(image_dto=image_dto)

invokeai/frontend/web/public/locales/en.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2128,6 +2128,16 @@
21282128
"recalculateRects": "Recalculate Rects",
21292129
"clipToBbox": "Clip Strokes to Bbox",
21302130
"outputOnlyMaskedRegions": "Output Only Generated Regions",
2131+
"canvasWorkflowLabel": "Canvas Workflow",
2132+
"canvasWorkflowInstructions": "Select a workflow containing the canvas composite input and canvas workflow output nodes to drive custom canvas generation.",
2133+
"canvasWorkflowSelectedDescription": "This workflow is currently configured for canvas generation.",
2134+
"canvasWorkflowSelectButton": "Select Workflow",
2135+
"canvasWorkflowSelected": "Canvas workflow selected",
2136+
"canvasWorkflowModalTitle": "Select Canvas Workflow",
2137+
"canvasWorkflowModalDescription": "Choose a workflow containing the canvas composite input and canvas workflow output nodes. Only workflows that meet these requirements can be used from the canvas.",
2138+
"selectCanvasWorkflowTooltip": "Select a workflow to run from the canvas",
2139+
"changeCanvasWorkflowTooltip": "Change canvas workflow",
2140+
"canvasWorkflowChangeButton": "Change Workflow",
21312141
"addLayer": "Add Layer",
21322142
"duplicate": "Duplicate",
21332143
"moveToFront": "Move to Front",

invokeai/frontend/web/src/app/store/store.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import { changeBoardModalSliceConfig } from 'features/changeBoardModal/store/sli
2424
import { canvasSettingsSliceConfig } from 'features/controlLayers/store/canvasSettingsSlice';
2525
import { canvasSliceConfig } from 'features/controlLayers/store/canvasSlice';
2626
import { canvasSessionSliceConfig } from 'features/controlLayers/store/canvasStagingAreaSlice';
27+
import { canvasWorkflowSliceConfig } from 'features/controlLayers/store/canvasWorkflowSlice';
2728
import { lorasSliceConfig } from 'features/controlLayers/store/lorasSlice';
2829
import { paramsSliceConfig } from 'features/controlLayers/store/paramsSlice';
2930
import { refImagesSliceConfig } from 'features/controlLayers/store/refImagesSlice';
@@ -65,6 +66,7 @@ const log = logger('system');
6566
const SLICE_CONFIGS = {
6667
[canvasSessionSliceConfig.slice.reducerPath]: canvasSessionSliceConfig,
6768
[canvasSettingsSliceConfig.slice.reducerPath]: canvasSettingsSliceConfig,
69+
[canvasWorkflowSliceConfig.slice.reducerPath]: canvasWorkflowSliceConfig,
6870
[canvasSliceConfig.slice.reducerPath]: canvasSliceConfig,
6971
[changeBoardModalSliceConfig.slice.reducerPath]: changeBoardModalSliceConfig,
7072
[configSliceConfig.slice.reducerPath]: configSliceConfig,
@@ -91,6 +93,7 @@ const ALL_REDUCERS = {
9193
[api.reducerPath]: api.reducer,
9294
[canvasSessionSliceConfig.slice.reducerPath]: canvasSessionSliceConfig.slice.reducer,
9395
[canvasSettingsSliceConfig.slice.reducerPath]: canvasSettingsSliceConfig.slice.reducer,
96+
[canvasWorkflowSliceConfig.slice.reducerPath]: canvasWorkflowSliceConfig.slice.reducer,
9497
// Undoable!
9598
[canvasSliceConfig.slice.reducerPath]: undoable(
9699
canvasSliceConfig.slice.reducer,

invokeai/frontend/web/src/features/controlLayers/components/StagingArea/shared.ts

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,28 @@ export const getOutputImageName = (item: S['SessionQueueItem']) => {
2121
)?.[1][0];
2222
const output = nodeId ? item.session.results[nodeId] : undefined;
2323

24-
if (!output) {
24+
const getImageNameFromOutput = (result?: S['GraphExecutionState']['results'][string]) => {
25+
if (!result) {
26+
return null;
27+
}
28+
for (const [_name, value] of objectEntries(result)) {
29+
if (isImageField(value)) {
30+
return value.image_name;
31+
}
32+
}
2533
return null;
34+
};
35+
36+
const imageName = getImageNameFromOutput(output);
37+
if (imageName) {
38+
return imageName;
2639
}
2740

28-
for (const [_name, value] of objectEntries(output)) {
29-
if (isImageField(value)) {
30-
return value.image_name;
41+
// Fallback: search all results for an image field. Custom workflows may not have a canvas_output-prefixed node id.
42+
for (const result of Object.values(item.session.results)) {
43+
const fallbackName = getImageNameFromOutput(result);
44+
if (fallbackName) {
45+
return fallbackName;
3146
}
3247
}
3348

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
import { createAsyncThunk, createSlice } from '@reduxjs/toolkit';
2+
import { logger } from 'app/logging/logger';
3+
import type { RootState } from 'app/store/store';
4+
import type { SliceConfig } from 'app/store/types';
5+
import { deepClone } from 'common/util/deepClone';
6+
import { parseify } from 'common/util/serialize';
7+
import { $templates } from 'features/nodes/store/nodesSlice';
8+
import type { Templates } from 'features/nodes/store/types';
9+
import type { WorkflowV3 } from 'features/nodes/types/workflow';
10+
import { zWorkflowV3 } from 'features/nodes/types/workflow';
11+
import { serializeError } from 'serialize-error';
12+
import { workflowsApi } from 'services/api/endpoints/workflows';
13+
import { z } from 'zod';
14+
15+
const log = logger('canvas');
16+
17+
const zCanvasWorkflowState = z.object({
18+
selectedWorkflowId: z.string().nullable(),
19+
workflow: zWorkflowV3.nullable(),
20+
inputNodeId: z.string().nullable(),
21+
outputNodeId: z.string().nullable(),
22+
status: z.enum(['idle', 'loading', 'succeeded', 'failed']),
23+
error: z.string().nullable(),
24+
});
25+
26+
export type CanvasWorkflowState = z.infer<typeof zCanvasWorkflowState>;
27+
28+
const getInitialState = (): CanvasWorkflowState => ({
29+
selectedWorkflowId: null,
30+
workflow: null,
31+
inputNodeId: null,
32+
outputNodeId: null,
33+
status: 'idle',
34+
error: null,
35+
});
36+
37+
type ValidateResult = {
38+
inputNodeId: string;
39+
outputNodeId: string;
40+
};
41+
42+
const INPUT_TAG = 'canvas-workflow-input';
43+
const OUTPUT_TAG = 'canvas-workflow-output';
44+
45+
const validateCanvasWorkflow = (workflow: WorkflowV3, templates: Templates): ValidateResult => {
46+
const invocationNodes = workflow.nodes.filter(
47+
(node): node is WorkflowV3['nodes'][number] => node.type === 'invocation'
48+
);
49+
50+
const inputNodes = invocationNodes.filter((node) => {
51+
const template = templates[node.data.type];
52+
return Boolean(template && template.tags.includes(INPUT_TAG));
53+
});
54+
55+
const outputNodes = invocationNodes.filter((node) => {
56+
const template = templates[node.data.type];
57+
return Boolean(template && template.tags.includes(OUTPUT_TAG));
58+
});
59+
60+
if (inputNodes.length !== 1) {
61+
throw new Error('A canvas workflow must include exactly one input node.');
62+
}
63+
64+
if (outputNodes.length !== 1) {
65+
throw new Error('A canvas workflow must include exactly one output node.');
66+
}
67+
68+
const inputNode = inputNodes[0]!;
69+
const outputNode = outputNodes[0]!;
70+
71+
const inputTemplate = templates[inputNode.data.type];
72+
if (!inputTemplate) {
73+
throw new Error(`Input node template "${inputNode.data.type}" not found.`);
74+
}
75+
if (!('image' in inputTemplate.inputs)) {
76+
throw new Error('Canvas input node must expose an image field.');
77+
}
78+
79+
const outputTemplate = templates[outputNode.data.type];
80+
if (!outputTemplate) {
81+
throw new Error(`Output node template "${outputNode.data.type}" not found.`);
82+
}
83+
if (!('image' in outputTemplate.inputs)) {
84+
throw new Error('Canvas output node must accept an image input field named "image".');
85+
}
86+
87+
return { inputNodeId: inputNode.id, outputNodeId: outputNode.id };
88+
};
89+
90+
export const selectCanvasWorkflow = createAsyncThunk<
91+
{ workflowId: string; workflow: WorkflowV3; inputNodeId: string; outputNodeId: string },
92+
string,
93+
{ rejectValue: string }
94+
>('canvasWorkflow/select', async (workflowId, { dispatch, rejectWithValue }) => {
95+
const request = dispatch(workflowsApi.endpoints.getWorkflow.initiate(workflowId, { subscribe: false }));
96+
try {
97+
const result = await request.unwrap();
98+
const workflow = zWorkflowV3.parse(deepClone(result.workflow));
99+
const templates = $templates.get();
100+
if (!Object.keys(templates).length) {
101+
throw new Error('Invocation templates are not yet available.');
102+
}
103+
const { inputNodeId, outputNodeId } = validateCanvasWorkflow(workflow, templates);
104+
return { workflowId: result.workflow_id, workflow, inputNodeId, outputNodeId };
105+
} catch (error) {
106+
const message = error instanceof Error ? error.message : 'Unable to load workflow.';
107+
log.error({ error: serializeError(error as Error) }, 'Failed to load canvas workflow');
108+
return rejectWithValue(message);
109+
} finally {
110+
request.unsubscribe();
111+
}
112+
});
113+
114+
const slice = createSlice({
115+
name: 'canvasWorkflow',
116+
initialState: getInitialState(),
117+
reducers: {
118+
canvasWorkflowCleared: () => getInitialState(),
119+
},
120+
extraReducers(builder) {
121+
builder
122+
.addCase(selectCanvasWorkflow.pending, (state) => {
123+
state.status = 'loading';
124+
state.error = null;
125+
})
126+
.addCase(selectCanvasWorkflow.fulfilled, (state, action) => {
127+
state.selectedWorkflowId = action.payload.workflowId;
128+
state.workflow = action.payload.workflow;
129+
state.inputNodeId = action.payload.inputNodeId;
130+
state.outputNodeId = action.payload.outputNodeId;
131+
state.status = 'succeeded';
132+
state.error = null;
133+
})
134+
.addCase(selectCanvasWorkflow.rejected, (state, action) => {
135+
state.status = 'failed';
136+
state.error = action.payload ?? action.error.message ?? 'Unable to load workflow.';
137+
});
138+
},
139+
});
140+
141+
export const { canvasWorkflowCleared } = slice.actions;
142+
143+
export const canvasWorkflowSliceConfig: SliceConfig<typeof slice> = {
144+
slice,
145+
schema: zCanvasWorkflowState,
146+
getInitialState,
147+
persistConfig: {
148+
migrate: (state) => {
149+
const parsed = zCanvasWorkflowState.safeParse(state);
150+
if (!parsed.success) {
151+
log.warn({ error: parseify(parsed.error) }, 'Failed to migrate canvas workflow state, resetting to defaults');
152+
return getInitialState();
153+
}
154+
return {
155+
...parsed.data,
156+
status: 'idle',
157+
error: null,
158+
} satisfies CanvasWorkflowState;
159+
},
160+
persistDenylist: ['status', 'error'],
161+
},
162+
};
163+
164+
export const selectCanvasWorkflowSlice = (state: RootState) => state.canvasWorkflow;
165+
166+
export const selectCanvasWorkflowStatus = (state: RootState) => selectCanvasWorkflowSlice(state).status;
167+
168+
export const selectCanvasWorkflowError = (state: RootState) => selectCanvasWorkflowSlice(state).error;
169+
170+
export const selectCanvasWorkflowSelection = (state: RootState) => selectCanvasWorkflowSlice(state).selectedWorkflowId;
171+
172+
export const selectCanvasWorkflowData = (state: RootState) => selectCanvasWorkflowSlice(state).workflow;
173+
174+
export const selectCanvasWorkflowNodeIds = (state: RootState) => ({
175+
inputNodeId: selectCanvasWorkflowSlice(state).inputNodeId,
176+
outputNodeId: selectCanvasWorkflowSlice(state).outputNodeId,
177+
});
178+
179+
export const selectIsCanvasWorkflowActive = (state: RootState) => {
180+
const sliceState = selectCanvasWorkflowSlice(state);
181+
return (
182+
Boolean(sliceState.workflow && sliceState.inputNodeId && sliceState.outputNodeId) &&
183+
(sliceState.status === 'succeeded' || sliceState.status === 'idle')
184+
);
185+
};

invokeai/frontend/web/src/features/nodes/components/sidePanel/WorkflowListMenu/WorkflowListMenuTrigger.tsx

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import { Button, Text } from '@invoke-ai/ui-library';
22
import { useAppSelector } from 'app/store/storeHooks';
33
import { selectWorkflowName } from 'features/nodes/store/selectors';
44
import { useWorkflowLibraryModal } from 'features/nodes/store/workflowLibraryModal';
5+
import { useCallback } from 'react';
56
import { useTranslation } from 'react-i18next';
67
import { PiFolderOpenFill } from 'react-icons/pi';
78

@@ -10,8 +11,12 @@ export const WorkflowListMenuTrigger = () => {
1011
const { t } = useTranslation();
1112
const workflowName = useAppSelector(selectWorkflowName);
1213

14+
const onClick = useCallback(() => {
15+
workflowLibraryModal.open();
16+
}, [workflowLibraryModal]);
17+
1318
return (
14-
<Button variant="ghost" rightIcon={<PiFolderOpenFill />} size="sm" onClick={workflowLibraryModal.open}>
19+
<Button variant="ghost" rightIcon={<PiFolderOpenFill />} size="sm" onClick={onClick}>
1520
<Text
1621
display="auto"
1722
noOfLines={1}

invokeai/frontend/web/src/features/nodes/components/sidePanel/viewMode/EmptyState.tsx

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ const CleanEditorContent = () => {
2727
dispatch(workflowModeChanged('edit'));
2828
}, [dispatch]);
2929

30+
const onClickBrowseWorkflows = useCallback(() => {
31+
workflowLibraryModal.open();
32+
}, [workflowLibraryModal]);
33+
3034
return (
3135
<Flex flexDir="column" h="full" w="full" alignItems="center">
3236
<Flex flexDir="column" gap={8} w="full" pt="20vh" px={8} maxW={768}>
@@ -39,7 +43,7 @@ const CleanEditorContent = () => {
3943
</Text>
4044
</Flex>
4145
</LaunchpadButton>
42-
<LaunchpadButton onClick={workflowLibraryModal.open} gap={8}>
46+
<LaunchpadButton onClick={onClickBrowseWorkflows} gap={8}>
4347
<Icon as={PiFolderOpenBold} boxSize={6} color="base.500" />
4448
<Flex flexDir="column" alignItems="flex-start" gap={2}>
4549
<Heading size="sm">{t('nodes.loadWorkflow')}</Heading>

0 commit comments

Comments
 (0)