-
-
Notifications
You must be signed in to change notification settings - Fork 19.1k
ENH: Support ExtensionArray operators via a mixin #21261
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
5b0ebc7
d7596c6
7f2b0a1
ec96841
a07bb49
1d7b2b3
7bad559
dfcda3b
aaaa8fd
4bcf978
f958d7b
ef83c3a
41dc5ca
be6656b
a0f503c
700d75b
87e8f55
97bd291
8fc93e4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,5 @@ | ||
from .base import ExtensionArray # noqa | ||
from .base import (ExtensionArray, # noqa | ||
ExtensionArithmeticMixin, | ||
ExtensionComparisonMixin, | ||
ExtensionOpsBase) | ||
from .categorical import Categorical # noqa |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,8 +7,16 @@ | |
""" | ||
import numpy as np | ||
|
||
import operator | ||
|
||
from pandas.errors import AbstractMethodError | ||
from pandas.compat.numpy import function as nv | ||
from pandas.compat import set_function_name, PY3 | ||
import pandas.core.common as com | ||
from pandas.core.dtypes.common import ( | ||
is_extension_array_dtype, | ||
is_list_like) | ||
from pandas.core import ops | ||
|
||
_not_implemented_message = "{} does not implement {}." | ||
|
||
|
@@ -610,3 +618,121 @@ def _ndarray_values(self): | |
used for interacting with our indexers. | ||
""" | ||
return np.array(self) | ||
|
||
|
||
class ExtensionOpsBase(object): | ||
|
||
""" | ||
A base class for the mixins for different operators. | ||
Can also be used to define an individual method for a specific | ||
operator using the class method create_method() | ||
""" | ||
@classmethod | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. blank line above this for PEP8? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It passed the PEP8 tests....... But I will add it in. |
||
def create_method(cls, op): | ||
|
||
""" | ||
A class method that returns a method that will correspond to an | ||
operator for an ExtensionArray subclass. | ||
|
||
|
||
Parameters | ||
---------- | ||
op: An operator that takes arguments op(a, b) | ||
|
||
Returns | ||
------- | ||
A method that can be bound to a method of a class | ||
|
||
Usage | ||
----- | ||
Given an ExtensionArray subclass called MyClass, use | ||
|
||
mymethod = create_method(my_operator) | ||
|
||
in the class definition of MyClass to create the operator | ||
|
||
""" | ||
op_name = ops._get_op_name(op, False) | ||
|
||
def _binop(self, other): | ||
def convert_values(parm): | ||
|
||
if isinstance(parm, ExtensionArray): | ||
ovalues = list(parm) | ||
|
||
elif is_extension_array_dtype(parm): | ||
|
||
ovalues = parm.values | ||
elif is_list_like(parm): | ||
ovalues = parm | ||
else: # Assume its an object | ||
ovalues = [parm] * len(self) | ||
return ovalues | ||
lvalues = convert_values(self) | ||
|
||
rvalues = convert_values(other) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should probably do alignment as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It doesn't need to, for a few reasons:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, OK, this is also tested? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jorisvandenbossche Yes, this is tested because I am using the same tests that are used for general operators on Series, which tests things that are misaligned, etc. See As an aside, doing it this way uncovered the issues with |
||
|
||
try: | ||
res = [op(a, b) for (a, b) in zip(lvalues, rvalues)] | ||
except TypeError: | ||
msg = ("ExtensionDtype invalid operation " + | ||
"{opn} between {one} and {two}") | ||
raise TypeError(msg.format(opn=op_name, | ||
|
||
one=type(lvalues), | ||
two=type(rvalues))) | ||
|
||
res_values = com._values_from_object(res) | ||
|
||
|
||
try: | ||
res_values = self._from_sequence(res_values) | ||
except TypeError: | ||
|
||
pass | ||
|
||
return res_values | ||
|
||
name = '__{name}__'.format(name=op_name) | ||
|
||
return set_function_name(_binop, name, cls) | ||
|
||
|
||
class ExtensionArithmeticMixin(ExtensionOpsBase): | ||
|
||
"""A mixin for defining the arithmetic operations on an ExtensionArray | ||
class, where it assumed that the underlying objects have the operators | ||
already defined. | ||
|
||
Usage | ||
------ | ||
If you have defined a subclass MyClass(ExtensionArray), then | ||
use MyClass(ExtensionArray, ExtensionArithmeticMixin) to | ||
get the arithmetic operators | ||
""" | ||
|
||
__add__ = ExtensionOpsBase.create_method(operator.add) | ||
__radd__ = ExtensionOpsBase.create_method(ops.radd) | ||
__sub__ = ExtensionOpsBase.create_method(operator.sub) | ||
__rsub__ = ExtensionOpsBase.create_method(ops.rsub) | ||
__mul__ = ExtensionOpsBase.create_method(operator.mul) | ||
__rmul__ = ExtensionOpsBase.create_method(ops.rmul) | ||
__pow__ = ExtensionOpsBase.create_method(operator.pow) | ||
__rpow__ = ExtensionOpsBase.create_method(ops.rpow) | ||
__mod__ = ExtensionOpsBase.create_method(operator.mod) | ||
__rmod__ = ExtensionOpsBase.create_method(ops.rmod) | ||
__floordiv__ = ExtensionOpsBase.create_method(operator.floordiv) | ||
__rfloordiv__ = ExtensionOpsBase.create_method(ops.rfloordiv) | ||
__truediv__ = ExtensionOpsBase.create_method(operator.truediv) | ||
__rtruediv__ = ExtensionOpsBase.create_method(ops.rtruediv) | ||
if not PY3: | ||
__div__ = ExtensionOpsBase.create_method(operator.div) | ||
__rdiv__ = ExtensionOpsBase.create_method(ops.rdiv) | ||
|
||
__divmod__ = ExtensionOpsBase.create_method(divmod) | ||
|
||
|
||
|
||
class ExtensionComparisonMixin(ExtensionOpsBase): | ||
"""A mixin for defining the comparison operations on an ExtensionArray | ||
class, where it assumed that the underlying objects have the operators | ||
already defined. | ||
|
||
Usage | ||
------ | ||
If you have defined a subclass MyClass(ExtensionArray), then | ||
use MyClass(ExtensionArray, ExtensionComparisonMixin) to | ||
get the arithmetic operators | ||
""" | ||
__eq__ = ExtensionOpsBase.create_method(operator.eq) | ||
__ne__ = ExtensionOpsBase.create_method(operator.ne) | ||
__lt__ = ExtensionOpsBase.create_method(operator.lt) | ||
__gt__ = ExtensionOpsBase.create_method(operator.gt) | ||
__le__ = ExtensionOpsBase.create_method(operator.le) | ||
__ge__ = ExtensionOpsBase.create_method(operator.ge) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2972,16 +2972,20 @@ def get_value(self, series, key): | |
# use this, e.g. DatetimeIndex | ||
s = getattr(series, '_values', None) | ||
if isinstance(s, (ExtensionArray, Index)) and is_scalar(key): | ||
# GH 20825 | ||
# GH 20882, 21257 | ||
# Unify Index and ExtensionArray treatment | ||
# First try to convert the key to a location | ||
# If that fails, see if key is an integer, and | ||
# If that fails, raise a KeyError if an integer | ||
# index, otherwise, see if key is an integer, and | ||
# try that | ||
try: | ||
iloc = self.get_loc(key) | ||
return s[iloc] | ||
except KeyError: | ||
if is_integer(key): | ||
if (len(self) > 0 and | ||
self.inferred_type in ['integer', 'boolean']): | ||
|
||
raise | ||
|
||
elif is_integer(key): | ||
return s[key] | ||
|
||
s = com._values_from_object(series) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ | |
is_bool_dtype, | ||
is_list_like, | ||
is_scalar, | ||
is_extension_array_dtype, | ||
_ensure_object) | ||
from pandas.core.dtypes.cast import ( | ||
maybe_upcast_putmask, find_common_type, | ||
|
@@ -990,6 +991,26 @@ def _construct_divmod_result(left, result, index, name, dtype): | |
) | ||
|
||
|
||
def dispatch_to_extension_op(left, right, op_name): | ||
|
||
""" | ||
Assume that left is a Series backed by an ExtensionArray, | ||
apply the operator defined by op_name. | ||
""" | ||
|
||
method = getattr(left.values, op_name, None) | ||
if method is not None: | ||
res_values = method(right) | ||
|
||
if method is None or res_values is NotImplemented: | ||
msg = "ExtensionArray invalid operation {opn} between {one} and {two}" | ||
raise TypeError(msg.format(opn=op_name, | ||
|
||
one=type(left.values), | ||
two=type(right))) | ||
|
||
res_name = get_op_result_name(left, right) | ||
return left._constructor(res_values, index=left.index, | ||
name=res_name) | ||
|
||
|
||
def _arith_method_SERIES(cls, op, special): | ||
""" | ||
Wrapper function for Series arithmetic operations, to avoid | ||
|
@@ -1058,6 +1079,9 @@ def wrapper(left, right): | |
raise TypeError("{typ} cannot perform the operation " | ||
"{op}".format(typ=type(left).__name__, op=str_rep)) | ||
|
||
elif is_extension_array_dtype(left): | ||
return dispatch_to_extension_op(left, right, op_name) | ||
|
||
lvalues = left.values | ||
rvalues = right | ||
if isinstance(rvalues, ABCSeries): | ||
|
@@ -1208,6 +1232,9 @@ def wrapper(self, other, axis=None): | |
return self._constructor(res_values, index=self.index, | ||
name=res_name) | ||
|
||
elif is_extension_array_dtype(self): | ||
return dispatch_to_extension_op(self, other, op_name) | ||
|
||
elif isinstance(other, ABCSeries): | ||
# By this point we have checked that self._indexed_same(other) | ||
res_values = na_op(self.values, other.values) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2196,23 +2196,22 @@ def _binop(self, other, func, level=None, fill_value=None): | |
result.name = None | ||
return result | ||
|
||
def combine(self, other, func, fill_value=np.nan): | ||
def combine(self, other, func, fill_value=None): | ||
""" | ||
Perform elementwise binary operation on two Series using given function | ||
with optional fill value when an index is missing from one Series or | ||
the other | ||
|
||
Parameters | ||
---------- | ||
other : Series or scalar value | ||
func : function | ||
Function that takes two scalars as inputs and return a scalar | ||
fill_value : scalar value | ||
|
||
The default specifies to use the appropriate NaN value for | ||
the underlying dtype of the Series | ||
Returns | ||
------- | ||
result : Series | ||
|
||
Examples | ||
-------- | ||
>>> s1 = Series([1, 2]) | ||
|
@@ -2221,26 +2220,36 @@ def combine(self, other, func, fill_value=np.nan): | |
0 0 | ||
1 2 | ||
dtype: int64 | ||
|
||
See Also | ||
-------- | ||
Series.combine_first : Combine Series values, choosing the calling | ||
Series's values first | ||
""" | ||
self_is_ext = is_extension_array_dtype(self.values) | ||
|
||
if fill_value is None: | ||
fill_value = na_value_for_dtype(self.dtype, False) | ||
|
||
if isinstance(other, Series): | ||
new_index = self.index.union(other.index) | ||
new_name = ops.get_op_result_name(self, other) | ||
new_values = np.empty(len(new_index), dtype=self.dtype) | ||
for i, idx in enumerate(new_index): | ||
new_values = [] | ||
for idx in new_index: | ||
lv = self.get(idx, fill_value) | ||
rv = other.get(idx, fill_value) | ||
with np.errstate(all='ignore'): | ||
new_values[i] = func(lv, rv) | ||
new_values.append(func(lv, rv)) | ||
else: | ||
new_index = self.index | ||
with np.errstate(all='ignore'): | ||
new_values = func(self._values, other) | ||
new_values = [func(lv, other) for lv in self._values] | ||
new_name = self.name | ||
|
||
if self_is_ext and not is_categorical_dtype(self.values): | ||
|
||
try: | ||
new_values = self._values._from_sequence(new_values) | ||
except TypeError: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why are you catching a TypeError? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is also for the other PR, but: because the result might not necessarily be an ExtensionArray. This is of course a bit an unclear area of |
||
pass | ||
|
||
return self._constructor(new_values, index=new_index, name=new_name) | ||
|
||
def combine_first(self, other): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,9 @@ | |
import numpy as np | ||
|
||
import pandas as pd | ||
from pandas.core.arrays import ExtensionArray | ||
from pandas.core.arrays import (ExtensionArray, | ||
ExtensionArithmeticMixin, | ||
ExtensionComparisonMixin) | ||
from pandas.core.dtypes.base import ExtensionDtype | ||
|
||
|
||
|
@@ -24,11 +26,14 @@ def construct_from_string(cls, string): | |
"'{}'".format(cls, string)) | ||
|
||
|
||
class DecimalArray(ExtensionArray): | ||
class DecimalArray(ExtensionArray, ExtensionArithmeticMixin, | ||
ExtensionComparisonMixin): | ||
dtype = DecimalDtype() | ||
|
||
def __init__(self, values): | ||
assert all(isinstance(v, decimal.Decimal) for v in values) | ||
for val in values: | ||
if not isinstance(val, self.dtype.type): | ||
raise TypeError | ||
|
||
values = np.asarray(values, dtype=object) | ||
|
||
self._data = values | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,9 @@ | |
|
||
from pandas.tests.extension import base | ||
|
||
from pandas.tests.series.test_operators import TestSeriesOperators | ||
from pandas.util._decorators import cache_readonly | ||
|
||
from .array import DecimalDtype, DecimalArray, make_data | ||
|
||
|
||
|
@@ -183,3 +186,36 @@ def test_dataframe_constructor_with_different_dtype_raises(): | |
xpr = "Cannot coerce extension array to dtype 'int64'. " | ||
with tm.assert_raises_regex(ValueError, xpr): | ||
pd.DataFrame({"A": arr}, dtype='int64') | ||
|
||
|
||
_ts = pd.Series(DecimalArray(make_data())) | ||
|
||
|
||
class TestOperator(BaseDecimal, TestSeriesOperators): | ||
@cache_readonly | ||
|
||
def ts(self): | ||
ts = _ts.copy() | ||
ts.name = 'ts' | ||
return ts | ||
|
||
def test_operators(self): | ||
def absfunc(v): | ||
if isinstance(v, pd.Series): | ||
vals = v.values | ||
return pd.Series(vals._from_sequence([abs(i) for i in vals])) | ||
else: | ||
return abs(v) | ||
context = decimal.getcontext() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what the heck is all this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am reusing the tests that are in the class There is a related issue with respect to |
||
divbyzerotrap = context.traps[decimal.DivisionByZero] | ||
invalidoptrap = context.traps[decimal.InvalidOperation] | ||
context.traps[decimal.DivisionByZero] = 0 | ||
context.traps[decimal.InvalidOperation] = 0 | ||
super(TestOperator, self).test_operators(absfunc) | ||
context.traps[decimal.DivisionByZero] = divbyzerotrap | ||
context.traps[decimal.InvalidOperation] = invalidoptrap | ||
|
||
def test_operators_corner(self): | ||
pytest.skip("Cannot add empty Series of float64 to DecimalArray") | ||
|
||
def test_divmod(self): | ||
pytest.skip("divmod not appropriate for Decimal type") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make a separate sub-section. pls expand this a bit, I know what you mean, but I doubt the average reader does.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the note here still accurate or has the factory been changed to just mixins?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've written a section for whatsnew