Skip to content

Commit c4772cf

Browse files
committed
Do not try to save initial values buffer size in Scan
This will always require a roll at the end, for a minimal gain
1 parent 2fc2efe commit c4772cf

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

pytensor/scan/rewriting.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1430,9 +1430,18 @@ def scan_save_mem(fgraph, node):
14301430
store_steps[i] = 0
14311431
break
14321432

1433-
if isinstance(this_slice[0], slice) and this_slice[0].start is None:
1434-
store_steps[i] = 0
1435-
break
1433+
if isinstance(this_slice[0], slice):
1434+
start = this_slice[0].start
1435+
if isinstance(start, Constant):
1436+
start = start.data
1437+
# Don't do anything if the subtensor is starting from the beginning of the buffer
1438+
# Or just skipping the initial values (default output returned to the user).
1439+
# Trimming the initial values would require a roll to align the buffer once scan is done
1440+
# As it always starts writing at position [0+max(taps)], and ends up at position [:max(taps)]
1441+
# It's cheaper to just keep the initial values in the buffer and slice them away (default output)
1442+
if start in (0, None, init_l[i]):
1443+
store_steps[i] = 0
1444+
break
14361445

14371446
# Special case for recurrent outputs where only the last result
14381447
# is requested. This is needed for this rewrite to apply to

tests/link/numba/test_scan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ def buffer_tester(self, n_steps, op_size, buffer_size, benchmark=None):
476476
expected_buffer_size = 3
477477
elif buffer_size == "whole":
478478
xs_kept = xs # What users think is the whole buffer
479-
expected_buffer_size = n_steps - 1
479+
expected_buffer_size = n_steps
480480
elif buffer_size == "whole+init":
481481
xs_kept = xs.owner.inputs[0] # Whole buffer actually used by Scan
482482
expected_buffer_size = n_steps

0 commit comments

Comments
 (0)