@@ -367,35 +367,35 @@ def test_can_resize_data_defined_size(self):
367
367
assert y .eval ().shape == (3 , 2 )
368
368
assert z .eval ().shape == (3 , 2 )
369
369
370
- @pytest .mark .xfail (reason = "https://github.com/pymc-devs/aesara/issues/390" )
371
- def test_size32_doesnt_break_broadcasting ():
370
+ def test_size32_doesnt_break_broadcasting (self ):
372
371
size32 = at .constant ([1 , 10 ], dtype = "int32" )
373
372
rv = pm .Normal .dist (0 , 1 , size = size32 )
374
373
assert rv .broadcastable == (True , False )
375
374
376
- @pytest .mark .xfail (reason = "https://github.com/pymc-devs/aesara/issues/390" )
377
375
def test_observed_with_column_vector (self ):
378
376
"""This test is related to https://github.com/pymc-devs/aesara/issues/390 which breaks
379
377
broadcastability of column-vector RVs. This unexpected change in type can lead to
380
378
incompatibilities during graph rewriting for model.logp evaluation.
381
379
"""
382
380
with pm .Model () as model :
383
381
# The `observed` is a broadcastable column vector
384
- obs = at .as_tensor_variable (np .ones ((3 , 1 ), dtype = aesara .config .floatX ))
385
- assert obs .broadcastable == (False , True )
382
+ obs = [
383
+ at .as_tensor_variable (np .ones ((3 , 1 ), dtype = aesara .config .floatX )) for _ in range (4 )
384
+ ]
385
+ assert all (obs_ .broadcastable == (False , True ) for obs_ in obs )
386
386
387
387
# Both shapes describe broadcastable volumn vectors
388
388
size64 = at .constant ([3 , 1 ], dtype = "int64" )
389
389
# But the second shape is upcasted from an int32 vector
390
390
cast64 = at .cast (at .constant ([3 , 1 ], dtype = "int32" ), dtype = "int64" )
391
391
392
- pm .Normal ("size64" , mu = 0 , sigma = 1 , size = size64 , observed = obs )
393
- pm .Normal ("shape64" , mu = 0 , sigma = 1 , shape = size64 , observed = obs )
394
- model .logp ( )
392
+ pm .Normal ("size64" , mu = 0 , sigma = 1 , size = size64 , observed = obs [ 0 ] )
393
+ pm .Normal ("shape64" , mu = 0 , sigma = 1 , shape = size64 , observed = obs [ 1 ] )
394
+ assert model .compile_logp ()({} )
395
395
396
- pm .Normal ("size_cast64" , mu = 0 , sigma = 1 , size = cast64 , observed = obs )
397
- pm .Normal ("shape_cast64" , mu = 0 , sigma = 1 , shape = cast64 , observed = obs )
398
- model .logp ( )
396
+ pm .Normal ("size_cast64" , mu = 0 , sigma = 1 , size = cast64 , observed = obs [ 2 ] )
397
+ pm .Normal ("shape_cast64" , mu = 0 , sigma = 1 , shape = cast64 , observed = obs [ 3 ] )
398
+ assert model .compile_logp ()({} )
399
399
400
400
def test_dist_api_works (self ):
401
401
mu = aesara .shared (np .array ([1 , 2 , 3 ]))
0 commit comments