Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 66 additions & 1 deletion arkouda/pandas/extension/_arkouda_extension_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,25 @@

"""

from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
from __future__ import annotations

from types import NotImplementedType
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, TypeVar, Union

import numpy as np

from pandas.api.extensions import ExtensionArray
from typing_extensions import Self

from arkouda.numpy.dtypes import all_scalars
from arkouda.numpy.pdarrayclass import pdarray
from arkouda.numpy.pdarraysetops import concatenate as ak_concat
from arkouda.pandas.categorical import Categorical


# Self-type for correct return typing
EA = TypeVar("EA", bound="ExtensionArray")

if TYPE_CHECKING:
from arkouda.numpy.strings import Strings
else:
Expand All @@ -73,6 +80,8 @@ def _ensure_numpy(x):
class ArkoudaExtensionArray(ExtensionArray):
default_fill_value: Optional[Union[all_scalars, str]] = -1

_data: Any

def __init__(self, data):
# Subclasses should ensure this is the correct ak object
self._data = data
Expand All @@ -94,6 +103,62 @@ def __len__(self):
"""
return len(self._data)

@classmethod
def _from_data(cls: type[Self], data: Any) -> Self:
return cls(data)

def _arith_method(
self,
other: object,
op: Callable[[Any, Any], Any],
) -> Union[Self, NotImplementedType]:
"""
Apply an elementwise arithmetic operation between this ExtensionArray and
``other``.

This is the pandas ExtensionArray arithmetic hook. Pandas uses this method
(via its internal operator dispatch) to implement operators like ``+``,
``-``, ``*``, etc. for arrays/Series backed by Arkouda.

Parameters
----------
other : object
The right-hand operand. Supported forms:

* ExtensionArray with a ``_data`` attribute: the operand is unwrapped to
its underlying Arkouda data.
* scalar: any NumPy scalar / Python scalar supported by the underlying
Arkouda operation.

Any other type returns ``NotImplemented`` so that pandas/Python can fall
back to alternate dispatch paths.
op : callable
A binary operator (e.g., ``operator.add``). Must accept
``(self._data, other)`` and return an Arkouda-backed result.

Returns
-------
ExtensionArray or NotImplemented
A new array of the same ExtensionArray class as ``self`` containing the
elementwise result, or ``NotImplemented`` for unsupported operand types.

Notes
-----
* This method does **not** perform index alignment; pandas handles alignment
at the Series/DataFrame level before calling into the ExtensionArray.
* Type coercion / promotion behavior is determined by the underlying Arkouda
implementation of ``op``.
"""
if isinstance(other, ExtensionArray) and hasattr(other, "_data"):
other = other._data
elif np.isscalar(other):
pass
else:
return NotImplemented

result = op(self._data, other)
return type(self)(result)

def copy(self, deep: bool = True):
"""
Return a copy of the array.
Expand Down
64 changes: 64 additions & 0 deletions tests/pandas/extension/arkouda_extension.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import operator

import numpy as np
import pandas as pd
import pytest

import arkouda as ak
Expand Down Expand Up @@ -378,6 +381,7 @@ def assert_indices(self, perm: pdarray, expected_py_indices):
def test_argsort_pdarray_float_ascending_nan_positions(self, na_position, expected):
a = ak.array([3.0, float("nan"), 1.0, 2.0])
ea = ArkoudaExtensionArray(a)

perm = ea.argsort(ascending=True, na_position=na_position)
self.assert_indices(perm, expected)

Expand Down Expand Up @@ -650,3 +654,63 @@ def test_copy_shallow_creates_new_wrapper_but_shares_data(self, ea):

# Values are equal
np.testing.assert_array_equal(shallow.to_numpy(), ea.to_numpy())


class TestArkoudaExtensionArrayArithmatic:
@pytest.mark.parametrize(
"op, np_op",
[
(operator.add, operator.add),
(operator.sub, operator.sub),
(operator.mul, operator.mul),
],
)
def test_arith_method_with_arkouda_array_operand(self, op, np_op):
x = pd.array([1, 2, 3], dtype="ak_int64")
y = pd.array([10, 20, 30], dtype="ak_int64")

out = x._arith_method(y, op)

assert type(out) is type(x)
np.testing.assert_array_equal(out.to_numpy(), np_op(np.array([1, 2, 3]), np.array([10, 20, 30])))

@pytest.mark.parametrize(
"op, scalar, expected",
[
(operator.add, 5, np.array([6, 7, 8])),
(operator.sub, 1, np.array([0, 1, 2])),
(operator.mul, 2, np.array([2, 4, 6])),
],
)
def test_arith_method_with_scalar_operand(self, op, scalar, expected):
x = pd.array([1, 2, 3], dtype="ak_int64")

out = x._arith_method(scalar, op)

assert type(out) is type(x)
np.testing.assert_array_equal(out.to_numpy(), expected)

def test_arith_method_returns_notimplemented_for_unsupported_other(self):
x = pd.array([1, 2, 3], dtype="ak_int64")

# list is not scalar and not an Arkouda EA => NotImplemented
out = x._arith_method([1, 2, 3], operator.add)
assert out is NotImplemented

def test_operator_add_raises_typeerror_for_unsupported_other(self):
# This checks the user-visible behavior when NotImplemented propagates.
x = pd.array([1, 2, 3], dtype="ak_int64")

with pytest.raises(TypeError):
_ = x + [1, 2, 3]

def test_arith_method_unwraps_other_data_attribute(self):
# Ensures the unwrap path is actually used.
x = pd.array([1, 2, 3], dtype="ak_int64")
y = pd.array([10, 20, 30], dtype="ak_int64")

# Make sure y is an EA and has _data (the thing we unwrap).
assert hasattr(y, "_data")

out = x._arith_method(y, operator.add)
np.testing.assert_array_equal(out.to_numpy(), np.array([11, 22, 33]))