Skip to content

Commit 7aafe10

Browse files
committed
Update stale xfail tests
1 parent 042a1c5 commit 7aafe10

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

pymc/tests/test_shape_handling.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -367,35 +367,35 @@ def test_can_resize_data_defined_size(self):
367367
assert y.eval().shape == (3, 2)
368368
assert z.eval().shape == (3, 2)
369369

370-
@pytest.mark.xfail(reason="https://github.com/pymc-devs/aesara/issues/390")
371-
def test_size32_doesnt_break_broadcasting():
370+
def test_size32_doesnt_break_broadcasting(self):
372371
size32 = at.constant([1, 10], dtype="int32")
373372
rv = pm.Normal.dist(0, 1, size=size32)
374373
assert rv.broadcastable == (True, False)
375374

376-
@pytest.mark.xfail(reason="https://github.com/pymc-devs/aesara/issues/390")
377375
def test_observed_with_column_vector(self):
378376
"""This test is related to https://github.com/pymc-devs/aesara/issues/390 which breaks
379377
broadcastability of column-vector RVs. This unexpected change in type can lead to
380378
incompatibilities during graph rewriting for model.logp evaluation.
381379
"""
382380
with pm.Model() as model:
383381
# The `observed` is a broadcastable column vector
384-
obs = at.as_tensor_variable(np.ones((3, 1), dtype=aesara.config.floatX))
385-
assert obs.broadcastable == (False, True)
382+
obs = [
383+
at.as_tensor_variable(np.ones((3, 1), dtype=aesara.config.floatX)) for _ in range(4)
384+
]
385+
assert all(obs_.broadcastable == (False, True) for obs_ in obs)
386386

387387
# Both shapes describe broadcastable volumn vectors
388388
size64 = at.constant([3, 1], dtype="int64")
389389
# But the second shape is upcasted from an int32 vector
390390
cast64 = at.cast(at.constant([3, 1], dtype="int32"), dtype="int64")
391391

392-
pm.Normal("size64", mu=0, sigma=1, size=size64, observed=obs)
393-
pm.Normal("shape64", mu=0, sigma=1, shape=size64, observed=obs)
394-
model.logp()
392+
pm.Normal("size64", mu=0, sigma=1, size=size64, observed=obs[0])
393+
pm.Normal("shape64", mu=0, sigma=1, shape=size64, observed=obs[1])
394+
assert model.compile_logp()({})
395395

396-
pm.Normal("size_cast64", mu=0, sigma=1, size=cast64, observed=obs)
397-
pm.Normal("shape_cast64", mu=0, sigma=1, shape=cast64, observed=obs)
398-
model.logp()
396+
pm.Normal("size_cast64", mu=0, sigma=1, size=cast64, observed=obs[2])
397+
pm.Normal("shape_cast64", mu=0, sigma=1, shape=cast64, observed=obs[3])
398+
assert model.compile_logp()({})
399399

400400
def test_dist_api_works(self):
401401
mu = aesara.shared(np.array([1, 2, 3]))

0 commit comments

Comments
 (0)