Skip to content

Commit 2fc2efe

Browse files
committed
Benchmark Scan buffer optimization in Numba
1 parent 9bbd244 commit 2fc2efe

File tree

1 file changed

+119
-33
lines changed

1 file changed

+119
-33
lines changed

tests/link/numba/test_scan.py

Lines changed: 119 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -339,39 +339,6 @@ def power_step(prior_result, x):
339339
compare_numba_and_py([A], result, test_input_vals)
340340

341341

342-
@pytest.mark.parametrize("n_steps_val", [1, 5])
343-
def test_scan_save_mem_basic(n_steps_val):
344-
"""Make sure we can handle storage changes caused by the `scan_save_mem` rewrite."""
345-
346-
def f_pow2(x_tm2, x_tm1):
347-
return 2 * x_tm1 + x_tm2
348-
349-
init_x = pt.dvector("init_x")
350-
n_steps = pt.iscalar("n_steps")
351-
output, _ = scan(
352-
f_pow2,
353-
sequences=[],
354-
outputs_info=[{"initial": init_x, "taps": [-2, -1]}],
355-
non_sequences=[],
356-
n_steps=n_steps,
357-
)
358-
359-
state_val = np.array([1.0, 2.0])
360-
361-
numba_mode = get_mode("NUMBA").including("scan_save_mem")
362-
py_mode = Mode("py").including("scan_save_mem")
363-
364-
test_input_vals = (state_val, n_steps_val)
365-
366-
compare_numba_and_py(
367-
[init_x, n_steps],
368-
[output],
369-
test_input_vals,
370-
numba_mode=numba_mode,
371-
py_mode=py_mode,
372-
)
373-
374-
375342
def test_grad_sitsot():
376343
def get_sum_of_grad(inp):
377344
scan_outputs, updates = scan(
@@ -482,3 +449,122 @@ def step(seq1, seq2, mitsot1, mitsot2, sitsot1):
482449
np.testing.assert_array_almost_equal(numba_r, ref_r)
483450

484451
benchmark(numba_fn, *test.values())
452+
453+
454+
@pytest.mark.parametrize(
455+
"buffer_size", ("unit", "aligned", "misaligned", "whole", "whole+init")
456+
)
457+
@pytest.mark.parametrize("n_steps, op_size", [(10, 2), (512, 2), (512, 256)])
458+
@pytest.mark.parametrize("constant_n_steps", [False, True])
459+
@pytest.mark.parametrize("n_steps_val", [1, 1000])
460+
class TestScanSITSOTBuffer:
461+
def buffer_tester(self, n_steps, op_size, buffer_size, benchmark=None):
462+
x0 = pt.vector(shape=(op_size,), dtype="float64")
463+
xs, _ = pytensor.scan(
464+
fn=lambda xtm1: (xtm1 + 1),
465+
outputs_info=[x0],
466+
n_steps=n_steps - 1, # 1- makes it easier to align/misalign
467+
)
468+
if buffer_size == "unit":
469+
xs_kept = xs[-1] # Only last state is used
470+
expected_buffer_size = 2
471+
elif buffer_size == "aligned":
472+
xs_kept = xs[-2:] # The buffer will be aligned at the end of the 9 steps
473+
expected_buffer_size = 2
474+
elif buffer_size == "misaligned":
475+
xs_kept = xs[-3:] # The buffer will be misaligned at the end of the 9 steps
476+
expected_buffer_size = 3
477+
elif buffer_size == "whole":
478+
xs_kept = xs # What users think is the whole buffer
479+
expected_buffer_size = n_steps - 1
480+
elif buffer_size == "whole+init":
481+
xs_kept = xs.owner.inputs[0] # Whole buffer actually used by Scan
482+
expected_buffer_size = n_steps
483+
484+
x_test = np.zeros(x0.type.shape)
485+
numba_fn, _ = compare_numba_and_py(
486+
[x0],
487+
[xs_kept],
488+
test_inputs=[x_test],
489+
numba_mode="NUMBA", # Default doesn't include optimizations
490+
eval_obj_mode=False,
491+
)
492+
[scan_node] = [
493+
node
494+
for node in numba_fn.maker.fgraph.toposort()
495+
if isinstance(node.op, Scan)
496+
]
497+
buffer = scan_node.inputs[1]
498+
assert buffer.type.shape[0] == expected_buffer_size
499+
500+
if benchmark is not None:
501+
numba_fn.trust_input = True
502+
benchmark(numba_fn, x_test)
503+
504+
def test_buffer(self, n_steps, op_size, buffer_size):
505+
self.buffer_tester(n_steps, op_size, buffer_size, benchmark=None)
506+
507+
def test_buffer_benchmark(self, n_steps, op_size, buffer_size, benchmark):
508+
self.buffer_tester(n_steps, op_size, buffer_size, benchmark=benchmark)
509+
510+
511+
@pytest.mark.parametrize("constant_n_steps", [False, True])
512+
@pytest.mark.parametrize("n_steps_val", [1, 1000])
513+
class TestScanMITSOTBuffer:
514+
def buffer_tester(self, constant_n_steps, n_steps_val, benchmark=None):
515+
"""Make sure we can handle storage changes caused by the `scan_save_mem` rewrite."""
516+
517+
def f_pow2(x_tm2, x_tm1):
518+
return 2 * x_tm1 + x_tm2
519+
520+
init_x = pt.vector("init_x", shape=(2,))
521+
n_steps = pt.iscalar("n_steps")
522+
output, _ = scan(
523+
f_pow2,
524+
sequences=[],
525+
outputs_info=[{"initial": init_x, "taps": [-2, -1]}],
526+
non_sequences=[],
527+
n_steps=n_steps_val if constant_n_steps else n_steps,
528+
)
529+
530+
init_x_val = np.array([1.0, 2.0], dtype=init_x.type.dtype)
531+
test_vals = (
532+
[init_x_val]
533+
if constant_n_steps
534+
else [init_x_val, np.asarray(n_steps_val, dtype=n_steps.type.dtype)]
535+
)
536+
numba_fn, _ = compare_numba_and_py(
537+
[init_x] if constant_n_steps else [init_x, n_steps],
538+
[output[-1]],
539+
test_vals,
540+
numba_mode="NUMBA",
541+
eval_obj_mode=False,
542+
)
543+
544+
if n_steps_val == 1 and constant_n_steps:
545+
# There's no Scan in the graph when nsteps=constant(1)
546+
return
547+
548+
# Check the buffer size as been optimized
549+
[scan_node] = [
550+
node
551+
for node in numba_fn.maker.fgraph.toposort()
552+
if isinstance(node.op, Scan)
553+
]
554+
[mitsot_buffer] = scan_node.op.outer_mitsot(scan_node.inputs)
555+
mitsot_buffer_shape = mitsot_buffer.shape.eval(
556+
{init_x: init_x_val, n_steps: n_steps_val},
557+
accept_inplace=True,
558+
on_unused_input="ignore",
559+
)
560+
assert tuple(mitsot_buffer_shape) == (3,)
561+
562+
if benchmark is not None:
563+
numba_fn.trust_input = True
564+
benchmark(numba_fn, *test_vals)
565+
566+
def test_buffer(self, constant_n_steps, n_steps_val):
567+
self.buffer_tester(constant_n_steps, n_steps_val, benchmark=None)
568+
569+
def test_buffer_benchmark(self, constant_n_steps, n_steps_val, benchmark):
570+
self.buffer_tester(constant_n_steps, n_steps_val, benchmark=benchmark)

0 commit comments

Comments
 (0)