1
1
import math
2
2
3
3
from cubed .array_api .dtypes import (
4
+ _boolean_dtypes ,
4
5
_numeric_dtypes ,
5
6
_real_floating_dtypes ,
6
7
_real_numeric_dtypes ,
@@ -124,10 +125,13 @@ def min(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None)
124
125
def prod (
125
126
x , / , * , axis = None , dtype = None , keepdims = False , use_new_impl = True , split_every = None
126
127
):
127
- if x .dtype not in _numeric_dtypes :
128
- raise TypeError ("Only numeric dtypes are allowed in prod" )
128
+ # boolean is allowed by numpy
129
+ if x .dtype not in _numeric_dtypes and x .dtype not in _boolean_dtypes :
130
+ raise TypeError ("Only numeric or boolean dtypes are allowed in prod" )
129
131
if dtype is None :
130
- if x .dtype in _signed_integer_dtypes :
132
+ if x .dtype in _boolean_dtypes :
133
+ dtype = int64
134
+ elif x .dtype in _signed_integer_dtypes :
131
135
dtype = int64
132
136
elif x .dtype in _unsigned_integer_dtypes :
133
137
dtype = uint64
@@ -153,10 +157,13 @@ def prod(
153
157
def sum (
154
158
x , / , * , axis = None , dtype = None , keepdims = False , use_new_impl = True , split_every = None
155
159
):
156
- if x .dtype not in _numeric_dtypes :
157
- raise TypeError ("Only numeric dtypes are allowed in sum" )
160
+ # boolean is allowed by numpy
161
+ if x .dtype not in _numeric_dtypes and x .dtype not in _boolean_dtypes :
162
+ raise TypeError ("Only numeric or boolean dtypes are allowed in sum" )
158
163
if dtype is None :
159
- if x .dtype in _signed_integer_dtypes :
164
+ if x .dtype in _boolean_dtypes :
165
+ dtype = int64
166
+ elif x .dtype in _signed_integer_dtypes :
160
167
dtype = int64
161
168
elif x .dtype in _unsigned_integer_dtypes :
162
169
dtype = uint64
0 commit comments