@@ -1768,7 +1768,7 @@ def update_expr(self, new_op): # noqa: C901
17681768 new_operands = {}
17691769 # where() handling requires evaluating the expression prior to merge.
17701770 # This is different from reductions, where the expression is evaluated
1771- # and returned an NumPy array (for usability convenience).
1771+ # and returned a NumPy array (for usability convenience).
17721772 # We do things like this to enable the fusion of operations like
17731773 # `a.where(0, 1).sum()`.
17741774 # Another possibility would have been to always evaluate where() and produce
@@ -1783,6 +1783,7 @@ def update_expr(self, new_op): # noqa: C901
17831783 # We converted some of the operands to NDArray (where() handling above)
17841784 new_operands = {"o0" : value1 , "o1" : value2 }
17851785 expression = f"(o0 { op } o1)"
1786+ return self ._new_expr (expression , new_operands , guess = False , out = None , where = None )
17861787 elif isinstance (value1 , LazyExpr ) and isinstance (value2 , LazyExpr ):
17871788 # Expression fusion
17881789 # Fuse operands in expressions and detect duplicates
@@ -2113,6 +2114,14 @@ def prod(self, axis=None, dtype=None, keepdims=False, **kwargs):
21132114 return self .compute (_reduce_args = reduce_args , ** kwargs )
21142115
21152116 def get_num_elements (self , axis , item ):
2117+ if hasattr (self , "_where_args" ) and len (self ._where_args ) == 1 :
2118+ # We have a where condition, so we need to count the number of elements
2119+ # fulfilling the condition
2120+ orig_where_args = self ._where_args
2121+ self ._where_args = {"_where_x" : blosc2 .ones (self .shape , dtype = np .int8 )}
2122+ num_elements = self .sum (axis = axis , dtype = np .int64 , item = item )
2123+ self ._where_args = orig_where_args
2124+ return num_elements
21162125 if np .isscalar (axis ):
21172126 axis = (axis ,)
21182127 # Compute the number of elements in the array
0 commit comments