@@ -605,7 +605,7 @@ def test_change_specify_shape_size_multivariate():
605
605
606
606
607
607
@pytest .mark .parametrize (
608
- "steps , shape, step_shape_offset, expected_steps , consistent" ,
608
+ "support_shape , shape, support_shape_offset, expected_support_shape , consistent" ,
609
609
[
610
610
(10 , None , 0 , 10 , True ),
611
611
(10 , None , 1 , 10 , True ),
@@ -621,44 +621,46 @@ def test_change_specify_shape_size_multivariate():
621
621
)
622
622
@pytest .mark .parametrize ("info_source" , ("shape" , "dims" , "observed" ))
623
623
def test_get_support_shape_1d (
624
- info_source , steps , shape , step_shape_offset , expected_steps , consistent
624
+ info_source , support_shape , shape , support_shape_offset , expected_support_shape , consistent
625
625
):
626
626
if info_source == "shape" :
627
- inferred_steps = get_support_shape_1d (
628
- support_shape = steps , shape = shape , support_shape_offset = step_shape_offset
627
+ inferred_support_shape = get_support_shape_1d (
628
+ support_shape = support_shape , shape = shape , support_shape_offset = support_shape_offset
629
629
)
630
630
631
631
elif info_source == "dims" :
632
632
if shape is None :
633
633
dims = None
634
634
coords = {}
635
635
else :
636
- dims = tuple (str (i ) for i , shape in enumerate (shape ))
636
+ dims = tuple (str (i ) for i , _ in enumerate (shape ))
637
637
coords = {str (i ): range (shape ) for i , shape in enumerate (shape )}
638
638
with Model (coords = coords ):
639
- inferred_steps = get_support_shape_1d (
640
- support_shape = steps , dims = dims , support_shape_offset = step_shape_offset
639
+ inferred_support_shape = get_support_shape_1d (
640
+ support_shape = support_shape , dims = dims , support_shape_offset = support_shape_offset
641
641
)
642
642
643
643
elif info_source == "observed" :
644
644
if shape is None :
645
645
observed = None
646
646
else :
647
647
observed = np .zeros (shape )
648
- inferred_steps = get_support_shape_1d (
649
- support_shape = steps , observed = observed , support_shape_offset = step_shape_offset
648
+ inferred_support_shape = get_support_shape_1d (
649
+ support_shape = support_shape ,
650
+ observed = observed ,
651
+ support_shape_offset = support_shape_offset ,
650
652
)
651
653
652
- if not isinstance (inferred_steps , TensorVariable ):
653
- assert inferred_steps == expected_steps
654
+ if not isinstance (inferred_support_shape , TensorVariable ):
655
+ assert inferred_support_shape == expected_support_shape
654
656
else :
655
657
if consistent :
656
- assert inferred_steps .eval () == expected_steps
658
+ assert inferred_support_shape .eval () == expected_support_shape
657
659
else :
658
660
# check that inferred steps is still correct by ignoring the assert
659
661
f = aesara .function (
660
- [], inferred_steps , mode = Mode ().including ("local_remove_all_assert" )
662
+ [], inferred_support_shape , mode = Mode ().including ("local_remove_all_assert" )
661
663
)
662
- assert f () == expected_steps
663
- with pytest .raises (AssertionError , match = "Steps do not match" ):
664
- inferred_steps .eval ()
664
+ assert f () == expected_support_shape
665
+ with pytest .raises (AssertionError , match = "support_shape does not match" ):
666
+ inferred_support_shape .eval ()
0 commit comments