Skip to content

Commit 9d419ef

Browse files
committed
Fix test_get_support_shape_1d
1 parent ca655bc commit 9d419ef

File tree

2 files changed

+20
-16
lines changed

2 files changed

+20
-16
lines changed

pymc/distributions/shape_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,8 @@ def get_support_shape_1d(
755755
"""Helper function for cases when you just care about one dimension."""
756756
if support_shape is not None:
757757
support_shape_tuple = (support_shape,)
758+
else:
759+
support_shape_tuple = None
758760

759761
support_shape_tuple = get_support_shape(
760762
support_shape_tuple,

pymc/tests/distributions/test_shape_utils.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ def test_change_specify_shape_size_multivariate():
605605

606606

607607
@pytest.mark.parametrize(
608-
"steps, shape, step_shape_offset, expected_steps, consistent",
608+
"support_shape, shape, support_shape_offset, expected_support_shape, consistent",
609609
[
610610
(10, None, 0, 10, True),
611611
(10, None, 1, 10, True),
@@ -621,44 +621,46 @@ def test_change_specify_shape_size_multivariate():
621621
)
622622
@pytest.mark.parametrize("info_source", ("shape", "dims", "observed"))
623623
def test_get_support_shape_1d(
624-
info_source, steps, shape, step_shape_offset, expected_steps, consistent
624+
info_source, support_shape, shape, support_shape_offset, expected_support_shape, consistent
625625
):
626626
if info_source == "shape":
627-
inferred_steps = get_support_shape_1d(
628-
support_shape=steps, shape=shape, support_shape_offset=step_shape_offset
627+
inferred_support_shape = get_support_shape_1d(
628+
support_shape=support_shape, shape=shape, support_shape_offset=support_shape_offset
629629
)
630630

631631
elif info_source == "dims":
632632
if shape is None:
633633
dims = None
634634
coords = {}
635635
else:
636-
dims = tuple(str(i) for i, shape in enumerate(shape))
636+
dims = tuple(str(i) for i, _ in enumerate(shape))
637637
coords = {str(i): range(shape) for i, shape in enumerate(shape)}
638638
with Model(coords=coords):
639-
inferred_steps = get_support_shape_1d(
640-
support_shape=steps, dims=dims, support_shape_offset=step_shape_offset
639+
inferred_support_shape = get_support_shape_1d(
640+
support_shape=support_shape, dims=dims, support_shape_offset=support_shape_offset
641641
)
642642

643643
elif info_source == "observed":
644644
if shape is None:
645645
observed = None
646646
else:
647647
observed = np.zeros(shape)
648-
inferred_steps = get_support_shape_1d(
649-
support_shape=steps, observed=observed, support_shape_offset=step_shape_offset
648+
inferred_support_shape = get_support_shape_1d(
649+
support_shape=support_shape,
650+
observed=observed,
651+
support_shape_offset=support_shape_offset,
650652
)
651653

652-
if not isinstance(inferred_steps, TensorVariable):
653-
assert inferred_steps == expected_steps
654+
if not isinstance(inferred_support_shape, TensorVariable):
655+
assert inferred_support_shape == expected_support_shape
654656
else:
655657
if consistent:
656-
assert inferred_steps.eval() == expected_steps
658+
assert inferred_support_shape.eval() == expected_support_shape
657659
else:
658660
# check that inferred steps is still correct by ignoring the assert
659661
f = aesara.function(
660-
[], inferred_steps, mode=Mode().including("local_remove_all_assert")
662+
[], inferred_support_shape, mode=Mode().including("local_remove_all_assert")
661663
)
662-
assert f() == expected_steps
663-
with pytest.raises(AssertionError, match="Steps do not match"):
664-
inferred_steps.eval()
664+
assert f() == expected_support_shape
665+
with pytest.raises(AssertionError, match="support_shape does not match"):
666+
inferred_support_shape.eval()

0 commit comments

Comments
 (0)