Skip to content

Commit e687249

Browse files
author
Robbie Muir
committed
added tests to improve code coverage
1 parent ce39e45 commit e687249

File tree

5 files changed

+36
-10
lines changed

5 files changed

+36
-10
lines changed

doc/release_notes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Release Notes
44
Future Version
55
---------------
66
**Minor Improvements**
7+
78
* Improved variable/expression arithmetic methods so that they correctly handle types
89

910
Upcoming Version

linopy/expressions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1317,7 +1317,7 @@ def __add__(
13171317

13181318
def __radd__(self, other: ConstantLike) -> LinearExpression:
13191319
try:
1320-
return self.__add__(other)
1320+
return self + other
13211321
except TypeError:
13221322
return NotImplemented
13231323

@@ -1345,7 +1345,7 @@ def __sub__(
13451345

13461346
def __rsub__(self, other: ConstantLike | Variable) -> LinearExpression:
13471347
try:
1348-
return self.__neg__().__add__(other)
1348+
return (self * -1) + other
13491349
except TypeError:
13501350
return NotImplemented
13511351

@@ -1389,7 +1389,7 @@ def __rmul__(self, other: ConstantLike) -> LinearExpression:
13891389
Right-multiply the expr by a factor.
13901390
"""
13911391
try:
1392-
return self.__mul__(other)
1392+
return self * other
13931393
except TypeError:
13941394
return NotImplemented
13951395

@@ -1725,7 +1725,7 @@ def __mul__(self, other: SideLike) -> QuadraticExpression:
17251725
return NotImplemented
17261726

17271727
def __rmul__(self, other: SideLike) -> QuadraticExpression:
1728-
return self.__mul__(other)
1728+
return self * other
17291729

17301730
def __add__(self, other: SideLike) -> QuadraticExpression:
17311731
"""
@@ -1776,7 +1776,7 @@ def __rsub__(self, other: SideLike) -> QuadraticExpression:
17761776
Subtract expression from others.
17771777
"""
17781778
try:
1779-
return self.__neg__().__add__(other)
1779+
return (self * -1) + other
17801780
except TypeError:
17811781
return NotImplemented
17821782

linopy/variables.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,8 +1516,7 @@ class ScalarVariable:
15161516
"""
15171517
A scalar variable container.
15181518
1519-
In contrast to the Variable class, a ScalarVariable only contains
1520-
only one label. Use this class to create a expression or constraint
1519+
In contrast to the Variable class, a ScalarVariable only contains one label. Use this class to create a expression or constraint
15211520
in a rule.
15221521
"""
15231522

test/test_linear_expression.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from linopy.constants import HELPER_DIMS, TERM_DIM
1919
from linopy.expressions import ScalarLinearExpression
2020
from linopy.testing import assert_linequal, assert_quadequal
21+
from linopy.variables import ScalarVariable
2122

2223

2324
@pytest.fixture
@@ -227,6 +228,8 @@ def test_linear_expression_with_multiplication(x: Variable) -> None:
227228
with pytest.raises(TypeError):
228229
quad * quad
229230

231+
expr = x * 1
232+
assert isinstance(expr, LinearExpression)
230233
assert expr.__mul__(object()) is NotImplemented
231234
assert expr.__rmul__(object()) is NotImplemented
232235

@@ -285,6 +288,7 @@ def test_linear_expression_rsubtraction(x: Variable, y: Variable) -> None:
285288
assert isinstance(expr_2, LinearExpression)
286289
expr_3: LinearExpression = (expr - 10.0) * -1
287290
assert_linequal(expr_2, expr_3)
291+
assert expr.__rsub__(object()) is NotImplemented
288292

289293

290294
def test_linear_expression_with_constant(m: Model, x: Variable, y: Variable) -> None:
@@ -494,14 +498,15 @@ def test_linear_expression_sum_warn_unknown_kwargs(z: Variable) -> None:
494498

495499

496500
def test_linear_expression_power(x: Variable) -> None:
497-
qd_expr = x**2
501+
expr: LinearExpression = x * 1.0
502+
qd_expr = expr**2
498503
assert isinstance(qd_expr, QuadraticExpression)
499504

500-
qd_expr2 = x.pow(2)
505+
qd_expr2 = expr.pow(2)
501506
assert_quadequal(qd_expr, qd_expr2)
502507

503508
with pytest.raises(ValueError):
504-
x**3
509+
expr**3
505510

506511

507512
def test_linear_expression_multiplication(
@@ -1085,13 +1090,31 @@ def test_linear_expression_from_tuples(x: Variable, y: Variable) -> None:
10851090

10861091
expr4 = LinearExpression.from_tuples((10, x), (1, y), 1)
10871092
assert isinstance(expr4, LinearExpression)
1093+
assert (expr4.const == 1).all()
1094+
1095+
expr5 = LinearExpression.from_tuples(1, model=x.model)
1096+
assert isinstance(expr5, LinearExpression)
10881097

1098+
1099+
def test_linear_expression_from_tuples_bad_calls(
1100+
m: Model, x: Variable, y: Variable
1101+
) -> None:
10891102
with pytest.raises(ValueError):
10901103
LinearExpression.from_tuples((10, x), (1, y), x)
10911104

10921105
with pytest.raises(ValueError):
10931106
LinearExpression.from_tuples((10, x, 3), (1, y), 1)
10941107

1108+
sv = ScalarVariable(label=0, model=m)
1109+
with pytest.raises(TypeError):
1110+
LinearExpression.from_tuples((np.array([1, 1]), sv))
1111+
1112+
with pytest.raises(TypeError):
1113+
LinearExpression.from_tuples((x, x))
1114+
1115+
with pytest.raises(ValueError):
1116+
LinearExpression.from_tuples(10)
1117+
10951118

10961119
def test_linear_expression_sanitize(x: Variable, y: Variable, z: Variable) -> None:
10971120
expr = 10 * x + y + z

test/test_quadratic_expression.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ def test_matmul_expr_and_expr(x: Variable, y: Variable, z: Variable) -> None:
123123
assert expr.nterm == 6
124124
assert_quadequal(expr, target)
125125

126+
with pytest.raises(TypeError):
127+
(x**2) @ (y**2)
128+
126129

127130
def test_matmul_with_const(x: Variable) -> None:
128131
expr = x * x

0 commit comments

Comments
 (0)