Skip to content

Commit 0fa0c74

Browse files
authored
Maximize compatability with Datatypes by returning NotImplemented if __add__, __mul__ ... fail (#417)
1 parent 99b2bb7 commit 0fa0c74

File tree

4 files changed

+232
-55
lines changed

4 files changed

+232
-55
lines changed

doc/release_notes.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
Release Notes
22
=============
33

4-
.. Upcoming Version
5-
.. ----------------
4+
Upcoming Version
5+
----------------
6+
7+
* Added support for arithmetic operations with custom classes.
68

79
Version 0.5.0
810
--------------

linopy/expressions.py

Lines changed: 61 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -488,11 +488,14 @@ def __add__(self, other: SideLike) -> LinearExpression:
488488
Note: If other is a numpy array or pandas object without axes names,
489489
dimension names of self will be filled in other
490490
"""
491-
if np.isscalar(other):
492-
return self.assign(const=self.const + other)
491+
try:
492+
if np.isscalar(other):
493+
return self.assign(const=self.const + other)
493494

494-
other = as_expression(other, model=self.model, dims=self.coord_dims)
495-
return merge([self, other], cls=self.__class__)
495+
other = as_expression(other, model=self.model, dims=self.coord_dims)
496+
return merge([self, other], cls=self.__class__)
497+
except TypeError:
498+
return NotImplemented
496499

497500
def __radd__(self, other: int) -> LinearExpression | NotImplementedType:
498501
# This is needed for using python's sum function
@@ -505,11 +508,14 @@ def __sub__(self, other: SideLike) -> LinearExpression:
505508
Note: If other is a numpy array or pandas object without axes names,
506509
dimension names of self will be filled in other
507510
"""
508-
if np.isscalar(other):
509-
return self.assign_multiindex_safe(const=self.const - other)
511+
try:
512+
if np.isscalar(other):
513+
return self.assign_multiindex_safe(const=self.const - other)
510514

511-
other = as_expression(other, model=self.model, dims=self.coord_dims)
512-
return merge([self, -other], cls=self.__class__)
515+
other = as_expression(other, model=self.model, dims=self.coord_dims)
516+
return merge([self, -other], cls=self.__class__)
517+
except TypeError:
518+
return NotImplemented
513519

514520
def __neg__(self) -> LinearExpression | QuadraticExpression:
515521
"""
@@ -524,19 +530,22 @@ def __mul__(
524530
"""
525531
Multiply the expr by a factor.
526532
"""
527-
if isinstance(other, QuadraticExpression):
528-
raise TypeError(
529-
"unsupported operand type(s) for *: "
530-
f"{type(self)} and {type(other)}. "
531-
"Higher order non-linear expressions are not yet supported."
532-
)
533-
elif isinstance(other, (variables.Variable, variables.ScalarVariable)):
534-
other = other.to_linexpr()
533+
try:
534+
if isinstance(other, QuadraticExpression):
535+
raise TypeError(
536+
"unsupported operand type(s) for *: "
537+
f"{type(self)} and {type(other)}. "
538+
"Higher order non-linear expressions are not yet supported."
539+
)
540+
elif isinstance(other, (variables.Variable, variables.ScalarVariable)):
541+
other = other.to_linexpr()
535542

536-
if isinstance(other, (LinearExpression, ScalarLinearExpression)):
537-
return self._multiply_by_linear_expression(other)
538-
else:
539-
return self._multiply_by_constant(other)
543+
if isinstance(other, (LinearExpression, ScalarLinearExpression)):
544+
return self._multiply_by_linear_expression(other)
545+
else:
546+
return self._multiply_by_constant(other)
547+
except TypeError:
548+
return NotImplemented
540549

541550
def _multiply_by_linear_expression(
542551
self, other: LinearExpression | ScalarLinearExpression
@@ -599,15 +608,18 @@ def __matmul__(
599608
def __div__(
600609
self, other: Variable | ConstantLike
601610
) -> LinearExpression | QuadraticExpression:
602-
if isinstance(
603-
other, (LinearExpression, variables.Variable, variables.ScalarVariable)
604-
):
605-
raise TypeError(
606-
"unsupported operand type(s) for /: "
607-
f"{type(self)} and {type(other)}"
608-
"Non-linear expressions are not yet supported."
609-
)
610-
return self.__mul__(1 / other)
611+
try:
612+
if isinstance(
613+
other, (LinearExpression, variables.Variable, variables.ScalarVariable)
614+
):
615+
raise TypeError(
616+
"unsupported operand type(s) for /: "
617+
f"{type(self)} and {type(other)}"
618+
"Non-linear expressions are not yet supported."
619+
)
620+
return self.__mul__(1 / other)
621+
except TypeError:
622+
return NotImplemented
611623

612624
def __truediv__(
613625
self, other: Variable | ConstantLike
@@ -1557,13 +1569,17 @@ def __add__(
15571569
Note: If other is a numpy array or pandas object without axes names,
15581570
dimension names of self will be filled in other
15591571
"""
1560-
if np.isscalar(other):
1561-
return self.assign(const=self.const + other)
1572+
try:
1573+
if np.isscalar(other):
1574+
return self.assign(const=self.const + other)
15621575

1563-
other = as_expression(other, model=self.model, dims=self.coord_dims)
1564-
if type(other) is LinearExpression:
1565-
other = other.to_quadexpr()
1566-
return merge([self, other], cls=self.__class__) # type: ignore
1576+
other = as_expression(other, model=self.model, dims=self.coord_dims)
1577+
1578+
if type(other) is LinearExpression:
1579+
other = other.to_quadexpr()
1580+
return merge([self, other], cls=self.__class__) # type: ignore
1581+
except TypeError:
1582+
return NotImplemented
15671583

15681584
def __radd__(
15691585
self, other: LinearExpression | int
@@ -1586,13 +1602,16 @@ def __sub__(self, other: SideLike | QuadraticExpression) -> QuadraticExpression:
15861602
Note: If other is a numpy array or pandas object without axes names,
15871603
dimension names of self will be filled in other
15881604
"""
1589-
if np.isscalar(other):
1590-
return self.assign(const=self.const - other)
1591-
1592-
other = as_expression(other, model=self.model, dims=self.coord_dims)
1593-
if type(other) is LinearExpression:
1594-
other = other.to_quadexpr()
1595-
return merge([self, -other], cls=self.__class__) # type: ignore
1605+
try:
1606+
if np.isscalar(other):
1607+
return self.assign(const=self.const - other)
1608+
1609+
other = as_expression(other, model=self.model, dims=self.coord_dims)
1610+
if type(other) is LinearExpression:
1611+
other = other.to_quadexpr()
1612+
return merge([self, -other], cls=self.__class__) # type: ignore
1613+
except TypeError:
1614+
return NotImplemented
15961615

15971616
def __rsub__(self, other: LinearExpression) -> QuadraticExpression:
15981617
"""

linopy/variables.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -388,25 +388,33 @@ def __mul__(
388388
"""
389389
Multiply variables with a coefficient.
390390
"""
391-
if isinstance(other, (expressions.LinearExpression, Variable, ScalarVariable)):
392-
return self.to_linexpr() * other
393-
else:
391+
try:
392+
if isinstance(
393+
other, (expressions.LinearExpression, Variable, ScalarVariable)
394+
):
395+
return self.to_linexpr() * other
396+
394397
return self.to_linexpr(other)
398+
except TypeError:
399+
return NotImplemented
395400

396401
def __pow__(self, other: int) -> QuadraticExpression:
397402
"""
398403
Power of the variables with a coefficient. The only coefficient allowed is 2.
399404
"""
400-
if not other == 2:
401-
raise ValueError("Power must be 2.")
402-
expr = self.to_linexpr()
403-
return expr._multiply_by_linear_expression(expr)
405+
if isinstance(other, int) and other == 2:
406+
expr = self.to_linexpr()
407+
return expr._multiply_by_linear_expression(expr)
408+
return NotImplemented
404409

405410
def __rmul__(self, other: float | DataArray | int | ndarray) -> LinearExpression:
406411
"""
407412
Right-multiply variables with a coefficient.
408413
"""
409-
return self.to_linexpr(other)
414+
try:
415+
return self.to_linexpr(other)
416+
except TypeError:
417+
return NotImplemented
410418

411419
def __matmul__(
412420
self, other: LinearExpression | ndarray | Variable
@@ -436,15 +444,21 @@ def __truediv__(
436444
"""
437445
True divide variables with a coefficient.
438446
"""
439-
return self.__div__(coefficient)
447+
try:
448+
return self.__div__(coefficient)
449+
except TypeError:
450+
return NotImplemented
440451

441452
def __add__(
442453
self, other: int | QuadraticExpression | LinearExpression | Variable
443454
) -> QuadraticExpression | LinearExpression:
444455
"""
445456
Add variables to linear expressions or other variables.
446457
"""
447-
return self.to_linexpr() + other
458+
try:
459+
return self.to_linexpr() + other
460+
except TypeError:
461+
return NotImplemented
448462

449463
def __radd__(self, other: int) -> Variable | NotImplementedType:
450464
# This is needed for using python's sum function
@@ -456,7 +470,10 @@ def __sub__(
456470
"""
457471
Subtract linear expressions or other variables from the variables.
458472
"""
459-
return self.to_linexpr() - other
473+
try:
474+
return self.to_linexpr() - other
475+
except TypeError:
476+
return NotImplemented
460477

461478
def __le__(self, other: SideLike) -> Constraint:
462479
return self.to_linexpr().__le__(other)
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
from typing import Any
2+
3+
import numpy as np
4+
import pandas as pd
5+
import pytest
6+
import xarray as xr
7+
8+
from linopy import LESS_EQUAL, Model, Variable
9+
from linopy.testing import assert_linequal
10+
11+
12+
class SomeOtherDatatype:
13+
"""
14+
A class that is not a subclass of xarray.DataArray, but stores data in a compatible way.
15+
It defines all necessary arithmetrics AND __array_ufunc__ to ensure that operations are
16+
performed on the active_data.
17+
"""
18+
19+
def __init__(self, data: xr.DataArray) -> None:
20+
self.data1 = data
21+
self.data2 = data.copy()
22+
self.active = 1
23+
24+
def activate(self, active: int) -> None:
25+
self.active = active
26+
27+
@property
28+
def active_data(self) -> xr.DataArray:
29+
return self.data1 if self.active == 1 else self.data2
30+
31+
def __add__(self, other: Any) -> xr.DataArray:
32+
return self.active_data + other
33+
34+
def __sub__(self, other: Any) -> xr.DataArray:
35+
return self.active_data - other
36+
37+
def __mul__(self, other: Any) -> xr.DataArray:
38+
return self.active_data * other
39+
40+
def __truediv__(self, other: Any) -> xr.DataArray:
41+
return self.active_data / other
42+
43+
def __radd__(self, other: Any) -> Any:
44+
return other + self.active_data
45+
46+
def __rsub__(self, other: Any) -> Any:
47+
return other - self.active_data
48+
49+
def __rmul__(self, other: Any) -> Any:
50+
return other * self.active_data
51+
52+
def __rtruediv__(self, other: Any) -> Any:
53+
return other / self.active_data
54+
55+
def __neg__(self) -> xr.DataArray:
56+
return -self.active_data
57+
58+
def __pos__(self) -> xr.DataArray:
59+
return +self.active_data
60+
61+
def __abs__(self) -> xr.DataArray:
62+
return abs(self.active_data)
63+
64+
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): # type: ignore
65+
# Ensure we always use the active_data when interacting with numpy/xarray operations
66+
new_inputs = [
67+
inp.active_data if isinstance(inp, SomeOtherDatatype) else inp
68+
for inp in inputs
69+
]
70+
return getattr(ufunc, method)(*new_inputs, **kwargs)
71+
72+
73+
@pytest.fixture(
74+
params=[
75+
(pd.RangeIndex(10, name="first"),),
76+
(
77+
pd.Index(range(5), name="first"),
78+
pd.Index(range(3), name="second"),
79+
pd.Index(range(2), name="third"),
80+
),
81+
],
82+
ids=["single_dim", "multi_dim"],
83+
)
84+
def m(request) -> Model: # type: ignore
85+
m = Model()
86+
x = m.add_variables(coords=request.param, name="x")
87+
m.add_variables(0, 10, name="z")
88+
m.add_constraints(x, LESS_EQUAL, 0, name="c")
89+
return m
90+
91+
92+
def test_arithmetric_operations_variable(m: Model) -> None:
93+
x: Variable = m.variables["x"]
94+
rng = np.random.default_rng()
95+
data = xr.DataArray(rng.random(x.shape), coords=x.coords)
96+
other_datatype = SomeOtherDatatype(data.copy())
97+
assert_linequal(x + data, x + other_datatype) # type: ignore
98+
assert_linequal(x - data, x - other_datatype) # type: ignore
99+
assert_linequal(x * data, x * other_datatype) # type: ignore
100+
assert_linequal(x / data, x / other_datatype) # type: ignore
101+
assert_linequal(data * x, other_datatype * x) # type: ignore
102+
assert x.__add__(object()) is NotImplemented # type: ignore
103+
assert x.__sub__(object()) is NotImplemented # type: ignore
104+
assert x.__mul__(object()) is NotImplemented # type: ignore
105+
assert x.__truediv__(object()) is NotImplemented # type: ignore
106+
assert x.__pow__(object()) is NotImplemented # type: ignore
107+
assert x.__pow__(3) is NotImplemented
108+
109+
110+
def test_arithmetric_operations_expr(m: Model) -> None:
111+
x = m.variables["x"]
112+
expr = x + 3
113+
rng = np.random.default_rng()
114+
data = xr.DataArray(rng.random(x.shape), coords=x.coords)
115+
other_datatype = SomeOtherDatatype(data.copy())
116+
assert_linequal(expr + data, expr + other_datatype)
117+
assert_linequal(expr - data, expr - other_datatype)
118+
assert_linequal(expr * data, expr * other_datatype)
119+
assert_linequal(expr / data, expr / other_datatype)
120+
assert expr.__add__(object()) is NotImplemented
121+
assert expr.__sub__(object()) is NotImplemented
122+
assert expr.__mul__(object()) is NotImplemented
123+
assert expr.__truediv__(object()) is NotImplemented
124+
125+
126+
def test_arithmetric_operations_con(m: Model) -> None:
127+
c = m.constraints["c"]
128+
x = m.variables["x"]
129+
rng = np.random.default_rng()
130+
data = xr.DataArray(rng.random(x.shape), coords=x.coords)
131+
other_datatype = SomeOtherDatatype(data.copy())
132+
assert_linequal(c.lhs + data, c.lhs + other_datatype)
133+
assert_linequal(c.lhs - data, c.lhs - other_datatype)
134+
assert_linequal(c.lhs * data, c.lhs * other_datatype)
135+
assert_linequal(c.lhs / data, c.lhs / other_datatype)
136+
assert_linequal(c.rhs + data, c.rhs + other_datatype) # type: ignore
137+
assert_linequal(c.rhs - data, c.rhs - other_datatype) # type: ignore
138+
assert_linequal(c.rhs * data, c.rhs * other_datatype) # type: ignore
139+
assert_linequal(c.rhs / data, c.rhs / other_datatype) # type: ignore

0 commit comments

Comments
 (0)