|
37 | 37 | convert_shape,
|
38 | 38 | convert_size,
|
39 | 39 | get_broadcastable_dist_samples,
|
| 40 | + get_support_shape, |
40 | 41 | get_support_shape_1d,
|
41 | 42 | rv_size_is_none,
|
42 | 43 | shapes_broadcasting,
|
@@ -664,3 +665,79 @@ def test_get_support_shape_1d(
|
664 | 665 | assert f() == expected_support_shape
|
665 | 666 | with pytest.raises(AssertionError, match="support_shape does not match"):
|
666 | 667 | inferred_support_shape.eval()
|
| 668 | + |
| 669 | + |
| 670 | +@pytest.mark.parametrize( |
| 671 | + "support_shape, shape, support_shape_offset, expected_support_shape, ndim_supp, consistent", |
| 672 | + [ |
| 673 | + ((10, 5), None, (0,), (10, 5), 1, True), |
| 674 | + ((10, 5), None, (1, 1), (10, 5), 1, True), |
| 675 | + (None, (10, 5), (0,), 5, 1, True), |
| 676 | + (None, (10, 5), (1,), 4, 1, True), |
| 677 | + (None, (10, 5, 2), (0,), 2, 1, True), |
| 678 | + (None, None, None, None, 1, True), |
| 679 | + ((10, 5), (10, 5), None, (10, 5), 2, True), |
| 680 | + ((10, 5), (11, 10, 5), None, (10, 5), 2, True), |
| 681 | + (None, (11, 10, 5), (0, 1, 0), (11, 9, 5), 3, True), |
| 682 | + ((10, 5), (10, 5, 5), (0,), (5,), 1, False), |
| 683 | + ((10, 5), (10, 5), (1, 1), (9, 4), 2, False), |
| 684 | + ], |
| 685 | +) |
| 686 | +@pytest.mark.parametrize("info_source", ("shape", "dims", "observed")) |
| 687 | +def test_get_support_shape( |
| 688 | + info_source, |
| 689 | + support_shape, |
| 690 | + shape, |
| 691 | + support_shape_offset, |
| 692 | + expected_support_shape, |
| 693 | + ndim_supp, |
| 694 | + consistent, |
| 695 | +): |
| 696 | + if info_source == "shape": |
| 697 | + inferred_support_shape = get_support_shape( |
| 698 | + support_shape=support_shape, |
| 699 | + shape=shape, |
| 700 | + support_shape_offset=support_shape_offset, |
| 701 | + ndim_supp=ndim_supp, |
| 702 | + ) |
| 703 | + |
| 704 | + elif info_source == "dims": |
| 705 | + if shape is None: |
| 706 | + dims = None |
| 707 | + coords = {} |
| 708 | + else: |
| 709 | + dims = tuple(str(i) for i, _ in enumerate(shape)) |
| 710 | + coords = {str(i): range(shape) for i, shape in enumerate(shape)} |
| 711 | + with Model(coords=coords): |
| 712 | + inferred_support_shape = get_support_shape( |
| 713 | + support_shape=support_shape, |
| 714 | + dims=dims, |
| 715 | + support_shape_offset=support_shape_offset, |
| 716 | + ndim_supp=ndim_supp, |
| 717 | + ) |
| 718 | + |
| 719 | + elif info_source == "observed": |
| 720 | + if shape is None: |
| 721 | + observed = None |
| 722 | + else: |
| 723 | + observed = np.zeros(shape) |
| 724 | + inferred_support_shape = get_support_shape( |
| 725 | + support_shape=support_shape, |
| 726 | + observed=observed, |
| 727 | + support_shape_offset=support_shape_offset, |
| 728 | + ndim_supp=ndim_supp, |
| 729 | + ) |
| 730 | + |
| 731 | + if not isinstance(inferred_support_shape, TensorVariable): |
| 732 | + assert inferred_support_shape == expected_support_shape |
| 733 | + else: |
| 734 | + if consistent: |
| 735 | + assert (inferred_support_shape.eval() == expected_support_shape).all() |
| 736 | + else: |
| 737 | + # check that inferred support shape is still correct by ignoring the assert |
| 738 | + f = aesara.function( |
| 739 | + [], inferred_support_shape, mode=Mode().including("local_remove_all_assert") |
| 740 | + ) |
| 741 | + assert (f() == expected_support_shape).all() |
| 742 | + with pytest.raises(AssertionError, match="support_shape does not match"): |
| 743 | + inferred_support_shape.eval() |
0 commit comments