Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 71 additions & 1 deletion arkouda/pandas/extension/_arkouda_extension_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,15 @@

"""

from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
from __future__ import annotations

from types import NotImplementedType
from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, Tuple, TypeVar, Union

import numpy as np

from pandas.api.extensions import ExtensionArray
from typing_extensions import Self

from arkouda.numpy.dtypes import all_scalars
from arkouda.numpy.pdarrayclass import pdarray
Expand All @@ -73,6 +77,8 @@ def _ensure_numpy(x):
class ArkoudaExtensionArray(ExtensionArray):
default_fill_value: Optional[Union[all_scalars, str]] = -1

_data: Any

def __init__(self, data):
# Subclasses should ensure this is the correct ak object
self._data = data
Expand All @@ -94,6 +100,70 @@ def __len__(self):
"""
return len(self._data)

@classmethod
def _from_data(cls: type[Self], data: Any) -> Self:
return cls(data)

def _arith_method(
self,
other: object,
op: Callable[[Any, Any], Any],
) -> Union[Self, NotImplementedType]:
"""
Apply an elementwise arithmetic operation between this ExtensionArray and
``other``.

This is the pandas ExtensionArray arithmetic hook. Pandas uses this method
(via its internal operator dispatch) to implement operators like ``+``,
``-``, ``*``, etc. for arrays/Series backed by Arkouda.

Parameters
----------
other : object
The right-hand operand. Supported forms:

* ExtensionArray with a ``_data`` attribute: the operand is unwrapped to
its underlying Arkouda data.
* scalar: any NumPy scalar / Python scalar supported by the underlying
Arkouda operation.

Any other type returns ``NotImplemented`` so that pandas/Python can fall
back to alternate dispatch paths.
op : callable
A binary operator (e.g., ``operator.add``). Must accept
``(self._data, other)`` and return an Arkouda-backed result.

Returns
-------
ExtensionArray or NotImplemented
A new array of the same ExtensionArray class as ``self`` containing the
elementwise result, or ``NotImplemented`` for unsupported operand types.

Notes
-----
* This method does **not** perform index alignment; pandas handles alignment
at the Series/DataFrame level before calling into the ExtensionArray.
* Type coercion / promotion behavior is determined by the underlying Arkouda
implementation of ``op``.
"""
from arkouda.numpy.pdarraycreation import array as ak_array

if isinstance(other, ExtensionArray) and hasattr(other, "_data"):
other = other._data
if isinstance(other, (np.ndarray, Iterable, pdarray, Strings)):
other = ak_array(other, copy=False)
elif isinstance(other, Categorical):
other = other.to_strings()
else:
return NotImplemented
elif np.isscalar(other):
pass
else:
return NotImplemented

result = op(self._data, other)
return self._from_data(result)

def copy(self, deep: bool = True):
"""
Return a copy of the array.
Expand Down
68 changes: 68 additions & 0 deletions tests/pandas/extension/arkouda_extension.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import operator

import numpy as np
import pandas as pd
import pytest

import arkouda as ak
Expand Down Expand Up @@ -378,6 +381,7 @@ def assert_indices(self, perm: pdarray, expected_py_indices):
def test_argsort_pdarray_float_ascending_nan_positions(self, na_position, expected):
a = ak.array([3.0, float("nan"), 1.0, 2.0])
ea = ArkoudaExtensionArray(a)

perm = ea.argsort(ascending=True, na_position=na_position)
self.assert_indices(perm, expected)

Expand Down Expand Up @@ -650,3 +654,67 @@ def test_copy_shallow_creates_new_wrapper_but_shares_data(self, ea):

# Values are equal
np.testing.assert_array_equal(shallow.to_numpy(), ea.to_numpy())


class TestArkoudaExtensionArrayArithmatic:
@pytest.mark.parametrize(
"op, np_op",
[
(operator.add, operator.add),
(operator.sub, operator.sub),
(operator.mul, operator.mul),
],
)
def test_arith_method_with_arkouda_array_operand(self, op, np_op):
x = pd.array([1, 2, 3], dtype="ak_int64")
y = pd.array([10, 20, 30], dtype="ak_int64")

out = x._arith_method(y, op)

assert type(out) is type(x)
np.testing.assert_array_equal(out.to_numpy(), np_op(np.array([1, 2, 3]), np.array([10, 20, 30])))

@pytest.mark.parametrize(
"op, scalar, expected",
[
(operator.add, 5, np.array([6, 7, 8])),
(operator.sub, 1, np.array([0, 1, 2])),
(operator.mul, 2, np.array([2, 4, 6])),
(operator.add, 5.0, np.array([6.0, 7.0, 8.0])),
(operator.sub, 1.0, np.array([0.0, 1.0, 2.0])),
(operator.mul, 2.0, np.array([2.0, 4.0, 6.0])),
],
)
def test_arith_method_with_scalar_operand(self, op, scalar, expected):
x = pd.array([1, 2, 3], dtype="ak_int64")

out = x._arith_method(scalar, op)

assert type(out) is type(x)
assert out.to_numpy().dtype == expected.dtype
np.testing.assert_array_equal(out.to_numpy(), expected)

def test_arith_method_returns_notimplemented_for_unsupported_other(self):
x = pd.array([1, 2, 3], dtype="ak_int64")

# list is not scalar and not an Arkouda EA => NotImplemented
out = x._arith_method([1, 2, 3], operator.add)
assert out is NotImplemented

def test_operator_add_raises_typeerror_for_unsupported_other(self):
# This checks the user-visible behavior when NotImplemented propagates.
x = pd.array([1, 2, 3], dtype="ak_int64")

with pytest.raises(TypeError):
_ = x + [1, 2, 3]

def test_arith_method_unwraps_other_data_attribute(self):
# Ensures the unwrap path is actually used.
x = pd.array([1, 2, 3], dtype="ak_int64")
y = pd.array([10, 20, 30], dtype="ak_int64")

# Make sure y is an EA and has _data (the thing we unwrap).
assert hasattr(y, "_data")

out = x._arith_method(y, operator.add)
np.testing.assert_array_equal(out.to_numpy(), np.array([11, 22, 33]))