|
| 1 | +import { createAsyncThunk, createSlice } from '@reduxjs/toolkit'; |
| 2 | +import { logger } from 'app/logging/logger'; |
| 3 | +import type { RootState } from 'app/store/store'; |
| 4 | +import type { SliceConfig } from 'app/store/types'; |
| 5 | +import { deepClone } from 'common/util/deepClone'; |
| 6 | +import { parseify } from 'common/util/serialize'; |
| 7 | +import { $templates } from 'features/nodes/store/nodesSlice'; |
| 8 | +import type { Templates } from 'features/nodes/store/types'; |
| 9 | +import type { WorkflowV3 } from 'features/nodes/types/workflow'; |
| 10 | +import { zWorkflowV3 } from 'features/nodes/types/workflow'; |
| 11 | +import { serializeError } from 'serialize-error'; |
| 12 | +import { workflowsApi } from 'services/api/endpoints/workflows'; |
| 13 | +import { z } from 'zod'; |
| 14 | + |
| 15 | +const log = logger('canvas'); |
| 16 | + |
| 17 | +const zCanvasWorkflowState = z.object({ |
| 18 | + selectedWorkflowId: z.string().nullable(), |
| 19 | + workflow: zWorkflowV3.nullable(), |
| 20 | + inputNodeId: z.string().nullable(), |
| 21 | + outputNodeId: z.string().nullable(), |
| 22 | + status: z.enum(['idle', 'loading', 'succeeded', 'failed']), |
| 23 | + error: z.string().nullable(), |
| 24 | +}); |
| 25 | + |
| 26 | +export type CanvasWorkflowState = z.infer<typeof zCanvasWorkflowState>; |
| 27 | + |
| 28 | +const getInitialState = (): CanvasWorkflowState => ({ |
| 29 | + selectedWorkflowId: null, |
| 30 | + workflow: null, |
| 31 | + inputNodeId: null, |
| 32 | + outputNodeId: null, |
| 33 | + status: 'idle', |
| 34 | + error: null, |
| 35 | +}); |
| 36 | + |
| 37 | +type ValidateResult = { |
| 38 | + inputNodeId: string; |
| 39 | + outputNodeId: string; |
| 40 | +}; |
| 41 | + |
| 42 | +const INPUT_TAG = 'canvas-workflow-input'; |
| 43 | +const OUTPUT_TAG = 'canvas-workflow-output'; |
| 44 | + |
| 45 | +const validateCanvasWorkflow = (workflow: WorkflowV3, templates: Templates): ValidateResult => { |
| 46 | + const invocationNodes = workflow.nodes.filter( |
| 47 | + (node): node is WorkflowV3['nodes'][number] => node.type === 'invocation' |
| 48 | + ); |
| 49 | + |
| 50 | + const inputNodes = invocationNodes.filter((node) => { |
| 51 | + const template = templates[node.data.type]; |
| 52 | + return Boolean(template && template.tags.includes(INPUT_TAG)); |
| 53 | + }); |
| 54 | + |
| 55 | + const outputNodes = invocationNodes.filter((node) => { |
| 56 | + const template = templates[node.data.type]; |
| 57 | + return Boolean(template && template.tags.includes(OUTPUT_TAG)); |
| 58 | + }); |
| 59 | + |
| 60 | + if (inputNodes.length !== 1) { |
| 61 | + throw new Error('A canvas workflow must include exactly one input node.'); |
| 62 | + } |
| 63 | + |
| 64 | + if (outputNodes.length !== 1) { |
| 65 | + throw new Error('A canvas workflow must include exactly one output node.'); |
| 66 | + } |
| 67 | + |
| 68 | + const inputNode = inputNodes[0]!; |
| 69 | + const outputNode = outputNodes[0]!; |
| 70 | + |
| 71 | + const inputTemplate = templates[inputNode.data.type]; |
| 72 | + if (!inputTemplate) { |
| 73 | + throw new Error(`Input node template "${inputNode.data.type}" not found.`); |
| 74 | + } |
| 75 | + if (!('image' in inputTemplate.inputs)) { |
| 76 | + throw new Error('Canvas input node must expose an image field.'); |
| 77 | + } |
| 78 | + |
| 79 | + const outputTemplate = templates[outputNode.data.type]; |
| 80 | + if (!outputTemplate) { |
| 81 | + throw new Error(`Output node template "${outputNode.data.type}" not found.`); |
| 82 | + } |
| 83 | + if (!('image' in outputTemplate.inputs)) { |
| 84 | + throw new Error('Canvas output node must accept an image input field named "image".'); |
| 85 | + } |
| 86 | + |
| 87 | + return { inputNodeId: inputNode.id, outputNodeId: outputNode.id }; |
| 88 | +}; |
| 89 | + |
| 90 | +export const selectCanvasWorkflow = createAsyncThunk< |
| 91 | + { workflowId: string; workflow: WorkflowV3; inputNodeId: string; outputNodeId: string }, |
| 92 | + string, |
| 93 | + { rejectValue: string } |
| 94 | +>('canvasWorkflow/select', async (workflowId, { dispatch, rejectWithValue }) => { |
| 95 | + const request = dispatch(workflowsApi.endpoints.getWorkflow.initiate(workflowId, { subscribe: false })); |
| 96 | + try { |
| 97 | + const result = await request.unwrap(); |
| 98 | + const workflow = zWorkflowV3.parse(deepClone(result.workflow)); |
| 99 | + const templates = $templates.get(); |
| 100 | + if (!Object.keys(templates).length) { |
| 101 | + throw new Error('Invocation templates are not yet available.'); |
| 102 | + } |
| 103 | + const { inputNodeId, outputNodeId } = validateCanvasWorkflow(workflow, templates); |
| 104 | + return { workflowId: result.workflow_id, workflow, inputNodeId, outputNodeId }; |
| 105 | + } catch (error) { |
| 106 | + const message = error instanceof Error ? error.message : 'Unable to load workflow.'; |
| 107 | + log.error({ error: serializeError(error as Error) }, 'Failed to load canvas workflow'); |
| 108 | + return rejectWithValue(message); |
| 109 | + } finally { |
| 110 | + request.unsubscribe(); |
| 111 | + } |
| 112 | +}); |
| 113 | + |
| 114 | +const slice = createSlice({ |
| 115 | + name: 'canvasWorkflow', |
| 116 | + initialState: getInitialState(), |
| 117 | + reducers: { |
| 118 | + canvasWorkflowCleared: () => getInitialState(), |
| 119 | + }, |
| 120 | + extraReducers(builder) { |
| 121 | + builder |
| 122 | + .addCase(selectCanvasWorkflow.pending, (state) => { |
| 123 | + state.status = 'loading'; |
| 124 | + state.error = null; |
| 125 | + }) |
| 126 | + .addCase(selectCanvasWorkflow.fulfilled, (state, action) => { |
| 127 | + state.selectedWorkflowId = action.payload.workflowId; |
| 128 | + state.workflow = action.payload.workflow; |
| 129 | + state.inputNodeId = action.payload.inputNodeId; |
| 130 | + state.outputNodeId = action.payload.outputNodeId; |
| 131 | + state.status = 'succeeded'; |
| 132 | + state.error = null; |
| 133 | + }) |
| 134 | + .addCase(selectCanvasWorkflow.rejected, (state, action) => { |
| 135 | + state.status = 'failed'; |
| 136 | + state.error = action.payload ?? action.error.message ?? 'Unable to load workflow.'; |
| 137 | + }); |
| 138 | + }, |
| 139 | +}); |
| 140 | + |
| 141 | +export const { canvasWorkflowCleared } = slice.actions; |
| 142 | + |
| 143 | +export const canvasWorkflowSliceConfig: SliceConfig<typeof slice> = { |
| 144 | + slice, |
| 145 | + schema: zCanvasWorkflowState, |
| 146 | + getInitialState, |
| 147 | + persistConfig: { |
| 148 | + migrate: (state) => { |
| 149 | + const parsed = zCanvasWorkflowState.safeParse(state); |
| 150 | + if (!parsed.success) { |
| 151 | + log.warn({ error: parseify(parsed.error) }, 'Failed to migrate canvas workflow state, resetting to defaults'); |
| 152 | + return getInitialState(); |
| 153 | + } |
| 154 | + return { |
| 155 | + ...parsed.data, |
| 156 | + status: 'idle', |
| 157 | + error: null, |
| 158 | + } satisfies CanvasWorkflowState; |
| 159 | + }, |
| 160 | + persistDenylist: ['status', 'error'], |
| 161 | + }, |
| 162 | +}; |
| 163 | + |
| 164 | +export const selectCanvasWorkflowSlice = (state: RootState) => state.canvasWorkflow; |
| 165 | + |
| 166 | +export const selectCanvasWorkflowStatus = (state: RootState) => selectCanvasWorkflowSlice(state).status; |
| 167 | + |
| 168 | +export const selectCanvasWorkflowError = (state: RootState) => selectCanvasWorkflowSlice(state).error; |
| 169 | + |
| 170 | +export const selectCanvasWorkflowSelection = (state: RootState) => selectCanvasWorkflowSlice(state).selectedWorkflowId; |
| 171 | + |
| 172 | +export const selectCanvasWorkflowData = (state: RootState) => selectCanvasWorkflowSlice(state).workflow; |
| 173 | + |
| 174 | +export const selectCanvasWorkflowNodeIds = (state: RootState) => ({ |
| 175 | + inputNodeId: selectCanvasWorkflowSlice(state).inputNodeId, |
| 176 | + outputNodeId: selectCanvasWorkflowSlice(state).outputNodeId, |
| 177 | +}); |
| 178 | + |
| 179 | +export const selectIsCanvasWorkflowActive = (state: RootState) => { |
| 180 | + const sliceState = selectCanvasWorkflowSlice(state); |
| 181 | + return ( |
| 182 | + Boolean(sliceState.workflow && sliceState.inputNodeId && sliceState.outputNodeId) && |
| 183 | + (sliceState.status === 'succeeded' || sliceState.status === 'idle') |
| 184 | + ); |
| 185 | +}; |
0 commit comments