Skip to content

Commit 9af146a

Browse files
committed
feat(workflow): support dynamic input/output schemas in node definitions
1 parent 713612c commit 9af146a

File tree

3 files changed

+138
-11
lines changed

3 files changed

+138
-11
lines changed

packages/workflow/src/index.ts

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,11 @@ function validateNodeInput(
7171
): boolean {
7272
if (!definition.inputSchema) return true
7373
try {
74-
definition.inputSchema.parse(input)
74+
const schema =
75+
typeof definition.inputSchema === 'function'
76+
? definition.inputSchema(node)
77+
: definition.inputSchema
78+
schema.parse(input)
7579
return true
7680
} catch {
7781
return false
@@ -81,13 +85,18 @@ function validateNodeInput(
8185
function validateNodeOutput(
8286
definition: NodeDefinition,
8387
output: NodeIO,
84-
context: NodeContext
88+
context: NodeContext,
89+
node: WorkflowNode
8590
): boolean {
8691
if (!definition.outputSchema) return true
8792
try {
88-
for (const [key, schema] of Object.entries(definition.outputSchema)) {
93+
const schema =
94+
typeof definition.outputSchema === 'function'
95+
? definition.outputSchema(node)
96+
: definition.outputSchema
97+
for (const [key, schemaType] of Object.entries(schema)) {
8998
if (key in output) {
90-
schema.parse(output[key])
99+
schemaType.parse(output[key])
91100
}
92101
}
93102
return true
@@ -258,6 +267,11 @@ export async function executeWorkflow(
258267
timing.duration = timing.endTime - timing.startTime
259268
callbacks.onNodeError?.(node.id, node.type, error, timing)
260269
failed.add(node.id)
270+
results[node.id] = {
271+
state: 'failed' as const,
272+
output: {},
273+
error
274+
}
261275
return
262276
}
263277

@@ -272,7 +286,14 @@ export async function executeWorkflow(
272286
timing.endTime = Date.now()
273287
timing.duration = timing.endTime - timing.startTime
274288

275-
if (!validateNodeOutput(definition, output, context)) {
289+
if (
290+
!validateNodeOutput(
291+
definition,
292+
output,
293+
context,
294+
node
295+
)
296+
) {
276297
throw new Error(
277298
`Invalid output for node ${node.id}`
278299
)

packages/workflow/src/types.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,12 @@ export type NodeDefinition<
3030
TData = unknown
3131
> = {
3232
run: (input: TInput, context: NodeContext, data?: TData) => Promise<TOutput>
33-
inputSchema?: z.ZodType<TInput>
34-
outputSchema?: Record<string, z.ZodType<unknown>>
33+
inputSchema?:
34+
| z.ZodType<TInput>
35+
| ((node: WorkflowNode) => z.ZodType<TInput>)
36+
outputSchema?:
37+
| Record<string, z.ZodType<unknown>>
38+
| ((node: WorkflowNode) => Record<string, z.ZodType<unknown>>)
3539
dataSchema?: z.ZodType<TData>
3640
}
3741

packages/workflow/tests/example.spec.ts

Lines changed: 106 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,17 @@ describe('Workflow with Optional Outputs and Skipped Nodes', () => {
143143
}
144144

145145
const [lastNodeResult, allResults] = await runExample()
146-
146+
147147
// Verify last node result (evenHandler1)
148148
expect(lastNodeResult.state).to.equal('completed')
149149
expect(lastNodeResult.output).to.deep.equal({ doubled: 8 })
150-
150+
151151
// Verify all node results
152-
expect(allResults).to.have.all.keys('processor1', 'evenHandler1', 'oddHandler1')
152+
expect(allResults).to.have.all.keys(
153+
'processor1',
154+
'evenHandler1',
155+
'oddHandler1'
156+
)
153157
expect(allResults.processor1.state).to.equal('completed')
154158
expect(allResults.processor1.output).to.deep.equal({ even: 4 })
155159
expect(allResults.evenHandler1.state).to.equal('completed')
@@ -284,7 +288,11 @@ describe('Workflow with Optional Outputs and Skipped Nodes', () => {
284288
expect(lastNodeResult.output).to.deep.equal({ doubled: 10 })
285289

286290
// Verify all node results
287-
expect(allResults).to.have.all.keys('if1', 'trueHandler1', 'falseHandler1')
291+
expect(allResults).to.have.all.keys(
292+
'if1',
293+
'trueHandler1',
294+
'falseHandler1'
295+
)
288296
expect(allResults.if1.state).to.equal('completed')
289297
expect(allResults.if1.output).to.deep.equal({ true: 5 })
290298
expect(allResults.trueHandler1.state).to.equal('completed')
@@ -396,4 +404,98 @@ describe('Workflow with Optional Outputs and Skipped Nodes', () => {
396404
expect(allResults.divide.state).to.equal('completed')
397405
expect(allResults.divide.output).to.deep.equal({ result: 5 })
398406
})
407+
408+
it('should handle dynamic input/output schemas based on node config', async () => {
409+
// Define node types with dynamic schemas
410+
type EvalInput = { value: number | string }
411+
type EvalOutput = { result: unknown }
412+
413+
const evalNode: NodeDefinition<EvalInput, EvalOutput> = {
414+
inputSchema: (node: WorkflowNode) => {
415+
// Dynamic input schema based on node config
416+
const inputType = node.config?.inputType || 'number'
417+
return z.object({
418+
value: inputType === 'string' ? z.string() : z.number()
419+
})
420+
},
421+
outputSchema: (node: WorkflowNode) => {
422+
// Dynamic output schema based on node config
423+
const outputType = node.config?.outputType || 'number'
424+
return {
425+
result: outputType === 'string' ? z.string() : z.number()
426+
}
427+
},
428+
run: async (input: EvalInput) => {
429+
// Simple evaluation for demonstration
430+
const result = input.value.toString()
431+
return { result }
432+
}
433+
}
434+
435+
// Create and configure workflow
436+
const factory = createNodeFactory()
437+
factory.registerNode('eval', evalNode)
438+
439+
const workflow: WorkflowNode[] = [
440+
{
441+
id: 'eval1',
442+
type: 'eval',
443+
dependencies: [],
444+
config: {
445+
inputType: 'number',
446+
outputType: 'string'
447+
}
448+
}
449+
]
450+
451+
const initialContext: NodeContext = {
452+
variables: {
453+
value: 42
454+
},
455+
metadata: {}
456+
}
457+
458+
const [lastNodeResult, allResults] = await executeWorkflow(
459+
workflow,
460+
factory,
461+
initialContext,
462+
{
463+
maxRetries: 2,
464+
maxParallel: 2
465+
}
466+
)
467+
468+
// Verify results
469+
expect(lastNodeResult.state).to.equal('completed')
470+
expect(lastNodeResult.output).to.deep.equal({ result: '42' })
471+
expect(allResults.eval1.state).to.equal('completed')
472+
expect(allResults.eval1.output).to.deep.equal({ result: '42' })
473+
474+
// Test with invalid input type
475+
const invalidWorkflow: WorkflowNode[] = [
476+
{
477+
id: 'eval2',
478+
type: 'eval',
479+
dependencies: [],
480+
config: {
481+
inputType: 'string',
482+
outputType: 'number'
483+
}
484+
}
485+
]
486+
487+
const [invalidLastResult, invalidResults] = await executeWorkflow(
488+
invalidWorkflow,
489+
factory,
490+
initialContext,
491+
{
492+
maxRetries: 2,
493+
maxParallel: 2
494+
}
495+
)
496+
497+
// Verify validation failure
498+
expect(invalidLastResult.state).to.equal('failed')
499+
expect(invalidResults.eval2.state).to.equal('failed')
500+
})
399501
})

0 commit comments

Comments
 (0)