Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Release Notes
.. Upcoming Version
.. ----------------

.. * Added support for arithmetic operations with custom classes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.. means this is commented out, please remove it for your note and the "upcoming version" header


Version 0.5.0
--------------

Expand Down
26 changes: 20 additions & 6 deletions linopy/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,10 @@
if np.isscalar(other):
return self.assign(const=self.const + other)

other = as_expression(other, model=self.model, dims=self.coord_dims)
try:
other = as_expression(other, model=self.model, dims=self.coord_dims)
except TypeError:
return NotImplemented
return merge([self, other], cls=self.__class__)

def __radd__(self, other: int) -> LinearExpression | NotImplementedType:
Expand All @@ -508,7 +511,10 @@
if np.isscalar(other):
return self.assign_multiindex_safe(const=self.const - other)

other = as_expression(other, model=self.model, dims=self.coord_dims)
try:
other = as_expression(other, model=self.model, dims=self.coord_dims)
except TypeError:
return NotImplemented
return merge([self, -other], cls=self.__class__)

def __neg__(self) -> LinearExpression | QuadraticExpression:
Expand Down Expand Up @@ -536,7 +542,10 @@
if isinstance(other, (LinearExpression, ScalarLinearExpression)):
return self._multiply_by_linear_expression(other)
else:
return self._multiply_by_constant(other)
try:
return self._multiply_by_constant(other)
except TypeError:
return NotImplemented

