@@ -1531,3 +1531,289 @@ func TestCancelInterrupt(t *testing.T) {
15311531 "3" : "3" ,
15321532 }, result2 )
15331533}
1534+
1535+ func TestPersistRerunInputNonStream (t * testing.T ) {
1536+ store := newInMemoryStore ()
1537+
1538+ var receivedInput string
1539+ var callCount int
1540+
1541+ g := NewGraph [string , string ]()
1542+
1543+ err := g .AddLambdaNode ("1" , InvokableLambda (func (ctx context.Context , input string ) (output string , err error ) {
1544+ callCount ++
1545+ receivedInput = input
1546+ if callCount == 1 {
1547+ return "" , Interrupt (ctx , "interrupt" )
1548+ }
1549+ return input + "_processed" , nil
1550+ }))
1551+ assert .NoError (t , err )
1552+
1553+ err = g .AddEdge (START , "1" )
1554+ assert .NoError (t , err )
1555+ err = g .AddEdge ("1" , END )
1556+ assert .NoError (t , err )
1557+
1558+ ctx := context .Background ()
1559+ r , err := g .Compile (ctx ,
1560+ WithNodeTriggerMode (AllPredecessor ),
1561+ WithCheckPointStore (store ),
1562+ WithCheckpointConfig (CheckpointConfig {PersistRerunInput : true }),
1563+ )
1564+ assert .NoError (t , err )
1565+
1566+ _ , err = r .Invoke (ctx , "test_input" , WithCheckPointID ("cp1" ))
1567+ assert .NotNil (t , err )
1568+ info , ok := ExtractInterruptInfo (err )
1569+ assert .True (t , ok )
1570+ assert .Equal (t , []string {"1" }, info .RerunNodes )
1571+
1572+ assert .Equal (t , "test_input" , receivedInput )
1573+
1574+ result , err := r .Invoke (ctx , "" , WithCheckPointID ("cp1" ))
1575+ assert .NoError (t , err )
1576+ assert .Equal (t , "test_input_processed" , result )
1577+ assert .Equal (t , "test_input" , receivedInput )
1578+ assert .Equal (t , 2 , callCount )
1579+ }
1580+
1581+ func TestPersistRerunInputStream (t * testing.T ) {
1582+ store := newInMemoryStore ()
1583+
1584+ var receivedInput string
1585+ var callCount int
1586+
1587+ g := NewGraph [string , string ]()
1588+
1589+ err := g .AddLambdaNode ("1" , TransformableLambda (func (ctx context.Context , input * schema.StreamReader [string ]) (output * schema.StreamReader [string ], err error ) {
1590+ callCount ++
1591+
1592+ var sb string
1593+ for {
1594+ chunk , err := input .Recv ()
1595+ if err == io .EOF {
1596+ break
1597+ }
1598+ if err != nil {
1599+ return nil , err
1600+ }
1601+ sb += chunk
1602+ }
1603+ receivedInput = sb
1604+
1605+ if callCount == 1 {
1606+ return nil , Interrupt (ctx , "interrupt" )
1607+ }
1608+
1609+ return schema .StreamReaderFromArray ([]string {sb + "_processed" }), nil
1610+ }))
1611+ assert .NoError (t , err )
1612+
1613+ err = g .AddEdge (START , "1" )
1614+ assert .NoError (t , err )
1615+ err = g .AddEdge ("1" , END )
1616+ assert .NoError (t , err )
1617+
1618+ ctx := context .Background ()
1619+ r , err := g .Compile (ctx ,
1620+ WithNodeTriggerMode (AllPredecessor ),
1621+ WithCheckPointStore (store ),
1622+ WithCheckpointConfig (CheckpointConfig {PersistRerunInput : true }),
1623+ )
1624+ assert .NoError (t , err )
1625+
1626+ inputStream := schema .StreamReaderFromArray ([]string {"chunk1" , "chunk2" , "chunk3" })
1627+
1628+ _ , err = r .Transform (ctx , inputStream , WithCheckPointID ("cp1" ))
1629+ assert .NotNil (t , err )
1630+ info , ok := ExtractInterruptInfo (err )
1631+ assert .True (t , ok )
1632+ assert .Equal (t , []string {"1" }, info .RerunNodes )
1633+
1634+ assert .Equal (t , "chunk1chunk2chunk3" , receivedInput )
1635+
1636+ emptyInputStream := schema .StreamReaderFromArray ([]string {})
1637+
1638+ resultStream , err := r .Transform (ctx , emptyInputStream , WithCheckPointID ("cp1" ))
1639+ assert .NoError (t , err )
1640+
1641+ var result string
1642+ for {
1643+ chunk , err := resultStream .Recv ()
1644+ if err == io .EOF {
1645+ break
1646+ }
1647+ assert .NoError (t , err )
1648+ result += chunk
1649+ }
1650+
1651+ assert .Equal (t , "chunk1chunk2chunk3_processed" , result )
1652+ assert .Equal (t , "chunk1chunk2chunk3" , receivedInput )
1653+ assert .Equal (t , 2 , callCount )
1654+ }
1655+
1656+ type testPersistRerunInputState struct {
1657+ Prefix string
1658+ }
1659+
1660+ func TestPersistRerunInputWithPreHandler (t * testing.T ) {
1661+ store := newInMemoryStore ()
1662+
1663+ var receivedInput string
1664+ var callCount int
1665+
1666+ schema .Register [testPersistRerunInputState ]()
1667+
1668+ g := NewGraph [string , string ](WithGenLocalState (func (ctx context.Context ) * testPersistRerunInputState {
1669+ return & testPersistRerunInputState {Prefix : "prefix_" }
1670+ }))
1671+
1672+ err := g .AddLambdaNode ("1" , InvokableLambda (func (ctx context.Context , input string ) (output string , err error ) {
1673+ callCount ++
1674+ receivedInput = input
1675+ if callCount == 1 {
1676+ return "" , Interrupt (ctx , "interrupt" )
1677+ }
1678+ return input + "_processed" , nil
1679+ }), WithStatePreHandler (func (ctx context.Context , in string , s * testPersistRerunInputState ) (string , error ) {
1680+ return s .Prefix + in , nil
1681+ }))
1682+ assert .NoError (t , err )
1683+
1684+ err = g .AddEdge (START , "1" )
1685+ assert .NoError (t , err )
1686+ err = g .AddEdge ("1" , END )
1687+ assert .NoError (t , err )
1688+
1689+ ctx := context .Background ()
1690+ r , err := g .Compile (ctx ,
1691+ WithNodeTriggerMode (AllPredecessor ),
1692+ WithCheckPointStore (store ),
1693+ WithCheckpointConfig (CheckpointConfig {PersistRerunInput : true }),
1694+ )
1695+ assert .NoError (t , err )
1696+
1697+ _ , err = r .Invoke (ctx , "test_input" , WithCheckPointID ("cp1" ))
1698+ assert .NotNil (t , err )
1699+ info , ok := ExtractInterruptInfo (err )
1700+ assert .True (t , ok )
1701+ if ok {
1702+ assert .Equal (t , []string {"1" }, info .RerunNodes )
1703+ }
1704+
1705+ assert .Equal (t , "prefix_test_input" , receivedInput )
1706+
1707+ result , err := r .Invoke (ctx , "" , WithCheckPointID ("cp1" ))
1708+ assert .NoError (t , err )
1709+ assert .Equal (t , "prefix_test_input_processed" , result )
1710+ assert .Equal (t , "prefix_test_input" , receivedInput )
1711+ assert .Equal (t , 2 , callCount )
1712+ }
1713+
1714+ func TestPersistRerunInputBackwardCompatibility (t * testing.T ) {
1715+ store := newInMemoryStore ()
1716+
1717+ var receivedInput string
1718+ var callCount int
1719+
1720+ g := NewGraph [string , string ]()
1721+
1722+ err := g .AddLambdaNode ("1" , InvokableLambda (func (ctx context.Context , input string ) (output string , err error ) {
1723+ callCount ++
1724+ receivedInput = input
1725+ if len (input ) > 0 {
1726+ return "" , StatefulInterrupt (ctx , "interrupt" , input )
1727+ }
1728+
1729+ _ , _ , restoredInput := GetInterruptState [string ](ctx )
1730+ return restoredInput + "_processed" , nil
1731+ }))
1732+ assert .NoError (t , err )
1733+
1734+ err = g .AddEdge (START , "1" )
1735+ assert .NoError (t , err )
1736+ err = g .AddEdge ("1" , END )
1737+ assert .NoError (t , err )
1738+
1739+ ctx := context .Background ()
1740+ r , err := g .Compile (ctx ,
1741+ WithNodeTriggerMode (AllPredecessor ),
1742+ WithCheckPointStore (store ),
1743+ )
1744+ assert .NoError (t , err )
1745+
1746+ _ , err = r .Invoke (ctx , "test_input" , WithCheckPointID ("cp1" ))
1747+ assert .NotNil (t , err )
1748+ info , ok := ExtractInterruptInfo (err )
1749+ assert .True (t , ok )
1750+ assert .Equal (t , []string {"1" }, info .RerunNodes )
1751+
1752+ assert .Equal (t , "test_input" , receivedInput )
1753+
1754+ result , err := r .Invoke (ctx , "" , WithCheckPointID ("cp1" ))
1755+ assert .NoError (t , err )
1756+ assert .Equal (t , "test_input_processed" , result )
1757+ assert .Equal (t , "" , receivedInput )
1758+ assert .Equal (t , 2 , callCount )
1759+ }
1760+
1761+ func TestPersistRerunInputSubGraph (t * testing.T ) {
1762+ store := newInMemoryStore ()
1763+
1764+ var receivedInput string
1765+ var callCount int
1766+
1767+ subG := NewGraph [string , string ]()
1768+ err := subG .AddLambdaNode ("sub1" , InvokableLambda (func (ctx context.Context , input string ) (output string , err error ) {
1769+ callCount ++
1770+ receivedInput = input
1771+ if callCount == 1 {
1772+ return "" , Interrupt (ctx , "interrupt" )
1773+ }
1774+ return input + "_sub_processed" , nil
1775+ }))
1776+ assert .NoError (t , err )
1777+ err = subG .AddEdge (START , "sub1" )
1778+ assert .NoError (t , err )
1779+ err = subG .AddEdge ("sub1" , END )
1780+ assert .NoError (t , err )
1781+
1782+ g := NewGraph [string , string ]()
1783+ err = g .AddLambdaNode ("1" , InvokableLambda (func (ctx context.Context , input string ) (output string , err error ) {
1784+ return input + "_main" , nil
1785+ }))
1786+ assert .NoError (t , err )
1787+ err = g .AddGraphNode ("2" , subG )
1788+ assert .NoError (t , err )
1789+ err = g .AddEdge (START , "1" )
1790+ assert .NoError (t , err )
1791+ err = g .AddEdge ("1" , "2" )
1792+ assert .NoError (t , err )
1793+ err = g .AddEdge ("2" , END )
1794+ assert .NoError (t , err )
1795+
1796+ ctx := context .Background ()
1797+ r , err := g .Compile (ctx ,
1798+ WithNodeTriggerMode (AllPredecessor ),
1799+ WithCheckPointStore (store ),
1800+ WithCheckpointConfig (CheckpointConfig {PersistRerunInput : true }),
1801+ )
1802+ assert .NoError (t , err )
1803+
1804+ _ , err = r .Invoke (ctx , "test" , WithCheckPointID ("cp1" ))
1805+ assert .NotNil (t , err )
1806+ info , ok := ExtractInterruptInfo (err )
1807+ assert .True (t , ok )
1808+ assert .Contains (t , info .SubGraphs , "2" )
1809+ subInfo := info .SubGraphs ["2" ]
1810+ assert .Equal (t , []string {"sub1" }, subInfo .RerunNodes )
1811+
1812+ assert .Equal (t , "test_main" , receivedInput )
1813+
1814+ result , err := r .Invoke (ctx , "" , WithCheckPointID ("cp1" ))
1815+ assert .NoError (t , err )
1816+ assert .Equal (t , "test_main_sub_processed" , result )
1817+ assert .Equal (t , "test_main" , receivedInput )
1818+ assert .Equal (t , 2 , callCount )
1819+ }
0 commit comments