@@ -1539,16 +1539,21 @@ func TestCancelInterrupt(t *testing.T) {
15391539func TestPersistRerunInputNonStream (t * testing.T ) {
15401540 store := newInMemoryStore ()
15411541
1542+ var mu sync.Mutex
15421543 var receivedInput string
15431544 var callCount int
15441545
15451546 g := NewGraph [string , string ]()
15461547
15471548 err := g .AddLambdaNode ("1" , InvokableLambda (func (ctx context.Context , input string ) (output string , err error ) {
1549+ mu .Lock ()
15481550 callCount ++
1551+ currentCount := callCount
15491552 receivedInput = input
1550- if callCount == 1 {
1551- return "" , Interrupt (ctx , "interrupt" )
1553+ mu .Unlock ()
1554+
1555+ if currentCount == 1 {
1556+ time .Sleep (2 * time .Second )
15521557 }
15531558 return input + "_processed" , nil
15541559 }))
@@ -1563,35 +1568,49 @@ func TestPersistRerunInputNonStream(t *testing.T) {
15631568 r , err := g .Compile (ctx ,
15641569 WithNodeTriggerMode (AllPredecessor ),
15651570 WithCheckPointStore (store ),
1566- WithCheckpointConfig (CheckpointConfig {PersistRerunInput : true }),
15671571 )
15681572 assert .NoError (t , err )
15691573
1570- _ , err = r .Invoke (ctx , "test_input" , WithCheckPointID ("cp1" ))
1574+ canceledCtx , cancel := WithGraphInterrupt (ctx )
1575+ go func () {
1576+ time .Sleep (100 * time .Millisecond )
1577+ cancel (WithGraphInterruptTimeout (0 ))
1578+ }()
1579+
1580+ _ , err = r .Invoke (canceledCtx , "test_input" , WithCheckPointID ("cp1" ))
15711581 assert .NotNil (t , err )
15721582 info , ok := ExtractInterruptInfo (err )
15731583 assert .True (t , ok )
15741584 assert .Equal (t , []string {"1" }, info .RerunNodes )
15751585
1586+ mu .Lock ()
15761587 assert .Equal (t , "test_input" , receivedInput )
1588+ mu .Unlock ()
15771589
15781590 result , err := r .Invoke (ctx , "" , WithCheckPointID ("cp1" ))
15791591 assert .NoError (t , err )
15801592 assert .Equal (t , "test_input_processed" , result )
1593+
1594+ mu .Lock ()
15811595 assert .Equal (t , "test_input" , receivedInput )
15821596 assert .Equal (t , 2 , callCount )
1597+ mu .Unlock ()
15831598}
15841599
15851600func TestPersistRerunInputStream (t * testing.T ) {
15861601 store := newInMemoryStore ()
15871602
1603+ var mu sync.Mutex
15881604 var receivedInput string
15891605 var callCount int
15901606
15911607 g := NewGraph [string , string ]()
15921608
15931609 err := g .AddLambdaNode ("1" , TransformableLambda (func (ctx context.Context , input * schema.StreamReader [string ]) (output * schema.StreamReader [string ], err error ) {
1610+ mu .Lock ()
15941611 callCount ++
1612+ currentCount := callCount
1613+ mu .Unlock ()
15951614
15961615 var sb string
15971616 for {
@@ -1604,10 +1623,13 @@ func TestPersistRerunInputStream(t *testing.T) {
16041623 }
16051624 sb += chunk
16061625 }
1626+
1627+ mu .Lock ()
16071628 receivedInput = sb
1629+ mu .Unlock ()
16081630
1609- if callCount == 1 {
1610- return nil , Interrupt ( ctx , "interrupt" )
1631+ if currentCount == 1 {
1632+ time . Sleep ( 2 * time . Second )
16111633 }
16121634
16131635 return schema .StreamReaderFromArray ([]string {sb + "_processed" }), nil
@@ -1623,19 +1645,26 @@ func TestPersistRerunInputStream(t *testing.T) {
16231645 r , err := g .Compile (ctx ,
16241646 WithNodeTriggerMode (AllPredecessor ),
16251647 WithCheckPointStore (store ),
1626- WithCheckpointConfig (CheckpointConfig {PersistRerunInput : true }),
16271648 )
16281649 assert .NoError (t , err )
16291650
16301651 inputStream := schema .StreamReaderFromArray ([]string {"chunk1" , "chunk2" , "chunk3" })
16311652
1632- _ , err = r .Transform (ctx , inputStream , WithCheckPointID ("cp1" ))
1653+ canceledCtx , cancel := WithGraphInterrupt (ctx )
1654+ go func () {
1655+ time .Sleep (100 * time .Millisecond )
1656+ cancel (WithGraphInterruptTimeout (0 ))
1657+ }()
1658+
1659+ _ , err = r .Transform (canceledCtx , inputStream , WithCheckPointID ("cp1" ))
16331660 assert .NotNil (t , err )
16341661 info , ok := ExtractInterruptInfo (err )
16351662 assert .True (t , ok )
16361663 assert .Equal (t , []string {"1" }, info .RerunNodes )
16371664
1665+ mu .Lock ()
16381666 assert .Equal (t , "chunk1chunk2chunk3" , receivedInput )
1667+ mu .Unlock ()
16391668
16401669 emptyInputStream := schema .StreamReaderFromArray ([]string {})
16411670
@@ -1653,8 +1682,11 @@ func TestPersistRerunInputStream(t *testing.T) {
16531682 }
16541683
16551684 assert .Equal (t , "chunk1chunk2chunk3_processed" , result )
1685+
1686+ mu .Lock ()
16561687 assert .Equal (t , "chunk1chunk2chunk3" , receivedInput )
16571688 assert .Equal (t , 2 , callCount )
1689+ mu .Unlock ()
16581690}
16591691
16601692type testPersistRerunInputState struct {
@@ -1664,6 +1696,7 @@ type testPersistRerunInputState struct {
16641696func TestPersistRerunInputWithPreHandler (t * testing.T ) {
16651697 store := newInMemoryStore ()
16661698
1699+ var mu sync.Mutex
16671700 var receivedInput string
16681701 var callCount int
16691702
@@ -1674,10 +1707,14 @@ func TestPersistRerunInputWithPreHandler(t *testing.T) {
16741707 }))
16751708
16761709 err := g .AddLambdaNode ("1" , InvokableLambda (func (ctx context.Context , input string ) (output string , err error ) {
1710+ mu .Lock ()
16771711 callCount ++
1712+ currentCount := callCount
16781713 receivedInput = input
1679- if callCount == 1 {
1680- return "" , Interrupt (ctx , "interrupt" )
1714+ mu .Unlock ()
1715+
1716+ if currentCount == 1 {
1717+ time .Sleep (2 * time .Second )
16811718 }
16821719 return input + "_processed" , nil
16831720 }), WithStatePreHandler (func (ctx context.Context , in string , s * testPersistRerunInputState ) (string , error ) {
@@ -1694,25 +1731,35 @@ func TestPersistRerunInputWithPreHandler(t *testing.T) {
16941731 r , err := g .Compile (ctx ,
16951732 WithNodeTriggerMode (AllPredecessor ),
16961733 WithCheckPointStore (store ),
1697- WithCheckpointConfig (CheckpointConfig {PersistRerunInput : true }),
16981734 )
16991735 assert .NoError (t , err )
17001736
1701- _ , err = r .Invoke (ctx , "test_input" , WithCheckPointID ("cp1" ))
1737+ canceledCtx , cancel := WithGraphInterrupt (ctx )
1738+ go func () {
1739+ time .Sleep (100 * time .Millisecond )
1740+ cancel (WithGraphInterruptTimeout (0 ))
1741+ }()
1742+
1743+ _ , err = r .Invoke (canceledCtx , "test_input" , WithCheckPointID ("cp1" ))
17021744 assert .NotNil (t , err )
17031745 info , ok := ExtractInterruptInfo (err )
17041746 assert .True (t , ok )
17051747 if ok {
17061748 assert .Equal (t , []string {"1" }, info .RerunNodes )
17071749 }
17081750
1751+ mu .Lock ()
17091752 assert .Equal (t , "prefix_test_input" , receivedInput )
1753+ mu .Unlock ()
17101754
17111755 result , err := r .Invoke (ctx , "" , WithCheckPointID ("cp1" ))
17121756 assert .NoError (t , err )
17131757 assert .Equal (t , "prefix_test_input_processed" , result )
1758+
1759+ mu .Lock ()
17141760 assert .Equal (t , "prefix_test_input" , receivedInput )
17151761 assert .Equal (t , 2 , callCount )
1762+ mu .Unlock ()
17161763}
17171764
17181765func TestPersistRerunInputBackwardCompatibility (t * testing.T ) {
@@ -1765,15 +1812,20 @@ func TestPersistRerunInputBackwardCompatibility(t *testing.T) {
17651812func TestPersistRerunInputSubGraph (t * testing.T ) {
17661813 store := newInMemoryStore ()
17671814
1815+ var mu sync.Mutex
17681816 var receivedInput string
17691817 var callCount int
17701818
17711819 subG := NewGraph [string , string ]()
17721820 err := subG .AddLambdaNode ("sub1" , InvokableLambda (func (ctx context.Context , input string ) (output string , err error ) {
1821+ mu .Lock ()
17731822 callCount ++
1823+ currentCount := callCount
17741824 receivedInput = input
1775- if callCount == 1 {
1776- return "" , Interrupt (ctx , "interrupt" )
1825+ mu .Unlock ()
1826+
1827+ if currentCount == 1 {
1828+ time .Sleep (2 * time .Second )
17771829 }
17781830 return input + "_sub_processed" , nil
17791831 }))
@@ -1801,25 +1853,39 @@ func TestPersistRerunInputSubGraph(t *testing.T) {
18011853 r , err := g .Compile (ctx ,
18021854 WithNodeTriggerMode (AllPredecessor ),
18031855 WithCheckPointStore (store ),
1804- WithCheckpointConfig (CheckpointConfig {PersistRerunInput : true }),
18051856 )
18061857 assert .NoError (t , err )
18071858
1808- _ , err = r .Invoke (ctx , "test" , WithCheckPointID ("cp1" ))
1859+ canceledCtx , cancel := WithGraphInterrupt (ctx )
1860+ go func () {
1861+ time .Sleep (100 * time .Millisecond )
1862+ cancel (WithGraphInterruptTimeout (0 ))
1863+ }()
1864+
1865+ _ , err = r .Invoke (canceledCtx , "test" , WithCheckPointID ("cp1" ))
18091866 assert .NotNil (t , err )
18101867 info , ok := ExtractInterruptInfo (err )
1811- assert .True (t , ok )
1812- assert .Contains (t , info .SubGraphs , "2" )
1813- subInfo := info .SubGraphs ["2" ]
1814- assert .Equal (t , []string {"sub1" }, subInfo .RerunNodes )
1868+ assert .True (t , ok , "Expected interrupt error, got: %v" , err )
1869+ if len (info .SubGraphs ) > 0 {
1870+ assert .Contains (t , info .SubGraphs , "2" )
1871+ subInfo := info .SubGraphs ["2" ]
1872+ assert .Equal (t , []string {"sub1" }, subInfo .RerunNodes )
1873+ } else {
1874+ assert .Equal (t , []string {"2" }, info .RerunNodes )
1875+ }
18151876
1877+ mu .Lock ()
18161878 assert .Equal (t , "test_main" , receivedInput )
1879+ mu .Unlock ()
18171880
18181881 result , err := r .Invoke (ctx , "" , WithCheckPointID ("cp1" ))
18191882 assert .NoError (t , err )
18201883 assert .Equal (t , "test_main_sub_processed" , result )
1884+
1885+ mu .Lock ()
18211886 assert .Equal (t , "test_main" , receivedInput )
18221887 assert .Equal (t , 2 , callCount )
1888+ mu .Unlock ()
18231889}
18241890
18251891type longRunningToolInput struct {
@@ -1908,69 +1974,6 @@ func TestToolsNodeWithExternalGraphInterrupt(t *testing.T) {
19081974 mu .Unlock ()
19091975}
19101976
1911- func TestExternalInterruptRespectsExplicitPersistRerunInputFalse (t * testing.T ) {
1912- store := newInMemoryStore ()
1913- ctx := context .Background ()
1914-
1915- var mu sync.Mutex
1916- var callCount int
1917- var receivedInputOnResume string
1918-
1919- g := NewGraph [string , string ]()
1920- err := g .AddLambdaNode ("1" , InvokableLambda (func (ctx context.Context , input string ) (output string , err error ) {
1921- mu .Lock ()
1922- callCount ++
1923- currentCount := callCount
1924- mu .Unlock ()
1925-
1926- if currentCount == 1 {
1927- time .Sleep (2 * time .Second )
1928- }
1929- if currentCount == 2 {
1930- mu .Lock ()
1931- receivedInputOnResume = input
1932- mu .Unlock ()
1933- }
1934- return input + "_processed" , nil
1935- }))
1936- assert .NoError (t , err )
1937-
1938- err = g .AddEdge (START , "1" )
1939- assert .NoError (t , err )
1940- err = g .AddEdge ("1" , END )
1941- assert .NoError (t , err )
1942-
1943- r , err := g .Compile (ctx ,
1944- WithNodeTriggerMode (AllPredecessor ),
1945- WithCheckPointStore (store ),
1946- WithCheckpointConfig (CheckpointConfig {PersistRerunInput : false }),
1947- )
1948- assert .NoError (t , err )
1949-
1950- canceledCtx , cancel := WithGraphInterrupt (ctx )
1951- go func () {
1952- time .Sleep (100 * time .Millisecond )
1953- cancel (WithGraphInterruptTimeout (0 ))
1954- }()
1955-
1956- _ , err = r .Invoke (canceledCtx , "test_input" , WithCheckPointID ("cp1" ))
1957- assert .Error (t , err )
1958- info , ok := ExtractInterruptInfo (err )
1959- assert .True (t , ok , "Expected interrupt error, got: %v" , err )
1960- if ok {
1961- assert .Equal (t , []string {"1" }, info .RerunNodes )
1962- }
1963-
1964- result , err := r .Invoke (ctx , "" , WithCheckPointID ("cp1" ))
1965- assert .NoError (t , err )
1966- assert .Equal (t , "_processed" , result )
1967-
1968- mu .Lock ()
1969- assert .Equal (t , "" , receivedInputOnResume )
1970- assert .Equal (t , 2 , callCount )
1971- mu .Unlock ()
1972- }
1973-
19741977type checkpointTestTool [I , O any ] struct {
19751978 info * schema.ToolInfo
19761979 fn func (ctx context.Context , in I ) (O , error )
0 commit comments