@@ -48,7 +48,7 @@ def boolean_indexing_set_or_inc(fgraph, node):
48
48
49
49
@node_rewriter ([Sum ])
50
50
def boolean_indexing_sum (fgraph , node ):
51
- """Replace the sum of `AdvancedSubtensor` with boolean indexing.
51
+ """Replace the sum of `AdvancedSubtensor` with exclusively boolean indexing.
52
52
53
53
JAX cannot JIT-compile functions that use boolean indexing, but can compile
54
54
those expressions that can be re-expressed using `jax.numpy.where`. This
@@ -61,21 +61,30 @@ def boolean_indexing_sum(fgraph, node):
61
61
if not isinstance (operand , TensorVariable ):
62
62
return
63
63
64
+ # If it's not a scalar reduction, it couldn't have been a pure boolean mask
65
+ if node .outputs [0 ].ndim != 0 :
66
+ return
67
+
64
68
if operand .owner is None :
65
69
return
66
70
67
71
if not isinstance (operand .owner .op , AdvancedSubtensor ):
68
72
return
69
73
70
- x = operand .owner .inputs [0 ]
71
- cond = operand .owner .inputs [1 ]
74
+ # Get out if AdvancedSubtensor has more than a single indexing operation
75
+ if len (operand .owner .inputs ) > 2 :
76
+ return
77
+
78
+ [x , cond ] = operand .owner .inputs
72
79
73
80
if not isinstance (cond , TensorVariable ):
74
81
return
75
82
76
83
if not cond .type .dtype == "bool" :
77
84
return
78
85
86
+ # Output must be a scalar, since pure boolean indexing returns a vector
87
+ # No need to worry about axis
79
88
out = at .sum (at .where (cond , x , 0 ))
80
89
return out .owner .outputs
81
90
0 commit comments