Skip to content

Commit ef1ae4e

Browse files
refactor(compose): remove CheckpointConfig, auto-enable PersistRerunInput with WithGraphInterrupt
- Remove CheckpointConfig struct and WithCheckpointConfig option - Replace checkpointConfig field with persistRerunInput bool in taskManager - Auto-enable persistRerunInput when WithGraphInterrupt is used - Subgraphs inherit behavior through context propagation - Update tests to use WithGraphInterrupt pattern Change-Id: I15c3e10d815b5aa39768e6fdee1407025e5ea542
1 parent 50b2f1c commit ef1ae4e

File tree

5 files changed

+94
-130
lines changed

5 files changed

+94
-130
lines changed

compose/checkpoint_test.go

Lines changed: 86 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1539,16 +1539,21 @@ func TestCancelInterrupt(t *testing.T) {
15391539
func TestPersistRerunInputNonStream(t *testing.T) {
15401540
store := newInMemoryStore()
15411541

1542+
var mu sync.Mutex
15421543
var receivedInput string
15431544
var callCount int
15441545

15451546
g := NewGraph[string, string]()
15461547

15471548
err := g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
1549+
mu.Lock()
15481550
callCount++
1551+
currentCount := callCount
15491552
receivedInput = input
1550-
if callCount == 1 {
1551-
return "", Interrupt(ctx, "interrupt")
1553+
mu.Unlock()
1554+
1555+
if currentCount == 1 {
1556+
time.Sleep(2 * time.Second)
15521557
}
15531558
return input + "_processed", nil
15541559
}))
@@ -1563,35 +1568,49 @@ func TestPersistRerunInputNonStream(t *testing.T) {
15631568
r, err := g.Compile(ctx,
15641569
WithNodeTriggerMode(AllPredecessor),
15651570
WithCheckPointStore(store),
1566-
WithCheckpointConfig(CheckpointConfig{PersistRerunInput: true}),
15671571
)
15681572
assert.NoError(t, err)
15691573

1570-
_, err = r.Invoke(ctx, "test_input", WithCheckPointID("cp1"))
1574+
canceledCtx, cancel := WithGraphInterrupt(ctx)
1575+
go func() {
1576+
time.Sleep(100 * time.Millisecond)
1577+
cancel(WithGraphInterruptTimeout(0))
1578+
}()
1579+
1580+
_, err = r.Invoke(canceledCtx, "test_input", WithCheckPointID("cp1"))
15711581
assert.NotNil(t, err)
15721582
info, ok := ExtractInterruptInfo(err)
15731583
assert.True(t, ok)
15741584
assert.Equal(t, []string{"1"}, info.RerunNodes)
15751585

1586+
mu.Lock()
15761587
assert.Equal(t, "test_input", receivedInput)
1588+
mu.Unlock()
15771589

15781590
result, err := r.Invoke(ctx, "", WithCheckPointID("cp1"))
15791591
assert.NoError(t, err)
15801592
assert.Equal(t, "test_input_processed", result)
1593+
1594+
mu.Lock()
15811595
assert.Equal(t, "test_input", receivedInput)
15821596
assert.Equal(t, 2, callCount)
1597+
mu.Unlock()
15831598
}
15841599

