Skip to content

Commit dd8d9a8

Browse files
committed
Closes #5230: ArkoudaExtensionArray arithmetic
1 parent 4b458d1 commit dd8d9a8

File tree

2 files changed

+125
-0
lines changed

2 files changed

+125
-0
lines changed

arkouda/pandas/extension/_arkouda_extension_array.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,67 @@ def __len__(self):
9494
"""
9595
return len(self._data)
9696

97+
def _arith_method(self, other, op):
98+
"""
99+
Apply an elementwise arithmetic operation between this ExtensionArray and
100+
``other``.
101+
102+
This is the pandas ExtensionArray arithmetic hook. Pandas uses this method
103+
(via its internal operator dispatch) to implement operators like ``+``,
104+
``-``, ``*``, etc. for arrays/Series backed by Arkouda.
105+
106+
Parameters
107+
----------
108+
other : object
109+
The right-hand operand. Supported forms:
110+
111+
* ``ArkoudaArray`` or ``ArkoudaStringArray``: the operand is unwrapped to
112+
its underlying Arkouda data via ``other._data``.
113+
* scalar: any NumPy scalar / Python scalar supported by the underlying
114+
Arkouda operation.
115+
116+
Any other type returns ``NotImplemented`` so that pandas/Python can fall
117+
back to alternate dispatch paths.
118+
op : callable
119+
A binary operator (e.g., ``operator.add``). Must accept
120+
``(self._data, other)`` and return an Arkouda-backed result.
121+
122+
Returns
123+
-------
124+
ArkoudaArray
125+
A new array of the same ExtensionArray class as ``self`` containing the
126+
elementwise result.
127+
128+
Notes
129+
-----
130+
* This method does **not** perform index alignment; pandas handles alignment
131+
at the Series/DataFrame level before calling into the ExtensionArray.
132+
* Type coercion / promotion behavior is determined by the underlying Arkouda
133+
implementation of ``op``.
134+
135+
Examples
136+
--------
137+
>>> import arkouda as ak
138+
>>> import pandas as pd
139+
>>> import operator
140+
>>> x = pd.array([1, 2, 3], dtype="ak_int64")
141+
>>> y = pd.array([10, 20, 30], dtype="ak_int64")
142+
>>> z = x._arith_method(y, operator.add)
143+
>>> z
144+
ArkoudaArray([11 22 33])
145+
"""
146+
from arkouda.pandas.extension import ArkoudaArray, ArkoudaStringArray
147+
148+
if isinstance(other, (ArkoudaArray, ArkoudaStringArray)) and hasattr(other, "_data"):
149+
other = other._data
150+
elif np.isscalar(other):
151+
pass
152+
else:
153+
return NotImplemented
154+
155+
result = op(self._data, other)
156+
return type(self)(result)
157+
97158
def copy(self, deep: bool = True):
98159
"""
99160
Return a copy of the array.

tests/pandas/extension/arkouda_extension.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import operator
2+
13
import numpy as np
4+
import pandas as pd
25
import pytest
36

47
import arkouda as ak
@@ -378,6 +381,7 @@ def assert_indices(self, perm: pdarray, expected_py_indices):
378381
def test_argsort_pdarray_float_ascending_nan_positions(self, na_position, expected):
379382
a = ak.array([3.0, float("nan"), 1.0, 2.0])
380383
ea = ArkoudaExtensionArray(a)
384+
381385
perm = ea.argsort(ascending=True, na_position=na_position)
382386
self.assert_indices(perm, expected)
383387

@@ -650,3 +654,63 @@ def test_copy_shallow_creates_new_wrapper_but_shares_data(self, ea):
650654

651655
# Values are equal
652656
np.testing.assert_array_equal(shallow.to_numpy(), ea.to_numpy())
657+
658+
659+
class TestArkoudaExtensionArrayArithmatic:
660+
@pytest.mark.parametrize(
661+
"op, np_op",
662+
[
663+
(operator.add, operator.add),
664+
(operator.sub, operator.sub),
665+
(operator.mul, operator.mul),
666+
],
667+
)
668+
def test_arith_method_with_arkouda_array_operand(self, op, np_op):
669+
x = pd.array([1, 2, 3], dtype="ak_int64")
670+
y = pd.array([10, 20, 30], dtype="ak_int64")
671+
672+
out = x._arith_method(y, op)
673+
674+
assert type(out) is type(x)
675+
np.testing.assert_array_equal(out.to_numpy(), np_op(np.array([1, 2, 3]), np.array([10, 20, 30])))
676+
677+
@pytest.mark.parametrize(
678+
"op, scalar, expected",
679+
[
680+
(operator.add, 5, np.array([6, 7, 8])),
681+
(operator.sub, 1, np.array([0, 1, 2])),
682+
(operator.mul, 2, np.array([2, 4, 6])),
683+
],
684+
)
685+
def test_arith_method_with_scalar_operand(self, op, scalar, expected):
686+
x = pd.array([1, 2, 3], dtype="ak_int64")
687+
688+
out = x._arith_method(scalar, op)
689+
690+
assert type(out) is type(x)
691+
np.testing.assert_array_equal(out.to_numpy(), expected)
692+
693+
def test_arith_method_returns_notimplemented_for_unsupported_other(self):
694+
x = pd.array([1, 2, 3], dtype="ak_int64")
695+
696+
# list is not scalar and not an Arkouda EA => NotImplemented
697+
out = x._arith_method([1, 2, 3], operator.add)
698+
assert out is NotImplemented
699+
700+
def test_operator_add_raises_typeerror_for_unsupported_other(self):
701+
# This checks the user-visible behavior when NotImplemented propagates.
702+
x = pd.array([1, 2, 3], dtype="ak_int64")
703+
704+
with pytest.raises(TypeError):
705+
_ = x + [1, 2, 3]
706+
707+
def test_arith_method_unwraps_other_data_attribute(self):
708+
# Ensures the unwrap path is actually used.
709+
x = pd.array([1, 2, 3], dtype="ak_int64")
710+
y = pd.array([10, 20, 30], dtype="ak_int64")
711+
712+
# Make sure y is an EA and has _data (the thing we unwrap).
713+
assert hasattr(y, "_data")
714+
715+
out = x._arith_method(y, operator.add)
716+
np.testing.assert_array_equal(out.to_numpy(), np.array([11, 22, 33]))

0 commit comments

Comments
 (0)