Skip to content

Commit 57a3462

Browse files
committed
Do more agressive scan memory saves in JIT backends
1 parent c4772cf commit 57a3462

File tree

3 files changed

+58
-18
lines changed

3 files changed

+58
-18
lines changed

pytensor/compile/mode.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,19 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
454454
RewriteDatabaseQuery(include=["fast_run", "py_only"]),
455455
)
456456

457+
NUMBA = Mode(
458+
NumbaLinker(),
459+
RewriteDatabaseQuery(
460+
include=["fast_run", "numba"],
461+
exclude=[
462+
"cxx_only",
463+
"BlasOpt",
464+
"local_careduce_fusion",
465+
"scan_save_mem_prealloc",
466+
],
467+
),
468+
)
469+
457470
JAX = Mode(
458471
JAXLinker(),
459472
RewriteDatabaseQuery(
@@ -463,6 +476,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
463476
"BlasOpt",
464477
"fusion",
465478
"inplace",
479+
"scan_save_mem_prealloc",
466480
],
467481
),
468482
)
@@ -475,17 +489,10 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
475489
"BlasOpt",
476490
"fusion",
477491
"inplace",
478-
"local_uint_constant_indices",
492+
"local_uint_constant_indices" "scan_save_mem_prealloc",
479493
],
480494
),
481495
)
482-
NUMBA = Mode(
483-
NumbaLinker(),
484-
RewriteDatabaseQuery(
485-
include=["fast_run", "numba"],
486-
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
487-
),
488-
)
489496

490497

491498
predefined_modes = {

pytensor/scan/rewriting.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,8 +1183,7 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
11831183
return subtensor_merge_replacements
11841184

11851185

1186-
@node_rewriter([Scan])
1187-
def scan_save_mem(fgraph, node):
1186+
def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: bool):
11881187
r"""Graph optimizer that reduces scan memory consumption.
11891188
11901189
This optimizations attempts to determine if a `Scan` node, during its execution,
@@ -1215,10 +1214,16 @@ def scan_save_mem(fgraph, node):
12151214
12161215
The scan perform implementation takes the output sizes into consideration,
12171216
saving the newest results over the oldest ones whenever the buffer is filled.
1218-
"""
1219-
if not isinstance(node.op, Scan):
1220-
return False
12211217
1218+
Paramaters
1219+
----------
1220+
backend_supports_output_pre_allocation: bool
1221+
When the backend supports output pre-allocation Scan must keep buffers
1222+
with a length of required_states + 1, because the inner function will
1223+
attempt to write the inner function outputs directly into the provided
1224+
position in the outer circular buffer. This would invalidate results,
1225+
if the input is still needed for some other output computation.
1226+
"""
12221227
if hasattr(fgraph, "shape_feature"):
12231228
shape_of = fgraph.shape_feature.shape_of
12241229
else:
@@ -1487,7 +1492,10 @@ def scan_save_mem(fgraph, node):
14871492
# for mitsots and sitsots (because mitmots are not
14881493
# currently supported by the mechanism) and only if
14891494
# the pre-allocation mechanism is activated.
1490-
prealloc_outs = config.scan__allow_output_prealloc
1495+
prealloc_outs = (
1496+
backend_supports_output_pre_allocation
1497+
and config.scan__allow_output_prealloc
1498+
)
14911499

14921500
first_mitsot_idx = op_info.n_mit_mot
14931501
last_sitsot_idx = (
@@ -1496,6 +1504,8 @@ def scan_save_mem(fgraph, node):
14961504
preallocable_output = first_mitsot_idx <= i <= last_sitsot_idx
14971505

14981506
if prealloc_outs and preallocable_output:
1507+
# TODO: If there's only one output or other outputs do not depend
1508+
# on the same input, we could reduce the buffer size to the minimum
14991509
pval = select_max(nw_steps - start + init_l[i], init_l[i] + 1)
15001510
else:
15011511
pval = select_max(nw_steps - start + init_l[i], init_l[i])
@@ -1781,6 +1791,20 @@ def scan_save_mem(fgraph, node):
17811791
return False
17821792

17831793

1794+
@node_rewriter([Scan])
1795+
def scan_save_mem_prealloc(fgraph, node):
1796+
return scan_save_mem_rewrite(
1797+
fgraph, node, backend_supports_output_pre_allocation=True
1798+
)
1799+
1800+
1801+
@node_rewriter([Scan])
1802+
def scan_save_mem_no_prealloc(fgraph, node):
1803+
return scan_save_mem_rewrite(
1804+
fgraph, node, backend_supports_output_pre_allocation=False
1805+
)
1806+
1807+
17841808
class ScanMerge(GraphRewriter):
17851809
r"""Graph optimizer that merges different scan ops.
17861810
@@ -2508,12 +2532,21 @@ def scan_push_out_dot1(fgraph, node):
25082532
optdb.register("scan_eqopt2", scan_eqopt2, "fast_run", "scan", position=1.6)
25092533
# ScanSaveMem should execute only once per node.
25102534
optdb.register(
2511-
"scan_save_mem",
2512-
in2out(scan_save_mem, ignore_newtrees=True),
2535+
"scan_save_mem_prealloc",
2536+
in2out(scan_save_mem_prealloc, ignore_newtrees=True),
25132537
"fast_run",
25142538
"scan",
25152539
position=1.61,
25162540
)
2541+
optdb.register(
2542+
"scan_save_mem_no_prealloc",
2543+
in2out(scan_save_mem_no_prealloc, ignore_newtrees=True),
2544+
"numba",
2545+
"jax",
2546+
"pytorch",
2547+
"scan",
2548+
position=1.61,
2549+
)
25172550
optdb.register(
25182551
"scan_make_inplace",
25192552
ScanInplaceOptimizer(),

tests/link/numba/test_scan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ def buffer_tester(self, n_steps, op_size, buffer_size, benchmark=None):
467467
)
468468
if buffer_size == "unit":
469469
xs_kept = xs[-1] # Only last state is used
470-
expected_buffer_size = 2
470+
expected_buffer_size = 1
471471
elif buffer_size == "aligned":
472472
xs_kept = xs[-2:] # The buffer will be aligned at the end of the 9 steps
473473
expected_buffer_size = 2
@@ -557,7 +557,7 @@ def f_pow2(x_tm2, x_tm1):
557557
accept_inplace=True,
558558
on_unused_input="ignore",
559559
)
560-
assert tuple(mitsot_buffer_shape) == (3,)
560+
assert tuple(mitsot_buffer_shape) == (2,)
561561

562562
if benchmark is not None:
563563
numba_fn.trust_input = True

0 commit comments

Comments
 (0)