Skip to content

Commit 85da56c

Browse files
committed
Add test_get_support_shape
1 parent 9d419ef commit 85da56c

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed

pymc/tests/distributions/test_shape_utils.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
convert_shape,
3838
convert_size,
3939
get_broadcastable_dist_samples,
40+
get_support_shape,
4041
get_support_shape_1d,
4142
rv_size_is_none,
4243
shapes_broadcasting,
@@ -664,3 +665,79 @@ def test_get_support_shape_1d(
664665
assert f() == expected_support_shape
665666
with pytest.raises(AssertionError, match="support_shape does not match"):
666667
inferred_support_shape.eval()
668+
669+
670+
@pytest.mark.parametrize(
671+
"support_shape, shape, support_shape_offset, expected_support_shape, ndim_supp, consistent",
672+
[
673+
((10, 5), None, (0,), (10, 5), 1, True),
674+
((10, 5), None, (1, 1), (10, 5), 1, True),
675+
(None, (10, 5), (0,), 5, 1, True),
676+
(None, (10, 5), (1,), 4, 1, True),
677+
(None, (10, 5, 2), (0,), 2, 1, True),
678+
(None, None, None, None, 1, True),
679+
((10, 5), (10, 5), None, (10, 5), 2, True),
680+
((10, 5), (11, 10, 5), None, (10, 5), 2, True),
681+
(None, (11, 10, 5), (0, 1, 0), (11, 9, 5), 3, True),
682+
((10, 5), (10, 5, 5), (0,), (5,), 1, False),
683+
((10, 5), (10, 5), (1, 1), (9, 4), 2, False),
684+
],
685+
)
686+
@pytest.mark.parametrize("info_source", ("shape", "dims", "observed"))
687+
def test_get_support_shape(
688+
info_source,
689+
support_shape,
690+
shape,
691+
support_shape_offset,
692+
expected_support_shape,
693+
ndim_supp,
694+
consistent,
695+
):
696+
if info_source == "shape":
697+
inferred_support_shape = get_support_shape(
698+
support_shape=support_shape,
699+
shape=shape,
700+
support_shape_offset=support_shape_offset,
701+
ndim_supp=ndim_supp,
702+
)
703+
704+
elif info_source == "dims":
705+
if shape is None:
706+
dims = None
707+
coords = {}
708+
else:
709+
dims = tuple(str(i) for i, _ in enumerate(shape))
710+
coords = {str(i): range(shape) for i, shape in enumerate(shape)}
711+
with Model(coords=coords):
712+
inferred_support_shape = get_support_shape(
713+
support_shape=support_shape,
714+
dims=dims,
715+
support_shape_offset=support_shape_offset,
716+
ndim_supp=ndim_supp,
717+
)
718+
719+
elif info_source == "observed":
720+
if shape is None:
721+
observed = None
722+
else:
723+
observed = np.zeros(shape)
724+
inferred_support_shape = get_support_shape(
725+
support_shape=support_shape,
726+
observed=observed,
727+
support_shape_offset=support_shape_offset,
728+
ndim_supp=ndim_supp,
729+
)
730+
731+
if not isinstance(inferred_support_shape, TensorVariable):
732+
assert inferred_support_shape == expected_support_shape
733+
else:
734+
if consistent:
735+
assert (inferred_support_shape.eval() == expected_support_shape).all()
736+
else:
737+
# check that inferred support shape is still correct by ignoring the assert
738+
f = aesara.function(
739+
[], inferred_support_shape, mode=Mode().including("local_remove_all_assert")
740+
)
741+
assert (f() == expected_support_shape).all()
742+
with pytest.raises(AssertionError, match="support_shape does not match"):
743+
inferred_support_shape.eval()

0 commit comments

Comments
 (0)