Skip to content

Commit 0087cde

Browse files
feat(compose): enable persisting inputs for re-run nodes
Change-Id: I0b0e8a5d4a2970289eb6f48021f6ed047f6ebec4
1 parent e293e98 commit 0087cde

File tree

5 files changed

+354
-5
lines changed

5 files changed

+354
-5
lines changed

compose/checkpoint_test.go

Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,3 +1531,289 @@ func TestCancelInterrupt(t *testing.T) {
15311531
"3": "3",
15321532
}, result2)
15331533
}
1534+
1535+
func TestPersistRerunInputNonStream(t *testing.T) {
1536+
store := newInMemoryStore()
1537+
1538+
var receivedInput string
1539+
var callCount int
1540+
1541+
g := NewGraph[string, string]()
1542+
1543+
err := g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
1544+
callCount++
1545+
receivedInput = input
1546+
if callCount == 1 {
1547+
return "", Interrupt(ctx, "interrupt")
1548+
}
1549+
return input + "_processed", nil
1550+
}))
1551+
assert.NoError(t, err)
1552+
1553+
err = g.AddEdge(START, "1")
1554+
assert.NoError(t, err)
1555+
err = g.AddEdge("1", END)
1556+
assert.NoError(t, err)
1557+
1558+
ctx := context.Background()
1559+
r, err := g.Compile(ctx,
1560+
WithNodeTriggerMode(AllPredecessor),
1561+
WithCheckPointStore(store),
1562+
WithCheckpointConfig(CheckpointConfig{PersistRerunInput: true}),
1563+
)
1564+
assert.NoError(t, err)
1565+
1566+
_, err = r.Invoke(ctx, "test_input", WithCheckPointID("cp1"))
1567+
assert.NotNil(t, err)
1568+
info, ok := ExtractInterruptInfo(err)
1569+
assert.True(t, ok)
1570+
assert.Equal(t, []string{"1"}, info.RerunNodes)
1571+
1572+
assert.Equal(t, "test_input", receivedInput)
1573+
1574+
result, err := r.Invoke(ctx, "", WithCheckPointID("cp1"))
1575+
assert.NoError(t, err)
1576+
assert.Equal(t, "test_input_processed", result)
1577+
assert.Equal(t, "test_input", receivedInput)
1578+
assert.Equal(t, 2, callCount)
1579+
}
1580+
1581+
func TestPersistRerunInputStream(t *testing.T) {
1582+
store := newInMemoryStore()
1583+
1584+
var receivedInput string
1585+
var callCount int
1586+
1587+
g := NewGraph[string, string]()
1588+
1589+
err := g.AddLambdaNode("1", TransformableLambda(func(ctx context.Context, input *schema.StreamReader[string]) (output *schema.StreamReader[string], err error) {
1590+
callCount++
1591+
1592+
var sb string
1593+
for {
1594+
chunk, err := input.Recv()
1595+
if err == io.EOF {
1596+
break
1597+
}
1598+
if err != nil {
1599+
return nil, err
1600+
}
1601+
sb += chunk
1602+
}
1603+
receivedInput = sb
1604+
1605+
if callCount == 1 {
1606+
return nil, Interrupt(ctx, "interrupt")
1607+
}
1608+
1609+
return schema.StreamReaderFromArray([]string{sb + "_processed"}), nil
1610+
}))
1611+
assert.NoError(t, err)
1612+
1613+
err = g.AddEdge(START, "1")
1614+
assert.NoError(t, err)
1615+
err = g.AddEdge("1", END)
1616+
assert.NoError(t, err)
1617+
1618+
ctx := context.Background()
1619+
r, err := g.Compile(ctx,
1620+
WithNodeTriggerMode(AllPredecessor),
1621+
WithCheckPointStore(store),
1622+
WithCheckpointConfig(CheckpointConfig{PersistRerunInput: true}),
1623+
)
1624+
assert.NoError(t, err)
1625+
1626+
inputStream := schema.StreamReaderFromArray([]string{"chunk1", "chunk2", "chunk3"})
1627+
1628+
_, err = r.Transform(ctx, inputStream, WithCheckPointID("cp1"))
1629+
assert.NotNil(t, err)
1630+
info, ok := ExtractInterruptInfo(err)
1631+
assert.True(t, ok)
1632+
assert.Equal(t, []string{"1"}, info.RerunNodes)
1633+
1634+
assert.Equal(t, "chunk1chunk2chunk3", receivedInput)
1635+
1636+
emptyInputStream := schema.StreamReaderFromArray([]string{})
1637+
1638+
resultStream, err := r.Transform(ctx, emptyInputStream, WithCheckPointID("cp1"))
1639+
assert.NoError(t, err)
1640+
1641+
var result string
1642+
for {
1643+
chunk, err := resultStream.Recv()
1644+
if err == io.EOF {
1645+
break
1646+
}
1647+
assert.NoError(t, err)
1648+
result += chunk
1649+
}
1650+
1651+
assert.Equal(t, "chunk1chunk2chunk3_processed", result)
1652+
assert.Equal(t, "chunk1chunk2chunk3", receivedInput)
1653+
assert.Equal(t, 2, callCount)
1654+
}
1655+
1656+
type testPersistRerunInputState struct {
1657+
Prefix string
1658+
}
1659+
1660+
func TestPersistRerunInputWithPreHandler(t *testing.T) {
1661+
store := newInMemoryStore()
1662+
1663+
var receivedInput string
1664+
var callCount int
1665+
1666+
schema.Register[testPersistRerunInputState]()
1667+
1668+
g := NewGraph[string, string](WithGenLocalState(func(ctx context.Context) *testPersistRerunInputState {
1669+
return &testPersistRerunInputState{Prefix: "prefix_"}
1670+
}))
1671+
1672+
err := g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
1673+
callCount++
1674+
receivedInput = input
1675+
if callCount == 1 {
1676+
return "", Interrupt(ctx, "interrupt")
1677+
}
1678+
return input + "_processed", nil
1679+
}), WithStatePreHandler(func(ctx context.Context, in string, s *testPersistRerunInputState) (string, error) {
1680+
return s.Prefix + in, nil
1681+
}))
1682+
assert.NoError(t, err)
1683+
1684+
err = g.AddEdge(START, "1")
1685+
assert.NoError(t, err)
1686+
err = g.AddEdge("1", END)
1687+
assert.NoError(t, err)
1688+
1689+
ctx := context.Background()
1690+
r, err := g.Compile(ctx,
1691+
WithNodeTriggerMode(AllPredecessor),
1692+
WithCheckPointStore(store),
1693+
WithCheckpointConfig(CheckpointConfig{PersistRerunInput: true}),
1694+
)
1695+
assert.NoError(t, err)
1696+
1697+
_, err = r.Invoke(ctx, "test_input", WithCheckPointID("cp1"))
1698+
assert.NotNil(t, err)
1699+
info, ok := ExtractInterruptInfo(err)
1700+
assert.True(t, ok)
1701+
if ok {
1702+
assert.Equal(t, []string{"1"}, info.RerunNodes)
1703+
}
1704+
1705+
assert.Equal(t, "prefix_test_input", receivedInput)
1706+
1707+
result, err := r.Invoke(ctx, "", WithCheckPointID("cp1"))
1708+
assert.NoError(t, err)
1709+
assert.Equal(t, "prefix_test_input_processed", result)
1710+
assert.Equal(t, "prefix_test_input", receivedInput)
1711+
assert.Equal(t, 2, callCount)
1712+
}
1713+
1714+
func TestPersistRerunInputBackwardCompatibility(t *testing.T) {
1715+
store := newInMemoryStore()
1716+
1717+
var receivedInput string
1718+
var callCount int
1719+
1720+
g := NewGraph[string, string]()
1721+
1722+
err := g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
1723+
callCount++
1724+
receivedInput = input
1725+
if len(input) > 0 {
1726+
return "", StatefulInterrupt(ctx, "interrupt", input)
1727+
}
1728+
1729+
_, _, restoredInput := GetInterruptState[string](ctx)
1730+
return restoredInput + "_processed", nil
1731+
}))
1732+
assert.NoError(t, err)
1733+
1734+
err = g.AddEdge(START, "1")
1735+
assert.NoError(t, err)
1736+
err = g.AddEdge("1", END)
1737+
assert.NoError(t, err)
1738+
1739+
ctx := context.Background()
1740+
r, err := g.Compile(ctx,
1741+
WithNodeTriggerMode(AllPredecessor),
1742+
WithCheckPointStore(store),
1743+
)
1744+
assert.NoError(t, err)
1745+
1746+
_, err = r.Invoke(ctx, "test_input", WithCheckPointID("cp1"))
1747+
assert.NotNil(t, err)
1748+
info, ok := ExtractInterruptInfo(err)
1749+
assert.True(t, ok)
1750+
assert.Equal(t, []string{"1"}, info.RerunNodes)
1751+
1752+
assert.Equal(t, "test_input", receivedInput)
1753+
1754+
result, err := r.Invoke(ctx, "", WithCheckPointID("cp1"))
1755+
assert.NoError(t, err)
1756+
assert.Equal(t, "test_input_processed", result)
1757+
assert.Equal(t, "", receivedInput)
1758+
assert.Equal(t, 2, callCount)
1759+
}
1760+
1761+
func TestPersistRerunInputSubGraph(t *testing.T) {
1762+
store := newInMemoryStore()
1763+
1764+
var receivedInput string
1765+
var callCount int
1766+
1767+
subG := NewGraph[string, string]()
1768+
err := subG.AddLambdaNode("sub1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
1769+
callCount++
1770+
receivedInput = input
1771+
if callCount == 1 {
1772+
return "", Interrupt(ctx, "interrupt")
1773+
}
1774+
return input + "_sub_processed", nil
1775+
}))
1776+
assert.NoError(t, err)
1777+
err = subG.AddEdge(START, "sub1")
1778+
assert.NoError(t, err)
1779+
err = subG.AddEdge("sub1", END)
1780+
assert.NoError(t, err)
1781+
1782+
g := NewGraph[string, string]()
1783+
err = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
1784+
return input + "_main", nil
1785+
}))
1786+
assert.NoError(t, err)
1787+
err = g.AddGraphNode("2", subG)
1788+
assert.NoError(t, err)
1789+
err = g.AddEdge(START, "1")
1790+
assert.NoError(t, err)
1791+
err = g.AddEdge("1", "2")
1792+
assert.NoError(t, err)
1793+
err = g.AddEdge("2", END)
1794+
assert.NoError(t, err)
1795+
1796+
ctx := context.Background()
1797+
r, err := g.Compile(ctx,
1798+
WithNodeTriggerMode(AllPredecessor),
1799+
WithCheckPointStore(store),
1800+
WithCheckpointConfig(CheckpointConfig{PersistRerunInput: true}),
1801+
)
1802+
assert.NoError(t, err)
1803+
1804+
_, err = r.Invoke(ctx, "test", WithCheckPointID("cp1"))
1805+
assert.NotNil(t, err)
1806+
info, ok := ExtractInterruptInfo(err)
1807+
assert.True(t, ok)
1808+
assert.Contains(t, info.SubGraphs, "2")
1809+
subInfo := info.SubGraphs["2"]
1810+
assert.Equal(t, []string{"sub1"}, subInfo.RerunNodes)
1811+
1812+
assert.Equal(t, "test_main", receivedInput)
1813+
1814+
result, err := r.Invoke(ctx, "", WithCheckPointID("cp1"))
1815+
assert.NoError(t, err)
1816+
assert.Equal(t, "test_main_sub_processed", result)
1817+
assert.Equal(t, "test_main", receivedInput)
1818+
assert.Equal(t, 2, callCount)
1819+
}

