Skip to content

Commit d2c34a5

Browse files
authored
Add support for complex floating-point data types in mean (#796)
1 parent ab233c9 commit d2c34a5

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

cubed/array_api/statistical_functions.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import math
22

3+
from cubed.array_api.data_type_functions import isdtype
34
from cubed.array_api.dtypes import (
5+
_floating_dtypes,
46
_real_floating_dtypes,
57
_real_numeric_dtypes,
68
_upcast_integral_dtypes,
@@ -94,14 +96,17 @@ def max(x, /, *, axis=None, keepdims=False, split_every=None):
9496

9597

9698
def mean(x, /, *, axis=None, keepdims=False, split_every=None):
97-
if x.dtype not in _real_floating_dtypes:
98-
raise TypeError("Only real floating-point dtypes are allowed in mean")
99+
if x.dtype not in _floating_dtypes:
100+
raise TypeError("Only floating-point dtypes are allowed in mean")
99101
# This implementation uses a Zarr group of two arrays to store a
100102
# pair of fields needed to keep per-chunk counts and totals for computing
101103
# the mean.
102104
dtype = x.dtype
103105
# TODO(#658): Should these be default dtypes?
104-
intermediate_dtype = [("n", nxp.int64), ("total", nxp.float64)]
106+
if isdtype(x.dtype, "complex floating"):
107+
intermediate_dtype = [("n", nxp.int64), ("total", nxp.complex128)]
108+
else:
109+
intermediate_dtype = [("n", nxp.int64), ("total", nxp.float64)]
105110
extra_func_kwargs = dict(dtype=intermediate_dtype)
106111
return reduction(
107112
x,

cubed/tests/test_array_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,11 @@ def test_mean_axis_0(spec, executor):
904904
np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).mean(axis=0),
905905
)
906906

907+
def test_mean_complex():
908+
a = xp.asarray([1.0+1.0j, 2.0+2.0j, 3.0+3.0j], chunks=(2,))
909+
b = xp.mean(a)
910+
assert_array_equal(b.compute(), np.array([1.0+1.0j, 2.0+2.0j, 3.0+3.0j]).mean())
911+
907912

908913
def test_sum(spec, executor):
909914
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)

0 commit comments

Comments
 (0)