Skip to content

Commit 5d0e513

Browse files
committed
Closes #5232: _reduce improvements for ArkoudaExtensionArray
1 parent 4b458d1 commit 5d0e513

File tree

2 files changed

+210
-19
lines changed

2 files changed

+210
-19
lines changed

arkouda/pandas/extension/_arkouda_array.py

Lines changed: 93 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, Sequence, TypeVar
3+
import inspect
4+
5+
from typing import TYPE_CHECKING, Any, Callable, Sequence, TypeVar
46
from typing import cast as type_cast
57

68
import numpy as np
@@ -9,6 +11,7 @@
911
from pandas.api.extensions import ExtensionArray
1012

1113
from arkouda.numpy.dtypes import dtype as ak_dtype
14+
from arkouda.pandas.groupbyclass import GroupByReductionType
1215

1316
from ._arkouda_extension_array import ArkoudaExtensionArray
1417
from ._dtypes import (
@@ -224,21 +227,95 @@ def equals(self, other):
224227
return False
225228
return self._data.equals(other._data)
226229

227-
def _reduce(self, name, skipna=True, **kwargs):
228-
if name == "all":
229-
return self._data.all()
230-
elif name == "any":
231-
return self._data.any()
232-
elif name == "sum":
233-
return self._data.sum()
234-
elif name == "prod":
235-
return self._data.prod()
236-
elif name == "min":
237-
return self._data.min()
238-
elif name == "max":
239-
return self._data.max()
240-
else:
241-
raise TypeError(f"'ArkoudaArray' with dtype arkouda does not support reduction '{name}'")
230+
def _reduce(
231+
self,
232+
name: str | GroupByReductionType,
233+
skipna: bool = True,
234+
**kwargs: Any,
235+
) -> Any:
236+
"""
237+
Reduce the underlying data.
238+
239+
Parameters
240+
----------
241+
name : str | GroupByReductionType
242+
Reduction name, e.g. "sum", "mean", "nunique", ...
243+
skipna : bool
244+
If supported by the underlying implementation, skip NaN/NA values.
245+
Default is True.
246+
**kwargs : Any
247+
Extra args for compatibility (e.g. ddof for var/std).
248+
249+
Returns
250+
-------
251+
Any
252+
The reduction result.
253+
254+
Raises
255+
------
256+
TypeError
257+
If ``name`` is not a supported reduction or the underlying data does not
258+
implement the requested reduction.
259+
"""
260+
# Normalize: accept Enum or str
261+
if hasattr(name, "value"): # enum-like
262+
name = name.value
263+
if isinstance(name, tuple) and len(name) == 1: # guards against UNIQUE="unique",
264+
name = name[0]
265+
if not isinstance(name, str):
266+
raise TypeError(f"Reduction name must be a string or GroupByReductionType, got {type(name)}")
267+
268+
data = self._data
269+
270+
def _call_method(method_name: str, *args: Any, **kw: Any) -> Any:
271+
if not hasattr(data, method_name):
272+
raise TypeError(
273+
f"'ArkoudaArray' with dtype {self.dtype} does not support reduction '{name}' "
274+
f"(missing method {method_name!r} on {type(data).__name__})"
275+
)
276+
meth = getattr(data, method_name)
277+
278+
# Best-effort: pass skipna/ddof/etc only if the method accepts them.
279+
try:
280+
sig = inspect.signature(meth)
281+
except (TypeError, ValueError):
282+
return meth(*args, **kw)
283+
284+
params = sig.parameters
285+
filtered: dict[str, Any] = {k: v for k, v in kw.items() if k in params}
286+
return meth(*args, **filtered)
287+
288+
reductions: dict[str, Callable[[], Any]] = {
289+
"all": lambda: _call_method("all", skipna=skipna, **kwargs),
290+
"any": lambda: _call_method("any", skipna=skipna, **kwargs),
291+
"sum": lambda: _call_method("sum", skipna=skipna, **kwargs),
292+
"prod": lambda: _call_method("prod", skipna=skipna, **kwargs),
293+
"min": lambda: _call_method("min", skipna=skipna, **kwargs),
294+
"max": lambda: _call_method("max", skipna=skipna, **kwargs),
295+
"mean": lambda: _call_method("mean", skipna=skipna, **kwargs),
296+
"median": lambda: _call_method("median", skipna=skipna, **kwargs),
297+
"var": lambda: _call_method("var", skipna=skipna, **kwargs),
298+
"std": lambda: _call_method("std", skipna=skipna, **kwargs),
299+
"argmin": lambda: _call_method("argmin", skipna=skipna, **kwargs),
300+
"argmax": lambda: _call_method("argmax", skipna=skipna, **kwargs),
301+
"count": lambda: _call_method("count", **kwargs),
302+
"nunique": lambda: _call_method("nunique", **kwargs),
303+
"or": lambda: _call_method("or", skipna=skipna, **kwargs),
304+
"and": lambda: _call_method("and", skipna=skipna, **kwargs),
305+
"xor": lambda: _call_method("xor", skipna=skipna, **kwargs),
306+
"first": lambda: _call_method("first", skipna=skipna, **kwargs),
307+
"mode": lambda: _call_method("mode", skipna=skipna, **kwargs),
308+
"unique": lambda: _call_method("unique", **kwargs),
309+
}
310+
311+
fn = reductions.get(name)
312+
if fn is None:
313+
raise TypeError(
314+
f"'ArkoudaArray' with dtype {self.dtype} does not support reduction '{name}'. "
315+
f"Supported: {sorted(reductions)}"
316+
)
317+
318+
return fn()
242319

243320
def __eq__(self, other):
244321
"""

tests/pandas/extension/arkouda_array_extension.py

Lines changed: 117 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import inspect
2+
13
import numpy as np
24
import pandas as pd
35
import pytest
@@ -233,18 +235,130 @@ def test_argsort(self):
233235
sorted_vals = arr._data[perm]
234236
assert ak.is_sorted(sorted_vals)
235237

236-
@pytest.mark.parametrize("reduction", ["all", "any", "sum", "prod", "min", "max"])
237-
def test_reduce_ops(self, reduction):
238+
@pytest.mark.parametrize(
239+
"reduction",
240+
[
241+
"all",
242+
"any",
243+
"sum",
244+
"prod",
245+
"min",
246+
"max",
247+
"mean",
248+
"var",
249+
"std",
250+
"median",
251+
"mode",
252+
"unique",
253+
"count",
254+
"first",
255+
"nunique",
256+
],
257+
)
258+
def test_reduce_scalar_ops(self, reduction):
259+
ak_data = ak.arange(10)
260+
arr = ArkoudaArray(ak_data)
261+
262+
if not hasattr(ak_data, reduction):
263+
pytest.xfail(f"{reduction} not implemented on pdarray backend yet")
264+
265+
result = arr._reduce(reduction)
266+
267+
assert isinstance(result, numeric_and_bool_scalars)
268+
269+
@pytest.mark.parametrize("reduction", ["argmin", "argmax"])
270+
def test_reduce_arg_ops(self, reduction):
238271
ak_data = ak.arange(10)
239272
arr = ArkoudaArray(ak_data)
273+
274+
result = arr._reduce(reduction)
275+
276+
assert isinstance(result, numeric_and_bool_scalars)
277+
assert result >= 0
278+
279+
@pytest.mark.parametrize("reduction", ["or", "and", "xor"])
280+
def test_reduce_bitwise_ops(self, reduction):
281+
ak_data = ak.array([True, False, True, False])
282+
arr = ArkoudaArray(ak_data)
283+
284+
if not hasattr(ak_data, reduction):
285+
pytest.xfail(f"{reduction} not implemented on pdarray backend yet")
286+
287+
result = arr._reduce(reduction)
288+
289+
assert isinstance(result, bool)
290+
291+
@pytest.mark.parametrize("reduction", ["unique", "mode"])
292+
def test_reduce_array_ops(self, reduction):
293+
ak_data = ak.array([1, 2, 2, 3, 3, 3])
294+
arr = ArkoudaArray(ak_data)
295+
296+
if not hasattr(ak_data, reduction):
297+
pytest.xfail(f"{reduction} not implemented on pdarray backend yet")
298+
240299
result = arr._reduce(reduction)
300+
301+
assert isinstance(result, ArkoudaArray | ak.pdarray)
302+
303+
@pytest.mark.parametrize(
304+
"reduction",
305+
["sum", "mean", "min", "max", "var", "std"],
306+
)
307+
def test_reduce_skipna_kwarg(self, reduction):
308+
ak_data = ak.array([1.0, 2.0, ak.nan, 4.0])
309+
arr = ArkoudaArray(ak_data)
310+
311+
result = arr._reduce(reduction, skipna=True)
312+
241313
assert isinstance(result, numeric_and_bool_scalars)
242314

315+
@pytest.mark.parametrize("reduction", ["var", "std"])
316+
def test_reduce_ddof_kwarg(self, reduction):
317+
ak_data = ak.array([1.0, 2.0, 3.0, 4.0])
318+
arr = ArkoudaArray(ak_data)
319+
320+
meth = getattr(ak_data, reduction, None)
321+
if meth is None:
322+
pytest.xfail(f"{reduction} not implemented on pdarray backend yet")
323+
324+
try:
325+
accepts_ddof = "ddof" in inspect.signature(meth).parameters
326+
except (TypeError, ValueError):
327+
# If we can't introspect, just do a smoke test
328+
accepts_ddof = False
329+
330+
if not accepts_ddof:
331+
# ddof should be ignored by your filtering logic (no crash)
332+
r = arr._reduce(reduction, ddof=1)
333+
assert isinstance(r, numeric_and_bool_scalars)
334+
return
335+
336+
r0 = arr._reduce(reduction, ddof=0)
337+
r1 = arr._reduce(reduction, ddof=1)
338+
339+
# Compare to numpy for correctness
340+
np_data = np.array([1.0, 2.0, 3.0, 4.0])
341+
expected0 = getattr(np_data, reduction)(ddof=0)
342+
expected1 = getattr(np_data, reduction)(ddof=1)
343+
344+
assert np.isclose(r0, expected0)
345+
assert np.isclose(r1, expected1)
346+
347+
@pytest.mark.parametrize("reduction", ["sum", "mean", "min", "max"])
348+
def test_reduce_ignores_unknown_kwargs(self, reduction):
349+
ak_data = ak.arange(10)
350+
arr = ArkoudaArray(ak_data)
351+
352+
r1 = arr._reduce(reduction)
353+
r2 = arr._reduce(reduction, totally_not_a_real_kwarg=123)
354+
355+
assert r1 == r2
356+
243357
def test_reduce_invalid(self):
244358
ak_data = ak.arange(10)
245359
arr = ArkoudaArray(ak_data)
246360
with pytest.raises(TypeError):
247-
arr._reduce("mean")
361+
arr._reduce("test")
248362

249363
def test_concat_same_type(self):
250364
a1 = ArkoudaArray(ak.array([1, 2]))

0 commit comments

Comments
 (0)