diff --git a/doc/release_notes.rst b/doc/release_notes.rst index 379c07de..ac86ad97 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -1,8 +1,10 @@ Release Notes ============= -.. Upcoming Version -.. ---------------- +Upcoming Version +---------------- + +* Added support for arithmetic operations with custom classes. Version 0.5.0 -------------- diff --git a/linopy/expressions.py b/linopy/expressions.py index 4473230f..4c5598af 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -488,11 +488,14 @@ def __add__(self, other: SideLike) -> LinearExpression: Note: If other is a numpy array or pandas object without axes names, dimension names of self will be filled in other """ - if np.isscalar(other): - return self.assign(const=self.const + other) + try: + if np.isscalar(other): + return self.assign(const=self.const + other) - other = as_expression(other, model=self.model, dims=self.coord_dims) - return merge([self, other], cls=self.__class__) + other = as_expression(other, model=self.model, dims=self.coord_dims) + return merge([self, other], cls=self.__class__) + except TypeError: + return NotImplemented def __radd__(self, other: int) -> LinearExpression | NotImplementedType: # This is needed for using python's sum function @@ -505,11 +508,14 @@ def __sub__(self, other: SideLike) -> LinearExpression: Note: If other is a numpy array or pandas object without axes names, dimension names of self will be filled in other """ - if np.isscalar(other): - return self.assign_multiindex_safe(const=self.const - other) + try: + if np.isscalar(other): + return self.assign_multiindex_safe(const=self.const - other) - other = as_expression(other, model=self.model, dims=self.coord_dims) - return merge([self, -other], cls=self.__class__) + other = as_expression(other, model=self.model, dims=self.coord_dims) + return merge([self, -other], cls=self.__class__) + except TypeError: + return NotImplemented def __neg__(self) -> LinearExpression | QuadraticExpression: """ @@ -524,19 +530,22 @@ def __mul__( """ Multiply the expr by a factor. """ - if isinstance(other, QuadraticExpression): - raise TypeError( - "unsupported operand type(s) for *: " - f"{type(self)} and {type(other)}. " - "Higher order non-linear expressions are not yet supported." - ) - elif isinstance(other, (variables.Variable, variables.ScalarVariable)): - other = other.to_linexpr() + try: + if isinstance(other, QuadraticExpression): + raise TypeError( + "unsupported operand type(s) for *: " + f"{type(self)} and {type(other)}. " + "Higher order non-linear expressions are not yet supported." + ) + elif isinstance(other, (variables.Variable, variables.ScalarVariable)): + other = other.to_linexpr() - if isinstance(other, (LinearExpression, ScalarLinearExpression)): - return self._multiply_by_linear_expression(other) - else: - return self._multiply_by_constant(other) + if isinstance(other, (LinearExpression, ScalarLinearExpression)): + return self._multiply_by_linear_expression(other) + else: + return self._multiply_by_constant(other) + except TypeError: + return NotImplemented def _multiply_by_linear_expression( self, other: LinearExpression | ScalarLinearExpression @@ -599,15 +608,18 @@ def __matmul__( def __div__( self, other: Variable | ConstantLike ) -> LinearExpression | QuadraticExpression: - if isinstance( - other, (LinearExpression, variables.Variable, variables.ScalarVariable) - ): - raise TypeError( - "unsupported operand type(s) for /: " - f"{type(self)} and {type(other)}" - "Non-linear expressions are not yet supported." - ) - return self.__mul__(1 / other) + try: + if isinstance( + other, (LinearExpression, variables.Variable, variables.ScalarVariable) + ): + raise TypeError( + "unsupported operand type(s) for /: " + f"{type(self)} and {type(other)}" + "Non-linear expressions are not yet supported." + ) + return self.__mul__(1 / other) + except TypeError: + return NotImplemented def __truediv__( self, other: Variable | ConstantLike @@ -1557,13 +1569,17 @@ def __add__( Note: If other is a numpy array or pandas object without axes names, dimension names of self will be filled in other """ - if np.isscalar(other): - return self.assign(const=self.const + other) + try: + if np.isscalar(other): + return self.assign(const=self.const + other) - other = as_expression(other, model=self.model, dims=self.coord_dims) - if type(other) is LinearExpression: - other = other.to_quadexpr() - return merge([self, other], cls=self.__class__) # type: ignore + other = as_expression(other, model=self.model, dims=self.coord_dims) + + if type(other) is LinearExpression: + other = other.to_quadexpr() + return merge([self, other], cls=self.__class__) # type: ignore + except TypeError: + return NotImplemented def __radd__( self, other: LinearExpression | int @@ -1586,13 +1602,16 @@ def __sub__(self, other: SideLike | QuadraticExpression) -> QuadraticExpression: Note: If other is a numpy array or pandas object without axes names, dimension names of self will be filled in other """ - if np.isscalar(other): - return self.assign(const=self.const - other) - - other = as_expression(other, model=self.model, dims=self.coord_dims) - if type(other) is LinearExpression: - other = other.to_quadexpr() - return merge([self, -other], cls=self.__class__) # type: ignore + try: + if np.isscalar(other): + return self.assign(const=self.const - other) + + other = as_expression(other, model=self.model, dims=self.coord_dims) + if type(other) is LinearExpression: + other = other.to_quadexpr() + return merge([self, -other], cls=self.__class__) # type: ignore + except TypeError: + return NotImplemented def __rsub__(self, other: LinearExpression) -> QuadraticExpression: """ diff --git a/linopy/variables.py b/linopy/variables.py index 2b7263e8..d2dd3cd6 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -388,25 +388,33 @@ def __mul__( """ Multiply variables with a coefficient. """ - if isinstance(other, (expressions.LinearExpression, Variable, ScalarVariable)): - return self.to_linexpr() * other - else: + try: + if isinstance( + other, (expressions.LinearExpression, Variable, ScalarVariable) + ): + return self.to_linexpr() * other + return self.to_linexpr(other) + except TypeError: + return NotImplemented def __pow__(self, other: int) -> QuadraticExpression: """ Power of the variables with a coefficient. The only coefficient allowed is 2. """ - if not other == 2: - raise ValueError("Power must be 2.") - expr = self.to_linexpr() - return expr._multiply_by_linear_expression(expr) + if isinstance(other, int) and other == 2: + expr = self.to_linexpr() + return expr._multiply_by_linear_expression(expr) + return NotImplemented def __rmul__(self, other: float | DataArray | int | ndarray) -> LinearExpression: """ Right-multiply variables with a coefficient. """ - return self.to_linexpr(other) + try: + return self.to_linexpr(other) + except TypeError: + return NotImplemented def __matmul__( self, other: LinearExpression | ndarray | Variable @@ -436,7 +444,10 @@ def __truediv__( """ True divide variables with a coefficient. """ - return self.__div__(coefficient) + try: + return self.__div__(coefficient) + except TypeError: + return NotImplemented def __add__( self, other: int | QuadraticExpression | LinearExpression | Variable @@ -444,7 +455,10 @@ def __add__( """ Add variables to linear expressions or other variables. """ - return self.to_linexpr() + other + try: + return self.to_linexpr() + other + except TypeError: + return NotImplemented def __radd__(self, other: int) -> Variable | NotImplementedType: # This is needed for using python's sum function @@ -456,7 +470,10 @@ def __sub__( """ Subtract linear expressions or other variables from the variables. """ - return self.to_linexpr() - other + try: + return self.to_linexpr() - other + except TypeError: + return NotImplemented def __le__(self, other: SideLike) -> Constraint: return self.to_linexpr().__le__(other) diff --git a/test/test_compatible_arithmetrics.py b/test/test_compatible_arithmetrics.py new file mode 100644 index 00000000..0b5829cf --- /dev/null +++ b/test/test_compatible_arithmetrics.py @@ -0,0 +1,139 @@ +from typing import Any + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from linopy import LESS_EQUAL, Model, Variable +from linopy.testing import assert_linequal + + +class SomeOtherDatatype: + """ + A class that is not a subclass of xarray.DataArray, but stores data in a compatible way. + It defines all necessary arithmetrics AND __array_ufunc__ to ensure that operations are + performed on the active_data. + """ + + def __init__(self, data: xr.DataArray) -> None: + self.data1 = data + self.data2 = data.copy() + self.active = 1 + + def activate(self, active: int) -> None: + self.active = active + + @property + def active_data(self) -> xr.DataArray: + return self.data1 if self.active == 1 else self.data2 + + def __add__(self, other: Any) -> xr.DataArray: + return self.active_data + other + + def __sub__(self, other: Any) -> xr.DataArray: + return self.active_data - other + + def __mul__(self, other: Any) -> xr.DataArray: + return self.active_data * other + + def __truediv__(self, other: Any) -> xr.DataArray: + return self.active_data / other + + def __radd__(self, other: Any) -> Any: + return other + self.active_data + + def __rsub__(self, other: Any) -> Any: + return other - self.active_data + + def __rmul__(self, other: Any) -> Any: + return other * self.active_data + + def __rtruediv__(self, other: Any) -> Any: + return other / self.active_data + + def __neg__(self) -> xr.DataArray: + return -self.active_data + + def __pos__(self) -> xr.DataArray: + return +self.active_data + + def __abs__(self) -> xr.DataArray: + return abs(self.active_data) + + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): # type: ignore + # Ensure we always use the active_data when interacting with numpy/xarray operations + new_inputs = [ + inp.active_data if isinstance(inp, SomeOtherDatatype) else inp + for inp in inputs + ] + return getattr(ufunc, method)(*new_inputs, **kwargs) + + +@pytest.fixture( + params=[ + (pd.RangeIndex(10, name="first"),), + ( + pd.Index(range(5), name="first"), + pd.Index(range(3), name="second"), + pd.Index(range(2), name="third"), + ), + ], + ids=["single_dim", "multi_dim"], +) +def m(request) -> Model: # type: ignore + m = Model() + x = m.add_variables(coords=request.param, name="x") + m.add_variables(0, 10, name="z") + m.add_constraints(x, LESS_EQUAL, 0, name="c") + return m + + +def test_arithmetric_operations_variable(m: Model) -> None: + x: Variable = m.variables["x"] + rng = np.random.default_rng() + data = xr.DataArray(rng.random(x.shape), coords=x.coords) + other_datatype = SomeOtherDatatype(data.copy()) + assert_linequal(x + data, x + other_datatype) # type: ignore + assert_linequal(x - data, x - other_datatype) # type: ignore + assert_linequal(x * data, x * other_datatype) # type: ignore + assert_linequal(x / data, x / other_datatype) # type: ignore + assert_linequal(data * x, other_datatype * x) # type: ignore + assert x.__add__(object()) is NotImplemented # type: ignore + assert x.__sub__(object()) is NotImplemented # type: ignore + assert x.__mul__(object()) is NotImplemented # type: ignore + assert x.__truediv__(object()) is NotImplemented # type: ignore + assert x.__pow__(object()) is NotImplemented # type: ignore + assert x.__pow__(3) is NotImplemented + + +def test_arithmetric_operations_expr(m: Model) -> None: + x = m.variables["x"] + expr = x + 3 + rng = np.random.default_rng() + data = xr.DataArray(rng.random(x.shape), coords=x.coords) + other_datatype = SomeOtherDatatype(data.copy()) + assert_linequal(expr + data, expr + other_datatype) + assert_linequal(expr - data, expr - other_datatype) + assert_linequal(expr * data, expr * other_datatype) + assert_linequal(expr / data, expr / other_datatype) + assert expr.__add__(object()) is NotImplemented + assert expr.__sub__(object()) is NotImplemented + assert expr.__mul__(object()) is NotImplemented + assert expr.__truediv__(object()) is NotImplemented + + +def test_arithmetric_operations_con(m: Model) -> None: + c = m.constraints["c"] + x = m.variables["x"] + rng = np.random.default_rng() + data = xr.DataArray(rng.random(x.shape), coords=x.coords) + other_datatype = SomeOtherDatatype(data.copy()) + assert_linequal(c.lhs + data, c.lhs + other_datatype) + assert_linequal(c.lhs - data, c.lhs - other_datatype) + assert_linequal(c.lhs * data, c.lhs * other_datatype) + assert_linequal(c.lhs / data, c.lhs / other_datatype) + assert_linequal(c.rhs + data, c.rhs + other_datatype) # type: ignore + assert_linequal(c.rhs - data, c.rhs - other_datatype) # type: ignore + assert_linequal(c.rhs * data, c.rhs * other_datatype) # type: ignore + assert_linequal(c.rhs / data, c.rhs / other_datatype) # type: ignore