Skip to content

Commit 6752317

Browse files
chore: as per suggestion of @ricardoV94
1 parent 145d28d commit 6752317

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

pymc/logprob/mixture.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
is_basic_idx,
6363
)
6464
from pytensor.tensor.type import TensorType
65-
from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceConstant, SliceType
65+
from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceType
6666
from pytensor.tensor.variable import TensorVariable
6767

6868
from pymc.logprob.abstract import (
@@ -290,14 +290,11 @@ def find_measurable_index_mixture(fgraph, node):
290290
# but the Mixture logprob assumes all mixture values are independent
291291
if any(
292292
(
293-
isinstance(indices, (type(NoneConst) | type(None)))
294-
or (
295-
indices.dtype.startswith("int")
296-
and sum(1 - b for b in indices.type.broadcastable) > 0
297-
)
293+
isinstance(indices, TensorVariable)
294+
and indices.dtype.startswith("int")
295+
and any(not b for b in indices.type.broadcastable)
298296
)
299297
for indices in mixing_indices
300-
if not isinstance(indices, SliceConstant)
301298
):
302299
return None
303300

0 commit comments

Comments
 (0)