Skip to content

Commit 7f3c032

Browse files
committed
Workaround buggy AdvancedIndexing Mixture logprob
Now that Dimshuffle lift broadcasts both the parameters and the size, the AdvancedIndexing logprob fails most of the times, even though these are valid graphs. All but one of the failing cases can be helped by introducing the local_rv_size_lift rewrite.
1 parent 547cf57 commit 7f3c032

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

pymc/logprob/mixture.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from pytensor.tensor.elemwise import Elemwise
5454
from pytensor.tensor.random.rewriting import (
5555
local_dimshuffle_rv_lift,
56+
local_rv_size_lift,
5657
local_subtensor_rv_lift,
5758
)
5859
from pytensor.tensor.shape import shape_tuple
@@ -210,6 +211,7 @@ def rv_pull_down(x: TensorVariable, dont_touch_vars=None) -> TensorVariable:
210211
return pre_greedy_node_rewriter(
211212
fgraph,
212213
[
214+
local_rv_size_lift,
213215
local_dimshuffle_rv_lift,
214216
local_subtensor_rv_lift,
215217
naive_bcast_rv_lift,

pymc/tests/logprob/test_mixture.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def test_hetero_mixture_binomial(p_val, size):
324324
0,
325325
),
326326
# Same as before but with degenerate vector parameters
327-
pytest.param(
327+
(
328328
(
329329
np.array([0], dtype=pytensor.config.floatX),
330330
np.array(1, dtype=pytensor.config.floatX),
@@ -342,7 +342,6 @@ def test_hetero_mixture_binomial(p_val, size):
342342
(2,),
343343
(),
344344
0,
345-
marks=pytest.mark.xfail(IndexError, reason="Bug in AdvancedIndexing Mixture logprob"),
346345
),
347346
(
348347
(
@@ -382,7 +381,7 @@ def test_hetero_mixture_binomial(p_val, size):
382381
(),
383382
0,
384383
),
385-
(
384+
pytest.param(
386385
(
387386
np.array(0, dtype=pytensor.config.floatX),
388387
np.array(1, dtype=pytensor.config.floatX),
@@ -400,6 +399,7 @@ def test_hetero_mixture_binomial(p_val, size):
400399
(3,),
401400
(slice(None),),
402401
1,
402+
marks=pytest.mark.xfail(IndexError, reason="Bug in AdvancedIndex Mixture logprob"),
403403
),
404404
(
405405
(

0 commit comments

Comments
 (0)