compose/graph.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,15 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa
696696
for name, node := range g.nodes {
697697
node.beforeChildGraphCompile(name, key2SubGraphs)
698698

699+
if opt != nil && opt.checkpointConfig != nil && node.g != nil {
700+
if node.nodeInfo.compileOption == nil {
701+
node.nodeInfo.compileOption = &graphCompileOptions{}
702+
}
703+
if node.nodeInfo.compileOption.checkpointConfig == nil {
704+
node.nodeInfo.compileOption.checkpointConfig = opt.checkpointConfig
705+
}
706+
}
707+
699708
r, err := node.compileIfNeeded(ctx)
700709
if err != nil {
701710
return nil, err
@@ -837,6 +846,7 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa
837846
outputPairs[START] = r.inputConvertStreamPair
838847
r.checkPointer = newCheckPointer(inputPairs, outputPairs, opt.checkPointStore, opt.serializer)
839848

849+
r.checkpointConfig = opt.checkpointConfig
840850
r.interruptBeforeNodes = opt.interruptBeforeNodes
841851
r.interruptAfterNodes = opt.interruptAfterNodes
842852
r.options = *opt

compose/graph_compile_options.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,23 @@ type graphCompileOptions struct {
2929
serializer Serializer
3030
interruptBeforeNodes []string
3131
interruptAfterNodes []string
32+
checkpointConfig *CheckpointConfig
3233

3334
eagerDisabled bool
3435

3536
mergeConfigs map[string]FanInMergeConfig
3637
}
3738

39+
// CheckpointConfig contains configuration options for checkpoint behavior.
40+
// This configuration is inherited by subgraphs unless explicitly overridden.
41+
type CheckpointConfig struct {
42+
// PersistRerunInput enables persisting the original input for rerun nodes in checkpoint.
43+
// When enabled, stream inputs are tee'd before consumption so they can be persisted
44+
// and restored when resuming from an interrupt.
45+
// When disabled (default), rerun nodes receive zero/empty values on resume for backward compatibility.
46+
PersistRerunInput bool
47+
}
48+
3849
func newGraphCompileOptions(opts ...GraphCompileOption) *graphCompileOptions {
3950
option := &graphCompileOptions{}
4051

@@ -124,6 +135,16 @@ func WithFanInMergeConfig(confs map[string]FanInMergeConfig) GraphCompileOption
124135
}
125136
}
126137

138+
// WithCheckpointConfig sets the checkpoint configuration for the graph.
139+
// This configuration is inherited by subgraphs.
140+
// CheckpointConfig.PersistRerunInput enables persisting rerun node inputs in checkpoint,
141+
// allowing nodes to be resumed with their original inputs after an interrupt.
142+
func WithCheckpointConfig(config CheckpointConfig) GraphCompileOption {
143+
return func(o *graphCompileOptions) {
144+
o.checkpointConfig = &config
145+
}
146+
}
147+
127148
// InitGraphCompileCallbacks set global graph compile callbacks,
128149
// which ONLY will be added to top level graph compile options
129150
func InitGraphCompileCallbacks(cbs []GraphCompileCallback) {

0 commit comments

Comments
 (0)