@@ -1183,8 +1183,7 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
11831183 return subtensor_merge_replacements
11841184
11851185
1186- @node_rewriter ([Scan ])
1187- def scan_save_mem (fgraph , node ):
1186+ def scan_save_mem_rewrite (fgraph , node , backend_supports_output_pre_allocation : bool ):
11881187 r"""Graph optimizer that reduces scan memory consumption.
11891188
11901189 This optimizations attempts to determine if a `Scan` node, during its execution,
@@ -1215,10 +1214,16 @@ def scan_save_mem(fgraph, node):
12151214
12161215 The scan perform implementation takes the output sizes into consideration,
12171216 saving the newest results over the oldest ones whenever the buffer is filled.
1218- """
1219- if not isinstance (node .op , Scan ):
1220- return False
12211217
1218+ Paramaters
1219+ ----------
1220+ backend_supports_output_pre_allocation: bool
1221+ When the backend supports output pre-allocation Scan must keep buffers
1222+ with a length of required_states + 1, because the inner function will
1223+ attempt to write the inner function outputs directly into the provided
1224+ position in the outer circular buffer. This would invalidate results,
1225+ if the input is still needed for some other output computation.
1226+ """
12221227 if hasattr (fgraph , "shape_feature" ):
12231228 shape_of = fgraph .shape_feature .shape_of
12241229 else :
@@ -1487,7 +1492,10 @@ def scan_save_mem(fgraph, node):
14871492 # for mitsots and sitsots (because mitmots are not
14881493 # currently supported by the mechanism) and only if
14891494 # the pre-allocation mechanism is activated.
1490- prealloc_outs = config .scan__allow_output_prealloc
1495+ prealloc_outs = (
1496+ backend_supports_output_pre_allocation
1497+ and config .scan__allow_output_prealloc
1498+ )
14911499
14921500 first_mitsot_idx = op_info .n_mit_mot
14931501 last_sitsot_idx = (
@@ -1496,6 +1504,8 @@ def scan_save_mem(fgraph, node):
14961504 preallocable_output = first_mitsot_idx <= i <= last_sitsot_idx
14971505
14981506 if prealloc_outs and preallocable_output :
1507+ # TODO: If there's only one output or other outputs do not depend
1508+ # on the same input, we could reduce the buffer size to the minimum
14991509 pval = select_max (nw_steps - start + init_l [i ], init_l [i ] + 1 )
15001510 else :
15011511 pval = select_max (nw_steps - start + init_l [i ], init_l [i ])
@@ -1781,6 +1791,20 @@ def scan_save_mem(fgraph, node):
17811791 return False
17821792
17831793
1794+ @node_rewriter ([Scan ])
1795+ def scan_save_mem_prealloc (fgraph , node ):
1796+ return scan_save_mem_rewrite (
1797+ fgraph , node , backend_supports_output_pre_allocation = True
1798+ )
1799+
1800+
1801+ @node_rewriter ([Scan ])
1802+ def scan_save_mem_no_prealloc (fgraph , node ):
1803+ return scan_save_mem_rewrite (
1804+ fgraph , node , backend_supports_output_pre_allocation = False
1805+ )
1806+
1807+
17841808class ScanMerge (GraphRewriter ):
17851809 r"""Graph optimizer that merges different scan ops.
17861810
@@ -2508,12 +2532,21 @@ def scan_push_out_dot1(fgraph, node):
25082532optdb .register ("scan_eqopt2" , scan_eqopt2 , "fast_run" , "scan" , position = 1.6 )
25092533# ScanSaveMem should execute only once per node.
25102534optdb .register (
2511- "scan_save_mem " ,
2512- in2out (scan_save_mem , ignore_newtrees = True ),
2535+ "scan_save_mem_prealloc " ,
2536+ in2out (scan_save_mem_prealloc , ignore_newtrees = True ),
25132537 "fast_run" ,
25142538 "scan" ,
25152539 position = 1.61 ,
25162540)
2541+ optdb .register (
2542+ "scan_save_mem_no_prealloc" ,
2543+ in2out (scan_save_mem_no_prealloc , ignore_newtrees = True ),
2544+ "numba" ,
2545+ "jax" ,
2546+ "pytorch" ,
2547+ "scan" ,
2548+ position = 1.61 ,
2549+ )
25172550optdb .register (
25182551 "scan_make_inplace" ,
25192552 ScanInplaceOptimizer (),
0 commit comments