@@ -20,12 +20,16 @@ import (
2020 "context"
2121 "errors"
2222 "io"
23+ "sync"
2324 "testing"
2425 "time"
2526
27+ "github.com/bytedance/sonic"
2628 "github.com/stretchr/testify/assert"
2729
30+ "github.com/cloudwego/eino/components/tool"
2831 "github.com/cloudwego/eino/internal/callbacks"
32+ "github.com/cloudwego/eino/internal/generic"
2933 "github.com/cloudwego/eino/internal/serialization"
3034 "github.com/cloudwego/eino/schema"
3135)
@@ -1397,7 +1401,7 @@ func TestCancelInterrupt(t *testing.T) {
13971401 assert .NoError (t , err )
13981402 assert .Equal (t , "input12" , result )
13991403
1400- // interrupt rerun nodes
1404+ // interrupt rerun nodes - with auto-enabled PersistRerunInput, input is preserved
14011405 canceledCtx , cancel = WithGraphInterrupt (ctx )
14021406 go func () {
14031407 time .Sleep (500 * time .Millisecond )
@@ -1410,7 +1414,7 @@ func TestCancelInterrupt(t *testing.T) {
14101414 assert .Equal (t , []string {"1" }, info .RerunNodes )
14111415 result , err = r .Invoke (ctx , "input" , WithCheckPointID ("3" ))
14121416 assert .NoError (t , err )
1413- assert .Equal (t , "12 " , result )
1417+ assert .Equal (t , "input12 " , result )
14141418
14151419 // dag
14161420 g = NewGraph [string , string ]()
@@ -1455,7 +1459,7 @@ func TestCancelInterrupt(t *testing.T) {
14551459 assert .NoError (t , err )
14561460 assert .Equal (t , "input12" , result )
14571461
1458- // interrupt rerun nodes
1462+ // interrupt rerun nodes - with auto-enabled PersistRerunInput, input is preserved
14591463 canceledCtx , cancel = WithGraphInterrupt (ctx )
14601464 go func () {
14611465 time .Sleep (300 * time .Millisecond )
@@ -1468,7 +1472,7 @@ func TestCancelInterrupt(t *testing.T) {
14681472 assert .Equal (t , []string {"1" }, info .RerunNodes )
14691473 result , err = r .Invoke (ctx , "input" , WithCheckPointID ("3" ))
14701474 assert .NoError (t , err )
1471- assert .Equal (t , "12 " , result )
1475+ assert .Equal (t , "input12 " , result )
14721476
14731477 // dag multi canceled nodes
14741478 gg := NewGraph [string , map [string ]any ]()
@@ -1513,7 +1517,7 @@ func TestCancelInterrupt(t *testing.T) {
15131517 "3" : "input13" ,
15141518 }, result2 )
15151519
1516- // interrupt rerun nodes
1520+ // interrupt rerun nodes - with auto-enabled PersistRerunInput, input is preserved
15171521 canceledCtx , cancel = WithGraphInterrupt (ctx )
15181522 go func () {
15191523 time .Sleep (500 * time .Millisecond )
@@ -1527,8 +1531,8 @@ func TestCancelInterrupt(t *testing.T) {
15271531 result2 , err = rr .Invoke (ctx , "input" , WithCheckPointID ("2" ))
15281532 assert .NoError (t , err )
15291533 assert .Equal (t , map [string ]any {
1530- "2" : "2 " ,
1531- "3" : "3 " ,
1534+ "2" : "input12 " ,
1535+ "3" : "input13 " ,
15321536 }, result2 )
15331537}
15341538
@@ -1817,3 +1821,181 @@ func TestPersistRerunInputSubGraph(t *testing.T) {
18171821 assert .Equal (t , "test_main" , receivedInput )
18181822 assert .Equal (t , 2 , callCount )
18191823}
1824+
1825+ type longRunningToolInput struct {
1826+ Input string `json:"input"`
1827+ }
1828+
1829+ func TestToolsNodeWithExternalGraphInterrupt (t * testing.T ) {
1830+ store := newInMemoryStore ()
1831+ ctx := context .Background ()
1832+
1833+ var mu sync.Mutex
1834+ var callCount int
1835+
1836+ longRunningToolInfo := & schema.ToolInfo {
1837+ Name : "long_running_tool" ,
1838+ Desc : "A tool that takes a long time to run" ,
1839+ ParamsOneOf : schema .NewParamsOneOfByParams (map [string ]* schema.ParameterInfo {
1840+ "input" : {Type : "string" , Desc : "input" },
1841+ }),
1842+ }
1843+
1844+ longRunningTool := newCheckpointTestTool (longRunningToolInfo , func (ctx context.Context , in * longRunningToolInput ) (string , error ) {
1845+ mu .Lock ()
1846+ callCount ++
1847+ currentCount := callCount
1848+ mu .Unlock ()
1849+
1850+ if currentCount == 1 {
1851+ time .Sleep (2 * time .Second )
1852+ }
1853+ return "result_" + in .Input , nil
1854+ })
1855+
1856+ toolsNode , err := NewToolNode (ctx , & ToolsNodeConfig {
1857+ Tools : []tool.BaseTool {longRunningTool },
1858+ })
1859+ assert .NoError (t , err )
1860+
1861+ g := NewGraph [* schema.Message , []* schema.Message ]()
1862+ err = g .AddToolsNode ("tools" , toolsNode )
1863+ assert .NoError (t , err )
1864+ err = g .AddEdge (START , "tools" )
1865+ assert .NoError (t , err )
1866+ err = g .AddEdge ("tools" , END )
1867+ assert .NoError (t , err )
1868+
1869+ r , err := g .Compile (ctx ,
1870+ WithNodeTriggerMode (AllPredecessor ),
1871+ WithCheckPointStore (store ),
1872+ )
1873+ assert .NoError (t , err )
1874+
1875+ inputMsg := & schema.Message {
1876+ Role : schema .Assistant ,
1877+ ToolCalls : []schema.ToolCall {{
1878+ ID : "call_1" ,
1879+ Type : "function" ,
1880+ Function : schema.FunctionCall {
1881+ Name : "long_running_tool" ,
1882+ Arguments : `{"input": "test"}` ,
1883+ },
1884+ }},
1885+ }
1886+
1887+ canceledCtx , cancel := WithGraphInterrupt (ctx )
1888+ go func () {
1889+ time .Sleep (100 * time .Millisecond )
1890+ cancel (WithGraphInterruptTimeout (0 ))
1891+ }()
1892+
1893+ _ , err = r .Invoke (canceledCtx , inputMsg , WithCheckPointID ("cp1" ))
1894+ assert .Error (t , err )
1895+ info , ok := ExtractInterruptInfo (err )
1896+ assert .True (t , ok , "Expected interrupt error, got: %v" , err )
1897+ if ok {
1898+ assert .Equal (t , []string {"tools" }, info .RerunNodes )
1899+ }
1900+
1901+ result , err := r .Invoke (ctx , & schema.Message {}, WithCheckPointID ("cp1" ))
1902+ assert .NoError (t , err )
1903+ assert .Len (t , result , 1 )
1904+ assert .Equal (t , `"result_test"` , result [0 ].Content )
1905+
1906+ mu .Lock ()
1907+ assert .Equal (t , 2 , callCount )
1908+ mu .Unlock ()
1909+ }
1910+
1911+ func TestExternalInterruptRespectsExplicitPersistRerunInputFalse (t * testing.T ) {
1912+ store := newInMemoryStore ()
1913+ ctx := context .Background ()
1914+
1915+ var mu sync.Mutex
1916+ var callCount int
1917+ var receivedInputOnResume string
1918+
1919+ g := NewGraph [string , string ]()
1920+ err := g .AddLambdaNode ("1" , InvokableLambda (func (ctx context.Context , input string ) (output string , err error ) {
1921+ mu .Lock ()
1922+ callCount ++
1923+ currentCount := callCount
1924+ mu .Unlock ()
1925+
1926+ if currentCount == 1 {
1927+ time .Sleep (2 * time .Second )
1928+ }
1929+ if currentCount == 2 {
1930+ mu .Lock ()
1931+ receivedInputOnResume = input
1932+ mu .Unlock ()
1933+ }
1934+ return input + "_processed" , nil
1935+ }))
1936+ assert .NoError (t , err )
1937+
1938+ err = g .AddEdge (START , "1" )
1939+ assert .NoError (t , err )
1940+ err = g .AddEdge ("1" , END )
1941+ assert .NoError (t , err )
1942+
1943+ r , err := g .Compile (ctx ,
1944+ WithNodeTriggerMode (AllPredecessor ),
1945+ WithCheckPointStore (store ),
1946+ WithCheckpointConfig (CheckpointConfig {PersistRerunInput : false }),
1947+ )
1948+ assert .NoError (t , err )
1949+
1950+ canceledCtx , cancel := WithGraphInterrupt (ctx )
1951+ go func () {
1952+ time .Sleep (100 * time .Millisecond )
1953+ cancel (WithGraphInterruptTimeout (0 ))
1954+ }()
1955+
1956+ _ , err = r .Invoke (canceledCtx , "test_input" , WithCheckPointID ("cp1" ))
1957+ assert .Error (t , err )
1958+ info , ok := ExtractInterruptInfo (err )
1959+ assert .True (t , ok , "Expected interrupt error, got: %v" , err )
1960+ if ok {
1961+ assert .Equal (t , []string {"1" }, info .RerunNodes )
1962+ }
1963+
1964+ result , err := r .Invoke (ctx , "" , WithCheckPointID ("cp1" ))
1965+ assert .NoError (t , err )
1966+ assert .Equal (t , "_processed" , result )
1967+
1968+ mu .Lock ()
1969+ assert .Equal (t , "" , receivedInputOnResume )
1970+ assert .Equal (t , 2 , callCount )
1971+ mu .Unlock ()
1972+ }
1973+
1974+ type checkpointTestTool [I , O any ] struct {
1975+ info * schema.ToolInfo
1976+ fn func (ctx context.Context , in I ) (O , error )
1977+ }
1978+
1979+ func newCheckpointTestTool [I , O any ](info * schema.ToolInfo , f func (ctx context.Context , in I ) (O , error )) tool.InvokableTool {
1980+ return & checkpointTestTool [I , O ]{
1981+ info : info ,
1982+ fn : f ,
1983+ }
1984+ }
1985+
1986+ func (f * checkpointTestTool [I , O ]) Info (ctx context.Context ) (* schema.ToolInfo , error ) {
1987+ return f .info , nil
1988+ }
1989+
1990+ func (f * checkpointTestTool [I , O ]) InvokableRun (ctx context.Context , argumentsInJSON string , _ ... tool.Option ) (string , error ) {
1991+ t := generic .NewInstance [I ]()
1992+ err := sonic .UnmarshalString (argumentsInJSON , t )
1993+ if err != nil {
1994+ return "" , err
1995+ }
1996+ o , err := f .fn (ctx , t )
1997+ if err != nil {
1998+ return "" , err
1999+ }
2000+ return sonic .MarshalString (o )
2001+ }
0 commit comments