@@ -451,6 +451,80 @@ def step(seq1, seq2, mitsot1, mitsot2, sitsot1):
451451 benchmark (numba_fn , * test .values ())
452452
453453
454+ @pytest .mark .parametrize ("n_steps_constant" , (True , False ))
455+ def test_inplace_taps (n_steps_constant ):
456+ """Test that numba will inplace in the inner_function of the oldest sit-sot, mit-sot taps."""
457+ n_steps = 10 if n_steps_constant else scalar ("n_steps" , dtype = int )
458+ a = scalar ("a" )
459+ x0 = scalar ("x0" )
460+ y0 = vector ("y0" , shape = (2 ,))
461+ z0 = vector ("z0" , shape = (3 ,))
462+
463+ def step (ztm3 , ztm1 , xtm1 , ytm1 , ytm2 , a ):
464+ z = ztm1 + 1 + ztm3 + a
465+ x = xtm1 + 1
466+ y = ytm1 + 1 + ytm2 + a
467+ return z , x , z + x + y , y
468+
469+ [zs , xs , ws , ys ], _ = scan (
470+ fn = step ,
471+ outputs_info = [
472+ dict (initial = z0 , taps = [- 3 , - 1 ]),
473+ dict (initial = x0 , taps = [- 1 ]),
474+ None ,
475+ dict (initial = y0 , taps = [- 1 , - 2 ]),
476+ ],
477+ non_sequences = [a ],
478+ n_steps = n_steps ,
479+ )
480+ numba_fn , _ = compare_numba_and_py (
481+ [n_steps ] * (not n_steps_constant ) + [a , x0 , y0 , z0 ],
482+ [zs [- 1 ], xs [- 1 ], ws [- 1 ], ys [- 1 ]],
483+ [10 ] * (not n_steps_constant ) + [np .pi , np .e , [1 , np .euler_gamma ], [0 , 1 , 2 ]],
484+ numba_mode = "NUMBA" ,
485+ eval_obj_mode = False ,
486+ )
487+ [scan_op ] = [
488+ node .op
489+ for node in numba_fn .maker .fgraph .toposort ()
490+ if isinstance (node .op , Scan )
491+ ]
492+
493+ # Scan reorders inputs internally, so we need to check its ordering
494+ inner_inps = scan_op .fgraph .inputs
495+ mit_sot_inps = scan_op .inner_mitsot (inner_inps )
496+ oldest_mit_sot_inps = [
497+ # Implicitly assume that the first mit-sot input is the one with 3 taps
498+ # This is not a required behavior and the test can change if we need to change Scan.
499+ mit_sot_inps [:2 ][scan_op .info .mit_sot_in_slices [0 ].index (- 3 )],
500+ mit_sot_inps [2 :][scan_op .info .mit_sot_in_slices [1 ].index (- 2 )],
501+ ]
502+ [sit_sot_inp ] = scan_op .inner_sitsot (inner_inps )
503+
504+ inner_outs = scan_op .fgraph .outputs
505+ mit_sot_outs = scan_op .inner_mitsot_outs (inner_outs )
506+ [sit_sot_out ] = scan_op .inner_sitsot_outs (inner_outs )
507+ [nit_sot_out ] = scan_op .inner_nitsot_outs (inner_outs )
508+
509+ if n_steps_constant :
510+ assert mit_sot_outs [0 ].owner .op .destroy_map == {
511+ 0 : [mit_sot_outs [0 ].owner .inputs .index (oldest_mit_sot_inps [0 ])]
512+ }
513+ assert mit_sot_outs [1 ].owner .op .destroy_map == {
514+ 0 : [mit_sot_outs [1 ].owner .inputs .index (oldest_mit_sot_inps [1 ])]
515+ }
516+ assert sit_sot_out .owner .op .destroy_map == {
517+ 0 : [sit_sot_out .owner .inputs .index (sit_sot_inp )]
518+ }
519+ else :
520+ # This is not a feature, but a current limitation
521+ # https://github.com/pymc-devs/pytensor/issues/1283
522+ assert mit_sot_outs [0 ].owner .op .destroy_map == {}
523+ assert mit_sot_outs [1 ].owner .op .destroy_map == {}
524+ assert sit_sot_out .owner .op .destroy_map == {}
525+ assert nit_sot_out .owner .op .destroy_map == {}
526+
527+
454528@pytest .mark .parametrize (
455529 "buffer_size" , ("unit" , "aligned" , "misaligned" , "whole" , "whole+init" )
456530)
0 commit comments