Skip to content

Commit 3332064

Browse files
feat: auto-enable PersistRerunInput when WithGraphInterrupt is used
- Modified initTaskManager to automatically enable PersistRerunInput when WithGraphInterrupt is used and no explicit CheckpointConfig is set - This makes external interrupts work correctly by preserving node inputs - User's explicit CheckpointConfig settings are still respected - Added TestToolsNodeWithExternalGraphInterrupt test - Added TestExternalInterruptRespectsExplicitPersistRerunInputFalse test - Updated TestCancelInterrupt expectations to reflect new behavior Change-Id: I63d0c1812a5b20934eb6b7c62f53f55299a0c9d5
1 parent 3b91fdd commit 3332064

File tree

2 files changed

+195
-8
lines changed

2 files changed

+195
-8
lines changed

compose/checkpoint_test.go

Lines changed: 189 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,16 @@ import (
2020
"context"
2121
"errors"
2222
"io"
23+
"sync"
2324
"testing"
2425
"time"
2526

27+
"github.com/bytedance/sonic"
2628
"github.com/stretchr/testify/assert"
2729

30+
"github.com/cloudwego/eino/components/tool"
2831
"github.com/cloudwego/eino/internal/callbacks"
32+
"github.com/cloudwego/eino/internal/generic"
2933
"github.com/cloudwego/eino/internal/serialization"
3034
"github.com/cloudwego/eino/schema"
3135
)
@@ -1397,7 +1401,7 @@ func TestCancelInterrupt(t *testing.T) {
13971401
assert.NoError(t, err)
13981402
assert.Equal(t, "input12", result)
13991403

1400-
// interrupt rerun nodes
1404+
// interrupt rerun nodes - with auto-enabled PersistRerunInput, input is preserved
14011405
canceledCtx, cancel = WithGraphInterrupt(ctx)
14021406
go func() {
14031407
time.Sleep(500 * time.Millisecond)
@@ -1410,7 +1414,7 @@ func TestCancelInterrupt(t *testing.T) {
14101414
assert.Equal(t, []string{"1"}, info.RerunNodes)
14111415
result, err = r.Invoke(ctx, "input", WithCheckPointID("3"))
14121416
assert.NoError(t, err)
1413-
assert.Equal(t, "12", result)
1417+
assert.Equal(t, "input12", result)
14141418

14151419
// dag
14161420
g = NewGraph[string, string]()
@@ -1455,7 +1459,7 @@ func TestCancelInterrupt(t *testing.T) {
14551459
assert.NoError(t, err)
14561460
assert.Equal(t, "input12", result)
14571461

1458-
// interrupt rerun nodes
1462+
// interrupt rerun nodes - with auto-enabled PersistRerunInput, input is preserved
14591463
canceledCtx, cancel = WithGraphInterrupt(ctx)
14601464
go func() {
14611465
time.Sleep(300 * time.Millisecond)
@@ -1468,7 +1472,7 @@ func TestCancelInterrupt(t *testing.T) {
14681472
assert.Equal(t, []string{"1"}, info.RerunNodes)
14691473
result, err = r.Invoke(ctx, "input", WithCheckPointID("3"))
14701474
assert.NoError(t, err)
1471-
assert.Equal(t, "12", result)
1475+
assert.Equal(t, "input12", result)
14721476

14731477
// dag multi canceled nodes
14741478
gg := NewGraph[string, map[string]any]()
@@ -1513,7 +1517,7 @@ func TestCancelInterrupt(t *testing.T) {
15131517
"3": "input13",
15141518
}, result2)
15151519

1516-
// interrupt rerun nodes
1520+
// interrupt rerun nodes - with auto-enabled PersistRerunInput, input is preserved
15171521
canceledCtx, cancel = WithGraphInterrupt(ctx)
15181522
go func() {
15191523
time.Sleep(500 * time.Millisecond)
@@ -1527,8 +1531,8 @@ func TestCancelInterrupt(t *testing.T) {
15271531
result2, err = rr.Invoke(ctx, "input", WithCheckPointID("2"))
15281532
assert.NoError(t, err)
15291533
assert.Equal(t, map[string]any{
1530-
"2": "2",
1531-
"3": "3",
1534+
"2": "input12",
1535+
"3": "input13",
15321536
}, result2)
15331537
}
15341538

@@ -1817,3 +1821,181 @@ func TestPersistRerunInputSubGraph(t *testing.T) {
18171821
assert.Equal(t, "test_main", receivedInput)
18181822
assert.Equal(t, 2, callCount)
18191823
}
1824+
1825+
type longRunningToolInput struct {
1826+
Input string `json:"input"`
1827+
}
1828+
1829+
func TestToolsNodeWithExternalGraphInterrupt(t *testing.T) {
1830+
store := newInMemoryStore()
1831+
ctx := context.Background()
1832+
1833+
var mu sync.Mutex
1834+
var callCount int
1835+
1836+
longRunningToolInfo := &schema.ToolInfo{
1837+
Name: "long_running_tool",
1838+
Desc: "A tool that takes a long time to run",
1839+
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
1840+
"input": {Type: "string", Desc: "input"},
1841+
}),
1842+
}
1843+
1844+
longRunningTool := newCheckpointTestTool(longRunningToolInfo, func(ctx context.Context, in *longRunningToolInput) (string, error) {
1845+
mu.Lock()
1846+
callCount++
1847+
currentCount := callCount
1848+
mu.Unlock()
1849+
1850+
if currentCount == 1 {
1851+
time.Sleep(2 * time.Second)
1852+
}
1853+
return "result_" + in.Input, nil
1854+
})
1855+
1856+
toolsNode, err := NewToolNode(ctx, &ToolsNodeConfig{
1857+
Tools: []tool.BaseTool{longRunningTool},
1858+
})
1859+
assert.NoError(t, err)
1860+
1861+
g := NewGraph[*schema.Message, []*schema.Message]()
1862+
err = g.AddToolsNode("tools", toolsNode)
1863+
assert.NoError(t, err)
1864+
err = g.AddEdge(START, "tools")
1865+
assert.NoError(t, err)
1866+
err = g.AddEdge("tools", END)
1867+
assert.NoError(t, err)
1868+
1869+
r, err := g.Compile(ctx,
1870+
WithNodeTriggerMode(AllPredecessor),
1871+
WithCheckPointStore(store),
1872+
)
1873+
assert.NoError(t, err)
1874+
1875+
inputMsg := &schema.Message{
1876+
Role: schema.Assistant,
1877+
ToolCalls: []schema.ToolCall{{
1878+
ID: "call_1",
1879+
Type: "function",
1880+
Function: schema.FunctionCall{
1881+
Name: "long_running_tool",
1882+
Arguments: `{"input": "test"}`,
1883+
},
1884+
}},
1885+
}
1886+
1887+
canceledCtx, cancel := WithGraphInterrupt(ctx)
1888+
go func() {
1889+
time.Sleep(100 * time.Millisecond)
1890+
cancel(WithGraphInterruptTimeout(0))
1891+
}()
1892+
1893+
_, err = r.Invoke(canceledCtx, inputMsg, WithCheckPointID("cp1"))
1894+
assert.Error(t, err)
1895+
info, ok := ExtractInterruptInfo(err)
1896+
assert.True(t, ok, "Expected interrupt error, got: %v", err)
1897+
if ok {
1898+
assert.Equal(t, []string{"tools"}, info.RerunNodes)
1899+
}
1900+
1901+
result, err := r.Invoke(ctx, &schema.Message{}, WithCheckPointID("cp1"))
1902+
assert.NoError(t, err)
1903+
assert.Len(t, result, 1)
1904+
assert.Equal(t, `"result_test"`, result[0].Content)
1905+
1906+
mu.Lock()
1907+
assert.Equal(t, 2, callCount)
1908+
mu.Unlock()
1909+
}
1910+
1911+
func TestExternalInterruptRespectsExplicitPersistRerunInputFalse(t *testing.T) {
1912+
store := newInMemoryStore()
1913+
ctx := context.Background()
1914+
1915+
var mu sync.Mutex
1916+
var callCount int
1917+
var receivedInputOnResume string
1918+
1919+
g := NewGraph[string, string]()
1920+
err := g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
1921+
mu.Lock()
1922+
callCount++
1923+
currentCount := callCount
1924+
mu.Unlock()
1925+
1926+
if currentCount == 1 {
1927+
time.Sleep(2 * time.Second)
1928+
}
1929+
if currentCount == 2 {
1930+
mu.Lock()
1931+
receivedInputOnResume = input
1932+
mu.Unlock()
1933+
}
1934+
return input + "_processed", nil
1935+
}))
1936+
assert.NoError(t, err)
1937+
1938+
err = g.AddEdge(START, "1")
1939+
assert.NoError(t, err)
1940+
err = g.AddEdge("1", END)
1941+
assert.NoError(t, err)
1942+
1943+
r, err := g.Compile(ctx,
1944+
WithNodeTriggerMode(AllPredecessor),
1945+
WithCheckPointStore(store),
1946+
WithCheckpointConfig(CheckpointConfig{PersistRerunInput: false}),
1947+
)
1948+
assert.NoError(t, err)
1949+
1950+
canceledCtx, cancel := WithGraphInterrupt(ctx)
1951+
go func() {
1952+
time.Sleep(100 * time.Millisecond)
1953+
cancel(WithGraphInterruptTimeout(0))
1954+
}()
1955+
1956+
_, err = r.Invoke(canceledCtx, "test_input", WithCheckPointID("cp1"))
1957+
assert.Error(t, err)
1958+
info, ok := ExtractInterruptInfo(err)
1959+
assert.True(t, ok, "Expected interrupt error, got: %v", err)
1960+
if ok {
1961+
assert.Equal(t, []string{"1"}, info.RerunNodes)
1962+
}
1963+
1964+
result, err := r.Invoke(ctx, "", WithCheckPointID("cp1"))
1965+
assert.NoError(t, err)
1966+
assert.Equal(t, "_processed", result)
1967+
1968+
mu.Lock()
1969+
assert.Equal(t, "", receivedInputOnResume)
1970+
assert.Equal(t, 2, callCount)
1971+
mu.Unlock()
1972+
}
1973+
1974+
type checkpointTestTool[I, O any] struct {
1975+
info *schema.ToolInfo
1976+
fn func(ctx context.Context, in I) (O, error)
1977+
}
1978+
1979+
func newCheckpointTestTool[I, O any](info *schema.ToolInfo, f func(ctx context.Context, in I) (O, error)) tool.InvokableTool {
1980+
return &checkpointTestTool[I, O]{
1981+
info: info,
1982+
fn: f,
1983+
}
1984+
}
1985+
1986+
func (f *checkpointTestTool[I, O]) Info(ctx context.Context) (*schema.ToolInfo, error) {
1987+
return f.info, nil
1988+
}
1989+
1990+
func (f *checkpointTestTool[I, O]) InvokableRun(ctx context.Context, argumentsInJSON string, _ ...tool.Option) (string, error) {
1991+
t := generic.NewInstance[I]()
1992+
err := sonic.UnmarshalString(argumentsInJSON, t)
1993+
if err != nil {
1994+
return "", err
1995+
}
1996+
o, err := f.fn(ctx, t)
1997+
if err != nil {
1998+
return "", err
1999+
}
2000+
return sonic.MarshalString(o)
2001+
}

compose/graph_run.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -912,13 +912,18 @@ func (r *runner) calculateBranch(ctx context.Context, curNodeKey string, startCh
912912
}
913913

914914
func (r *runner) initTaskManager(runWrapper runnableCallWrapper, cancelVal *graphCancelChanVal, opts ...Option) *taskManager {
915+
checkpointConfig := r.checkpointConfig
916+
if cancelVal != nil && checkpointConfig == nil {
917+
checkpointConfig = &CheckpointConfig{PersistRerunInput: true}
918+
}
919+
915920
tm := &taskManager{
916921
runWrapper: runWrapper,
917922
opts: opts,
918923
needAll: !r.eager,
919924
done: internal.NewUnboundedChan[*task](),
920925
runningTasks: make(map[string]*task),
921-
checkpointConfig: r.checkpointConfig,
926+
checkpointConfig: checkpointConfig,
922927
}
923928
if cancelVal != nil {
924929
tm.cancelCh = cancelVal.ch

0 commit comments

Comments
 (0)