15851600
func TestPersistRerunInputStream(t *testing.T) {
15861601
store := newInMemoryStore()
15871602

1603+
var mu sync.Mutex
15881604
var receivedInput string
15891605
var callCount int
15901606

15911607
g := NewGraph[string, string]()
15921608

15931609
err := g.AddLambdaNode("1", TransformableLambda(func(ctx context.Context, input *schema.StreamReader[string]) (output *schema.StreamReader[string], err error) {
1610+
mu.Lock()
15941611
callCount++
1612+
currentCount := callCount
1613+
mu.Unlock()
15951614

15961615
var sb string
15971616
for {
@@ -1604,10 +1623,13 @@ func TestPersistRerunInputStream(t *testing.T) {
16041623
}
16051624
sb += chunk
16061625
}
1626+
1627+
mu.Lock()
16071628
receivedInput = sb
1629+
mu.Unlock()
16081630

1609-
if callCount == 1 {
1610-
return nil, Interrupt(ctx, "interrupt")
1631+
if currentCount == 1 {
1632+
time.Sleep(2 * time.Second)
16111633
}
16121634

16131635
return schema.StreamReaderFromArray([]string{sb + "_processed"}), nil
@@ -1623,19 +1645,26 @@ func TestPersistRerunInputStream(t *testing.T) {
16231645
r, err := g.Compile(ctx,
16241646
WithNodeTriggerMode(AllPredecessor),
16251647
WithCheckPointStore(store),
1626-
WithCheckpointConfig(CheckpointConfig{PersistRerunInput: true}),
16271648
)
16281649
assert.NoError(t, err)
16291650

16301651
inputStream := schema.StreamReaderFromArray([]string{"chunk1", "chunk2", "chunk3"})
16311652

1632-
_, err = r.Transform(ctx, inputStream, WithCheckPointID("cp1"))
1653+
canceledCtx, cancel := WithGraphInterrupt(ctx)
1654+
go func() {
1655+
time.Sleep(100 * time.Millisecond)
1656+
cancel(WithGraphInterruptTimeout(0))
1657+
}()
1658+
1659+
_, err = r.Transform(canceledCtx, inputStream, WithCheckPointID("cp1"))
16331660
assert.NotNil(t, err)
16341661
info, ok := ExtractInterruptInfo(err)
16351662
assert.True(t, ok)
16361663
assert.Equal(t, []string{"1"}, info.RerunNodes)
16371664

1665+
mu.Lock()
16381666
assert.Equal(t, "chunk1chunk2chunk3", receivedInput)
1667+
mu.Unlock()
16391668

16401669
emptyInputStream := schema.StreamReaderFromArray([]string{})
16411670

@@ -1653,8 +1682,11 @@ func TestPersistRerunInputStream(t *testing.T) {
16531682
}
16541683

16551684
assert.Equal(t, "chunk1chunk2chunk3_processed", result)
1685+
1686+
mu.Lock()
16561687
assert.Equal(t, "chunk1chunk2chunk3", receivedInput)
16571688
assert.Equal(t, 2, callCount)
1689+
mu.Unlock()
16581690
}
16591691

16601692
type testPersistRerunInputState struct {
@@ -1664,6 +1696,7 @@ type testPersistRerunInputState struct {
16641696
func TestPersistRerunInputWithPreHandler(t *testing.T) {
16651697
store := newInMemoryStore()
16661698

1699+
var mu sync.Mutex
16671700
var receivedInput string
16681701
var callCount int
16691702

@@ -1674,10 +1707,14 @@ func TestPersistRerunInputWithPreHandler(t *testing.T) {
16741707
}))
16751708

16761709
err := g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
1710+
mu.Lock()
16771711
callCount++
1712+
currentCount := callCount
16781713
receivedInput = input
1679-
if callCount == 1 {
1680-
return "", Interrupt(ctx, "interrupt")
1714+
mu.Unlock()
1715+
1716+
if currentCount == 1 {
1717+
time.Sleep(2 * time.Second)
16811718
}
16821719
return input + "_processed", nil
16831720
}), WithStatePreHandler(func(ctx context.Context, in string, s *testPersistRerunInputState) (string, error) {
@@ -1694,25 +1731,35 @@ func TestPersistRerunInputWithPreHandler(t *testing.T) {
16941731
r, err := g.Compile(ctx,
16951732
WithNodeTriggerMode(AllPredecessor),
16961733
WithCheckPointStore(store),
1697-
WithCheckpointConfig(CheckpointConfig{PersistRerunInput: true}),
16981734
)
16991735
assert.NoError(t, err)
17001736

1701-
_, err = r.Invoke(ctx, "test_input", WithCheckPointID("cp1"))
1737+
canceledCtx, cancel := WithGraphInterrupt(ctx)
1738+
go func() {
1739+
time.Sleep(100 * time.Millisecond)
1740+
cancel(WithGraphInterruptTimeout(0))
1741+
}()
1742+
1743+
_, err = r.Invoke(canceledCtx, "test_input", WithCheckPointID("cp1"))
17021744
assert.NotNil(t, err)
17031745
info, ok := ExtractInterruptInfo(err)
17041746
assert.True(t, ok)
17051747
if ok {
17061748
assert.Equal(t, []string{"1"}, info.RerunNodes)
17071749
}
17081750

1751+
mu.Lock()
17091752
assert.Equal(t, "prefix_test_input", receivedInput)
1753+
mu.Unlock()
17101754

17111755
result, err := r.Invoke(ctx, "", WithCheckPointID("cp1"))
17121756
assert.NoError(t, err)
17131757
assert.Equal(t, "prefix_test_input_processed", result)
1758+
1759+
mu.Lock()
17141760
assert.Equal(t, "prefix_test_input", receivedInput)
17151761
assert.Equal(t, 2, callCount)
1762+
mu.Unlock()
17161763
}
17171764

17181765
func TestPersistRerunInputBackwardCompatibility(t *testing.T) {
@@ -1765,15 +1812,20 @@ func TestPersistRerunInputBackwardCompatibility(t *testing.T) {
17651812
func TestPersistRerunInputSubGraph(t *testing.T) {
17661813
store := newInMemoryStore()
17671814

1815+
var mu sync.Mutex
17681816
var receivedInput string
17691817
var callCount int
17701818

17711819
subG := NewGraph[string, string]()
17721820
err := subG.AddLambdaNode("sub1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
1821+
mu.Lock()
17731822
callCount++
1823+
currentCount := callCount
17741824
receivedInput = input
1775-
if callCount == 1 {
1776-
return "", Interrupt(ctx, "interrupt")
1825+
mu.Unlock()
1826+
1827+
if currentCount == 1 {
1828+
time.Sleep(2 * time.Second)
17771829
}
17781830
return input + "_sub_processed", nil
17791831
}))
@@ -1801,25 +1853,39 @@ func TestPersistRerunInputSubGraph(t *testing.T) {
18011853
r, err := g.Compile(ctx,
18021854
WithNodeTriggerMode(AllPredecessor),
18031855
WithCheckPointStore(store),
1804-
WithCheckpointConfig(CheckpointConfig{PersistRerunInput: true}),
18051856
)
18061857
assert.NoError(t, err)
18071858

1808-
_, err = r.Invoke(ctx, "test", WithCheckPointID("cp1"))
1859+
canceledCtx, cancel := WithGraphInterrupt(ctx)
1860+
go func() {
1861+
time.Sleep(100 * time.Millisecond)
1862+
cancel(WithGraphInterruptTimeout(0))
1863+
}()
1864+
1865+
_, err = r.Invoke(canceledCtx, "test", WithCheckPointID("cp1"))
18091866
assert.NotNil(t, err)
18101867
info, ok := ExtractInterruptInfo(err)
1811-
assert.True(t, ok)
1812-
assert.Contains(t, info.SubGraphs, "2")
1813-
subInfo := info.SubGraphs["2"]
1814-
assert.Equal(t, []string{"sub1"}, subInfo.RerunNodes)
1868+
assert.True(t, ok, "Expected interrupt error, got: %v", err)
1869+
if len(info.SubGraphs) > 0 {
1870+
assert.Contains(t, info.SubGraphs, "2")
1871+
subInfo := info.SubGraphs["2"]
1872+
assert.Equal(t, []string{"sub1"}, subInfo.RerunNodes)
1873+
} else {
1874+
assert.Equal(t, []string{"2"}, info.RerunNodes)
1875+
}
18151876

1877+
mu.Lock()
18161878
assert.Equal(t, "test_main", receivedInput)
1879+
mu.Unlock()
18171880

18181881
result, err := r.Invoke(ctx, "", WithCheckPointID("cp1"))
18191882
assert.NoError(t, err)
18201883
assert.Equal(t, "test_main_sub_processed", result)
1884+
1885+
mu.Lock()
18211886
assert.Equal(t, "test_main", receivedInput)
18221887
assert.Equal(t, 2, callCount)
1888+
mu.Unlock()
18231889
}
18241890

18251891
type longRunningToolInput struct {
@@ -1908,69 +1974,6 @@ func TestToolsNodeWithExternalGraphInterrupt(t *testing.T) {
19081974
mu.Unlock()
19091975
}
19101976

1911-
func TestExternalInterruptRespectsExplicitPersistRerunInputFalse(t *testing.T) {
1912-
store := newInMemoryStore()
1913-
ctx := context.Background()
1914-
1915-
var mu sync.Mutex
1916-
var callCount int
1917-
var receivedInputOnResume string
1918-
1919-
g := NewGraph[string, string]()
1920-
err := g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
1921-
mu.Lock()
1922-
callCount++
1923-
currentCount := callCount
1924-
mu.Unlock()
1925-
1926-
if currentCount == 1 {
1927-
time.Sleep(2 * time.Second)
1928-
}
1929-
if currentCount == 2 {
1930-
mu.Lock()
1931-
receivedInputOnResume = input
1932-
mu.Unlock()
1933-
}
1934-
return input + "_processed", nil
1935-
}))
1936-
assert.NoError(t, err)
1937-
1938-
err = g.AddEdge(START, "1")
1939-
assert.NoError(t, err)
1940-
err = g.AddEdge("1", END)
1941-
assert.NoError(t, err)
1942-
1943-
r, err := g.Compile(ctx,
1944-
WithNodeTriggerMode(AllPredecessor),
1945-
WithCheckPointStore(store),
1946-
WithCheckpointConfig(CheckpointConfig{PersistRerunInput: false}),
1947-
)
1948-
assert.NoError(t, err)
1949-
1950-
canceledCtx, cancel := WithGraphInterrupt(ctx)
1951-
go func() {
1952-
time.Sleep(100 * time.Millisecond)
1953-
cancel(WithGraphInterruptTimeout(0))
1954-
}()
1955-
1956-
_, err = r.Invoke(canceledCtx, "test_input", WithCheckPointID("cp1"))
1957-
assert.Error(t, err)
1958-
info, ok := ExtractInterruptInfo(err)
1959-
assert.True(t, ok, "Expected interrupt error, got: %v", err)
1960-
if ok {
1961-
assert.Equal(t, []string{"1"}, info.RerunNodes)
1962-
}
1963-
1964-
result, err := r.Invoke(ctx, "", WithCheckPointID("cp1"))
1965-
assert.NoError(t, err)
1966-
assert.Equal(t, "_processed", result)
1967-
1968-
mu.Lock()
1969-
assert.Equal(t, "", receivedInputOnResume)
1970-
assert.Equal(t, 2, callCount)
1971-
mu.Unlock()
1972-
}
1973-
19741977
type checkpointTestTool[I, O any] struct {
19751978
info *schema.ToolInfo
19761979
fn func(ctx context.Context, in I) (O, error)

compose/graph.go

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -696,15 +696,6 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa
696696
for name, node := range g.nodes {
697697
node.beforeChildGraphCompile(name, key2SubGraphs)
698698

699-
if opt != nil && opt.checkpointConfig != nil && node.g != nil {
700-
if node.nodeInfo.compileOption == nil {
701-
node.nodeInfo.compileOption = &graphCompileOptions{}
702-
}
703-
if node.nodeInfo.compileOption.checkpointConfig == nil {
704-
node.nodeInfo.compileOption.checkpointConfig = opt.checkpointConfig
705-
}
706-
}
707-
708699
r, err := node.compileIfNeeded(ctx)
709700
if err != nil {
710701
return nil, err
@@ -846,7 +837,6 @@ func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composa
846837
outputPairs[START] = r.inputConvertStreamPair
847838
r.checkPointer = newCheckPointer(inputPairs, outputPairs, opt.checkPointStore, opt.serializer)
848839

849-
r.checkpointConfig = opt.checkpointConfig
850840
r.interruptBeforeNodes = opt.interruptBeforeNodes
851841
r.interruptAfterNodes = opt.interruptAfterNodes
852842
r.options = *opt

compose/graph_compile_options.go

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,23 +29,12 @@ type graphCompileOptions struct {
2929
serializer Serializer
3030
interruptBeforeNodes []string
3131
interruptAfterNodes []string
32-
checkpointConfig *CheckpointConfig
3332

3433
eagerDisabled bool
3534

3635
mergeConfigs map[string]FanInMergeConfig
3736
}
3837

39-
// CheckpointConfig contains configuration options for checkpoint behavior.
40-
// This configuration is inherited by subgraphs unless explicitly overridden.
41-
type CheckpointConfig struct {
42-
// PersistRerunInput enables persisting the original input for rerun nodes in checkpoint.
43-
// When enabled, stream inputs are tee'd before consumption so they can be persisted
44-
// and restored when resuming from an interrupt.
45-
// When disabled (default), rerun nodes receive zero/empty values on resume for backward compatibility.
46-
PersistRerunInput bool
47-
}
48-
4938
func newGraphCompileOptions(opts ...GraphCompileOption) *graphCompileOptions {
5039
option := &graphCompileOptions{}
5140

@@ -135,16 +124,6 @@ func WithFanInMergeConfig(confs map[string]FanInMergeConfig) GraphCompileOption
135124
}
136125
}
137126

138-
// WithCheckpointConfig sets the checkpoint configuration for the graph.
139-
// This configuration is inherited by subgraphs.
140-
// CheckpointConfig.PersistRerunInput enables persisting rerun node inputs in checkpoint,
141-
// allowing nodes to be resumed with their original inputs after an interrupt.
142-
func WithCheckpointConfig(config CheckpointConfig) GraphCompileOption {
143-
return func(o *graphCompileOptions) {
144-
o.checkpointConfig = &config
145-
}
146-
}
147-
148127
// InitGraphCompileCallbacks set global graph compile callbacks,
149128
// which ONLY will be added to top level graph compile options
150129
func InitGraphCompileCallbacks(cbs []GraphCompileCallback) {

0 commit comments

Comments
 (0)