Skip to content

Commit 7357416

Browse files
committed
Closes #5230: ArkoudaExtensionArray arithmetic
1 parent 4b458d1 commit 7357416

File tree

2 files changed

+130
-1
lines changed

2 files changed

+130
-1
lines changed

arkouda/pandas/extension/_arkouda_extension_array.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,25 @@
4444
4545
"""
4646

47-
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
47+
from __future__ import annotations
48+
49+
from types import NotImplementedType
50+
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, TypeVar, Union
4851

4952
import numpy as np
5053

5154
from pandas.api.extensions import ExtensionArray
55+
from typing_extensions import Self
5256

5357
from arkouda.numpy.dtypes import all_scalars
5458
from arkouda.numpy.pdarrayclass import pdarray
5559
from arkouda.numpy.pdarraysetops import concatenate as ak_concat
5660
from arkouda.pandas.categorical import Categorical
5761

5862

63+
# Self-type for correct return typing
64+
EA = TypeVar("EA", bound="ExtensionArray")
65+
5966
if TYPE_CHECKING:
6067
from arkouda.numpy.strings import Strings
6168
else:
@@ -73,6 +80,8 @@ def _ensure_numpy(x):
7380
class ArkoudaExtensionArray(ExtensionArray):
7481
default_fill_value: Optional[Union[all_scalars, str]] = -1
7582

83+
_data: Any
84+
7685
def __init__(self, data):
7786
# Subclasses should ensure this is the correct ak object
7887
self._data = data
@@ -94,6 +103,62 @@ def __len__(self):
94103
"""
95104
return len(self._data)
96105

106+
@classmethod
107+
def _from_data(cls: type[Self], data: Any) -> Self:
108+
return cls(data)
109+
110+
def _arith_method(
111+
self,
112+
other: object,
113+
op: Callable[[Any, Any], Any],
114+
) -> Union[Self, NotImplementedType]:
115+
"""
116+
Apply an elementwise arithmetic operation between this ExtensionArray and
117+
``other``.
118+
119+
This is the pandas ExtensionArray arithmetic hook. Pandas uses this method
120+
(via its internal operator dispatch) to implement operators like ``+``,
121+
``-``, ``*``, etc. for arrays/Series backed by Arkouda.
122+
123+
Parameters
124+
----------
125+
other : object
126+
The right-hand operand. Supported forms:
127+
128+
* ExtensionArray with a ``_data`` attribute: the operand is unwrapped to
129+
its underlying Arkouda data.
130+
* scalar: any NumPy scalar / Python scalar supported by the underlying
131+
Arkouda operation.
132+
133+
Any other type returns ``NotImplemented`` so that pandas/Python can fall
134+
back to alternate dispatch paths.
135+
op : callable
136+
A binary operator (e.g., ``operator.add``). Must accept
137+
``(self._data, other)`` and return an Arkouda-backed result.
138+
139+
Returns
140+
-------
141+
ExtensionArray or NotImplemented
142+
A new array of the same ExtensionArray class as ``self`` containing the
143+
elementwise result, or ``NotImplemented`` for unsupported operand types.
144+
145+
Notes
146+
-----
147+
* This method does **not** perform index alignment; pandas handles alignment
148+
at the Series/DataFrame level before calling into the ExtensionArray.
149+
* Type coercion / promotion behavior is determined by the underlying Arkouda
150+
implementation of ``op``.
151+
"""
152+
if isinstance(other, ExtensionArray) and hasattr(other, "_data"):
153+
other = other._data
154+
elif np.isscalar(other):
155+
pass
156+
else:
157+
return NotImplemented
158+
159+
result = op(self._data, other)
160+
return type(self)(result)
161+
97162
def copy(self, deep: bool = True):
98163
"""
99164
Return a copy of the array.

tests/pandas/extension/arkouda_extension.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import operator
2+
13
import numpy as np
4+
import pandas as pd
25
import pytest
36

47
import arkouda as ak
@@ -378,6 +381,7 @@ def assert_indices(self, perm: pdarray, expected_py_indices):
378381
def test_argsort_pdarray_float_ascending_nan_positions(self, na_position, expected):
379382
a = ak.array([3.0, float("nan"), 1.0, 2.0])
380383
ea = ArkoudaExtensionArray(a)
384+
381385
perm = ea.argsort(ascending=True, na_position=na_position)
382386
self.assert_indices(perm, expected)
383387

@@ -650,3 +654,63 @@ def test_copy_shallow_creates_new_wrapper_but_shares_data(self, ea):
650654

651655
# Values are equal
652656
np.testing.assert_array_equal(shallow.to_numpy(), ea.to_numpy())
657+
658+
659+
class TestArkoudaExtensionArrayArithmatic:
660+
@pytest.mark.parametrize(
661+
"op, np_op",
662+
[
663+
(operator.add, operator.add),
664+
(operator.sub, operator.sub),
665+
(operator.mul, operator.mul),
666+
],
667+
)
668+
def test_arith_method_with_arkouda_array_operand(self, op, np_op):
669+
x = pd.array([1, 2, 3], dtype="ak_int64")
670+
y = pd.array([10, 20, 30], dtype="ak_int64")
671+
672+
out = x._arith_method(y, op)
673+
674+
assert type(out) is type(x)
675+
np.testing.assert_array_equal(out.to_numpy(), np_op(np.array([1, 2, 3]), np.array([10, 20, 30])))
676+
677+
@pytest.mark.parametrize(
678+
"op, scalar, expected",
679+
[
680+
(operator.add, 5, np.array([6, 7, 8])),
681+
(operator.sub, 1, np.array([0, 1, 2])),
682+
(operator.mul, 2, np.array([2, 4, 6])),
683+
],
684+
)
685+
def test_arith_method_with_scalar_operand(self, op, scalar, expected):
686+
x = pd.array([1, 2, 3], dtype="ak_int64")
687+
688+
out = x._arith_method(scalar, op)
689+
690+
assert type(out) is type(x)
691+
np.testing.assert_array_equal(out.to_numpy(), expected)
692+
693+
def test_arith_method_returns_notimplemented_for_unsupported_other(self):
694+
x = pd.array([1, 2, 3], dtype="ak_int64")
695+
696+
# list is not scalar and not an Arkouda EA => NotImplemented
697+
out = x._arith_method([1, 2, 3], operator.add)
698+
assert out is NotImplemented
699+
700+
def test_operator_add_raises_typeerror_for_unsupported_other(self):
701+
# This checks the user-visible behavior when NotImplemented propagates.
702+
x = pd.array([1, 2, 3], dtype="ak_int64")
703+
704+
with pytest.raises(TypeError):
705+
_ = x + [1, 2, 3]
706+
707+
def test_arith_method_unwraps_other_data_attribute(self):
708+
# Ensures the unwrap path is actually used.
709+
x = pd.array([1, 2, 3], dtype="ak_int64")
710+
y = pd.array([10, 20, 30], dtype="ak_int64")
711+
712+
# Make sure y is an EA and has _data (the thing we unwrap).
713+
assert hasattr(y, "_data")
714+
715+
out = x._arith_method(y, operator.add)
716+
np.testing.assert_array_equal(out.to_numpy(), np.array([11, 22, 33]))

0 commit comments

Comments
 (0)