@@ -339,39 +339,6 @@ def power_step(prior_result, x):
339339 compare_numba_and_py ([A ], result , test_input_vals )
340340
341341
342- @pytest .mark .parametrize ("n_steps_val" , [1 , 5 ])
343- def test_scan_save_mem_basic (n_steps_val ):
344- """Make sure we can handle storage changes caused by the `scan_save_mem` rewrite."""
345-
346- def f_pow2 (x_tm2 , x_tm1 ):
347- return 2 * x_tm1 + x_tm2
348-
349- init_x = pt .dvector ("init_x" )
350- n_steps = pt .iscalar ("n_steps" )
351- output , _ = scan (
352- f_pow2 ,
353- sequences = [],
354- outputs_info = [{"initial" : init_x , "taps" : [- 2 , - 1 ]}],
355- non_sequences = [],
356- n_steps = n_steps ,
357- )
358-
359- state_val = np .array ([1.0 , 2.0 ])
360-
361- numba_mode = get_mode ("NUMBA" ).including ("scan_save_mem" )
362- py_mode = Mode ("py" ).including ("scan_save_mem" )
363-
364- test_input_vals = (state_val , n_steps_val )
365-
366- compare_numba_and_py (
367- [init_x , n_steps ],
368- [output ],
369- test_input_vals ,
370- numba_mode = numba_mode ,
371- py_mode = py_mode ,
372- )
373-
374-
375342def test_grad_sitsot ():
376343 def get_sum_of_grad (inp ):
377344 scan_outputs , updates = scan (
@@ -482,3 +449,122 @@ def step(seq1, seq2, mitsot1, mitsot2, sitsot1):
482449 np .testing .assert_array_almost_equal (numba_r , ref_r )
483450
484451 benchmark (numba_fn , * test .values ())
452+
453+
454+ @pytest .mark .parametrize (
455+ "buffer_size" , ("unit" , "aligned" , "misaligned" , "whole" , "whole+init" )
456+ )
457+ @pytest .mark .parametrize ("n_steps, op_size" , [(10 , 2 ), (512 , 2 ), (512 , 256 )])
458+ @pytest .mark .parametrize ("constant_n_steps" , [False , True ])
459+ @pytest .mark .parametrize ("n_steps_val" , [1 , 1000 ])
460+ class TestScanSITSOTBuffer :
461+ def buffer_tester (self , n_steps , op_size , buffer_size , benchmark = None ):
462+ x0 = pt .vector (shape = (op_size ,), dtype = "float64" )
463+ xs , _ = pytensor .scan (
464+ fn = lambda xtm1 : (xtm1 + 1 ),
465+ outputs_info = [x0 ],
466+ n_steps = n_steps - 1 , # 1- makes it easier to align/misalign
467+ )
468+ if buffer_size == "unit" :
469+ xs_kept = xs [- 1 ] # Only last state is used
470+ expected_buffer_size = 2
471+ elif buffer_size == "aligned" :
472+ xs_kept = xs [- 2 :] # The buffer will be aligned at the end of the 9 steps
473+ expected_buffer_size = 2
474+ elif buffer_size == "misaligned" :
475+ xs_kept = xs [- 3 :] # The buffer will be misaligned at the end of the 9 steps
476+ expected_buffer_size = 3
477+ elif buffer_size == "whole" :
478+ xs_kept = xs # What users think is the whole buffer
479+ expected_buffer_size = n_steps - 1
480+ elif buffer_size == "whole+init" :
481+ xs_kept = xs .owner .inputs [0 ] # Whole buffer actually used by Scan
482+ expected_buffer_size = n_steps
483+
484+ x_test = np .zeros (x0 .type .shape )
485+ numba_fn , _ = compare_numba_and_py (
486+ [x0 ],
487+ [xs_kept ],
488+ test_inputs = [x_test ],
489+ numba_mode = "NUMBA" , # Default doesn't include optimizations
490+ eval_obj_mode = False ,
491+ )
492+ [scan_node ] = [
493+ node
494+ for node in numba_fn .maker .fgraph .toposort ()
495+ if isinstance (node .op , Scan )
496+ ]
497+ buffer = scan_node .inputs [1 ]
498+ assert buffer .type .shape [0 ] == expected_buffer_size
499+
500+ if benchmark is not None :
501+ numba_fn .trust_input = True
502+ benchmark (numba_fn , x_test )
503+
504+ def test_buffer (self , n_steps , op_size , buffer_size ):
505+ self .buffer_tester (n_steps , op_size , buffer_size , benchmark = None )
506+
507+ def test_buffer_benchmark (self , n_steps , op_size , buffer_size , benchmark ):
508+ self .buffer_tester (n_steps , op_size , buffer_size , benchmark = benchmark )
509+
510+
511+ @pytest .mark .parametrize ("constant_n_steps" , [False , True ])
512+ @pytest .mark .parametrize ("n_steps_val" , [1 , 1000 ])
513+ class TestScanMITSOTBuffer :
514+ def buffer_tester (self , constant_n_steps , n_steps_val , benchmark = None ):
515+ """Make sure we can handle storage changes caused by the `scan_save_mem` rewrite."""
516+
517+ def f_pow2 (x_tm2 , x_tm1 ):
518+ return 2 * x_tm1 + x_tm2
519+
520+ init_x = pt .vector ("init_x" , shape = (2 ,))
521+ n_steps = pt .iscalar ("n_steps" )
522+ output , _ = scan (
523+ f_pow2 ,
524+ sequences = [],
525+ outputs_info = [{"initial" : init_x , "taps" : [- 2 , - 1 ]}],
526+ non_sequences = [],
527+ n_steps = n_steps_val if constant_n_steps else n_steps ,
528+ )
529+
530+ init_x_val = np .array ([1.0 , 2.0 ], dtype = init_x .type .dtype )
531+ test_vals = (
532+ [init_x_val ]
533+ if constant_n_steps
534+ else [init_x_val , np .asarray (n_steps_val , dtype = n_steps .type .dtype )]
535+ )
536+ numba_fn , _ = compare_numba_and_py (
537+ [init_x ] if constant_n_steps else [init_x , n_steps ],
538+ [output [- 1 ]],
539+ test_vals ,
540+ numba_mode = "NUMBA" ,
541+ eval_obj_mode = False ,
542+ )
543+
544+ if n_steps_val == 1 and constant_n_steps :
545+ # There's no Scan in the graph when nsteps=constant(1)
546+ return
547+
548+ # Check the buffer size as been optimized
549+ [scan_node ] = [
550+ node
551+ for node in numba_fn .maker .fgraph .toposort ()
552+ if isinstance (node .op , Scan )
553+ ]
554+ [mitsot_buffer ] = scan_node .op .outer_mitsot (scan_node .inputs )
555+ mitsot_buffer_shape = mitsot_buffer .shape .eval (
556+ {init_x : init_x_val , n_steps : n_steps_val },
557+ accept_inplace = True ,
558+ on_unused_input = "ignore" ,
559+ )
560+ assert tuple (mitsot_buffer_shape ) == (3 ,)
561+
562+ if benchmark is not None :
563+ numba_fn .trust_input = True
564+ benchmark (numba_fn , * test_vals )
565+
566+ def test_buffer (self , constant_n_steps , n_steps_val ):
567+ self .buffer_tester (constant_n_steps , n_steps_val , benchmark = None )
568+
569+ def test_buffer_benchmark (self , constant_n_steps , n_steps_val , benchmark ):
570+ self .buffer_tester (constant_n_steps , n_steps_val , benchmark = benchmark )
0 commit comments