def _multiply_by_linear_expression(
self, other: LinearExpression | ScalarLinearExpression
Expand Down Expand Up @@ -1560,7 +1569,10 @@
if np.isscalar(other):
return self.assign(const=self.const + other)

other = as_expression(other, model=self.model, dims=self.coord_dims)
try:
other = as_expression(other, model=self.model, dims=self.coord_dims)
except TypeError:
return NotImplemented

Check warning on line 1575 in linopy/expressions.py

View check run for this annotation

Codecov / codecov/patch

linopy/expressions.py#L1574-L1575

Added lines #L1574 - L1575 were not covered by tests
if type(other) is LinearExpression:
other = other.to_quadexpr()
return merge([self, other], cls=self.__class__) # type: ignore
Expand Down Expand Up @@ -1588,8 +1600,10 @@
"""
if np.isscalar(other):
return self.assign(const=self.const - other)

other = as_expression(other, model=self.model, dims=self.coord_dims)
try:
other = as_expression(other, model=self.model, dims=self.coord_dims)
except TypeError:
return NotImplemented

Check warning on line 1606 in linopy/expressions.py

View check run for this annotation

Codecov / codecov/patch

linopy/expressions.py#L1605-L1606

Added lines #L1605 - L1606 were not covered by tests
if type(other) is LinearExpression:
other = other.to_quadexpr()
return merge([self, -other], cls=self.__class__) # type: ignore
Expand Down
29 changes: 23 additions & 6 deletions linopy/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,10 @@
linopy.LinearExpression
Linear expression with the variables and coefficients.
"""
coefficient = as_dataarray(coefficient, coords=self.coords, dims=self.dims)
try:
coefficient = as_dataarray(coefficient, coords=self.coords, dims=self.dims)
except TypeError:
return NotImplemented
ds = Dataset({"coeffs": coefficient, "vars": self.labels}).expand_dims(
TERM_DIM, -1
)
Expand Down Expand Up @@ -390,8 +393,10 @@
"""
if isinstance(other, (expressions.LinearExpression, Variable, ScalarVariable)):
return self.to_linexpr() * other
else:
try:
return self.to_linexpr(other)
except TypeError:
return NotImplemented

Check warning on line 399 in linopy/variables.py

View check run for this annotation

Codecov / codecov/patch

linopy/variables.py#L398-L399

Added lines #L398 - L399 were not covered by tests

def __pow__(self, other: int) -> QuadraticExpression:
"""
Expand All @@ -406,7 +411,10 @@
"""
Right-multiply variables with a coefficient.
"""
return self.to_linexpr(other)
try:
return self.to_linexpr(other)
except TypeError:
return NotImplemented

Check warning on line 417 in linopy/variables.py

View check run for this annotation

Codecov / codecov/patch

linopy/variables.py#L416-L417

Added lines #L416 - L417 were not covered by tests

def __matmul__(
self, other: LinearExpression | ndarray | Variable
Expand Down Expand Up @@ -436,15 +444,21 @@
"""
True divide variables with a coefficient.
"""
return self.__div__(coefficient)
try:
return self.__div__(coefficient)
except TypeError:
return NotImplemented

def __add__(
self, other: int | QuadraticExpression | LinearExpression | Variable
) -> QuadraticExpression | LinearExpression:
"""
Add variables to linear expressions or other variables.
"""
return self.to_linexpr() + other
try:
return self.to_linexpr() + other
except TypeError:
return NotImplemented

Check warning on line 461 in linopy/variables.py

View check run for this annotation

Codecov / codecov/patch

linopy/variables.py#L460-L461

Added lines #L460 - L461 were not covered by tests

def __radd__(self, other: int) -> Variable | NotImplementedType:
# This is needed for using python's sum function
Expand All @@ -456,7 +470,10 @@
"""
Subtract linear expressions or other variables from the variables.
"""
return self.to_linexpr() - other
try:
return self.to_linexpr() - other
except TypeError:
return NotImplemented

Check warning on line 476 in linopy/variables.py

View check run for this annotation

Codecov / codecov/patch

linopy/variables.py#L475-L476

Added lines #L475 - L476 were not covered by tests

def __le__(self, other: SideLike) -> Constraint:
return self.to_linexpr().__le__(other)
Expand Down
114 changes: 114 additions & 0 deletions test/test_compatible_arithmetrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import numpy as np
import pandas as pd
import pytest
import xarray as xr

from linopy import Model
from linopy.testing import assert_linequal


class SomeOtherDatatype:
"""
A class that is not a subclass of xarray.DataArray, but stores data in a compatible way.
It defines all necessary arithmetrics AND __array_ufunc__ to ensure that operations are
performed on the active_data.
"""

def __init__(self, data: xr.DataArray) -> None:
self.data1 = data
self.data2 = data.copy()
self.active = 1

def activate(self, active: int) -> None:
self.active = active

@property
def active_data(self) -> xr.DataArray:
return self.data1 if self.active == 1 else self.data2

def __add__(self, other):
return self.active_data + other

def __sub__(self, other):
return self.active_data - other

def __mul__(self, other):
return self.active_data * other

def __truediv__(self, other):
return self.active_data / other

def __radd__(self, other):
return other + self.active_data

def __rsub__(self, other):
return other - self.active_data

def __rmul__(self, other):
return other * self.active_data

def __rtruediv__(self, other):
return other / self.active_data

def __neg__(self):
return -self.active_data

def __pos__(self):
return +self.active_data

def __abs__(self):
return abs(self.active_data)

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
# Ensure we always use the active_data when interacting with numpy/xarray operations
new_inputs = [
inp.active_data if isinstance(inp, SomeOtherDatatype) else inp
for inp in inputs
]
return getattr(ufunc, method)(*new_inputs, **kwargs)


@pytest.fixture(
params=[
(pd.RangeIndex(10, name="first"),),
(
pd.Index(range(5), name="first"),
pd.Index(range(3), name="second"),
pd.Index(range(2), name="third"),
),
],
ids=["single_dim", "multi_dim"],
)
def m(request) -> Model:
m = Model()
m.add_variables(coords=request.param, name="x")
m.add_variables(0, 10, name="z")
m.add_constraints(m.variables["x"] >= 0, name="c")
return m


def test_arithmetric_operations_variable(m: Model) -> None:
x = m.variables["x"]
rng = np.random.default_rng()
data = xr.DataArray(rng.random(x.shape), coords=x.coords)
other_datatype = SomeOtherDatatype(data.copy())
assert_linequal(x + data, x + other_datatype)
assert_linequal(x - data, x - other_datatype)
assert_linequal(x * data, x * other_datatype)
assert_linequal(x / data, x / other_datatype)


def test_arithmetric_operations_con(m: Model) -> None:
c = m.constraints["c"]
x = m.variables["x"]
rng = np.random.default_rng()
data = xr.DataArray(rng.random(x.shape), coords=x.coords)
other_datatype = SomeOtherDatatype(data.copy())
assert_linequal(c.lhs + data, c.lhs + other_datatype)
assert_linequal(c.lhs - data, c.lhs - other_datatype)
assert_linequal(c.lhs * data, c.lhs * other_datatype)
assert_linequal(c.lhs / data, c.lhs / other_datatype)
assert_linequal(c.rhs + data, c.rhs + other_datatype)
assert_linequal(c.rhs - data, c.rhs - other_datatype)
assert_linequal(c.rhs * data, c.rhs * other_datatype)
assert_linequal(c.rhs / data, c.rhs / other_datatype)
Loading