Skip to content

Commit 93bfdfa

Browse files
author
Robbie Muir
committed
Further typing changes
1 parent 4be198f commit 93bfdfa

File tree

5 files changed

+51
-20
lines changed

5 files changed

+51
-20
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,6 @@ benchmark/notebooks/.ipynb_checkpoints
3434
benchmark/scripts/__pycache__
3535
benchmark/scripts/benchmarks-pypsa-eur/__pycache__
3636
benchmark/scripts/leftovers/
37+
38+
# IDE
39+
.idea/

linopy/expressions.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,14 @@ def print(self, display_max_rows: int = 20, display_max_terms: int = 20) -> None
484484
)
485485
print(self)
486486

487+
@overload
488+
def __add__(
489+
self, other: ConstantLike | Variable | ScalarLinearExpression | LinearExpression
490+
) -> LinearExpression: ...
491+
492+
@overload
493+
def __add__(self, other: QuadraticExpression) -> QuadraticExpression: ...
494+
487495
def __add__(self, other: SideLike) -> LinearExpression | QuadraticExpression:
488496
"""
489497
Add an expression to others.
@@ -506,24 +514,16 @@ def __add__(self, other: SideLike) -> LinearExpression | QuadraticExpression:
506514
def __radd__(self, other: ConstantLike) -> LinearExpression:
507515
return self.__add__(other)
508516

509-
def __sub__(self, other: SideLike) -> LinearExpression:
510-
"""
511-
Subtract others from expression.
512-
513-
Note: If other is a numpy array or pandas object without axes names,
514-
dimension names of self will be filled in other
515-
"""
516-
if isinstance(other, QuadraticExpression):
517-
return other.__rsub__(self)
517+
@overload
518+
def __sub__(
519+
self, other: ConstantLike | Variable | ScalarLinearExpression | LinearExpression
520+
) -> LinearExpression: ...
518521

519-
try:
520-
if np.isscalar(other):
521-
return self.assign_multiindex_safe(const=self.const - other)
522+
@overload
523+
def __sub__(self, other: QuadraticExpression) -> QuadraticExpression: ...
522524

523-
other = as_expression(other, model=self.model, dims=self.coord_dims)
524-
return merge([self, -other], cls=self.__class__)
525-
except TypeError:
526-
return NotImplemented
525+
def __sub__(self, other: SideLike) -> LinearExpression | QuadraticExpression:
526+
return self.__add__(-other)
527527

528528
def __neg__(self) -> LinearExpression:
529529
"""

linopy/variables.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,15 @@ def __truediv__(
455455
except TypeError:
456456
return NotImplemented
457457

458-
def __add__(self, other: SideLike) -> LinearExpression:
458+
@overload
459+
def __add__(
460+
self, other: ConstantLike | Variable | ScalarLinearExpression | LinearExpression
461+
) -> LinearExpression: ...
462+
463+
@overload
464+
def __add__(self, other: QuadraticExpression) -> QuadraticExpression: ...
465+
466+
def __add__(self, other: SideLike) -> LinearExpression | QuadraticExpression:
459467
"""
460468
Add variables to linear expressions or other variables.
461469
"""
@@ -470,7 +478,15 @@ def __radd__(self, other: ConstantLike) -> LinearExpression:
470478
except ValueError:
471479
return NotImplemented
472480

473-
def __sub__(self, other: SideLike) -> LinearExpression:
481+
@overload
482+
def __sub__(
483+
self, other: ConstantLike | Variable | ScalarLinearExpression | LinearExpression
484+
) -> LinearExpression: ...
485+
486+
@overload
487+
def __sub__(self, other: QuadraticExpression) -> QuadraticExpression: ...
488+
489+
def __sub__(self, other: SideLike) -> LinearExpression | QuadraticExpression:
474490
"""
475491
Subtract linear expressions or other variables from the variables.
476492
"""

test/test_linear_expression.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,14 @@ def test_linear_expression_with_addition(m: Model, x: Variable, y: Variable) ->
226226
assert_linequal(expr, expr2)
227227

228228

229+
def test_linear_expression_with_raddition(m: Model, x: Variable):
230+
expr = x * 1.0
231+
expr_2: LinearExpression = 10.0 + expr # type: ignore
232+
assert isinstance(expr, LinearExpression)
233+
expr_3: LinearExpression = expr + 10.0 # type: ignore
234+
assert_linequal(expr_2, expr_3)
235+
236+
229237
def test_linear_expression_with_subtraction(m: Model, x: Variable, y: Variable) -> None:
230238
expr = x - y
231239
assert isinstance(expr, LinearExpression)

test/test_quadratic_expression.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,12 @@ def test_quadratic_expression_raddition(x: Variable, y: Variable) -> None:
144144
assert (expr.const == 5).all()
145145
assert expr.nterm == 2
146146

147-
with pytest.raises(TypeError):
148-
5 + x * y + x
147+
expr_2 = 5 + x * y + x
148+
assert isinstance(expr_2, QuadraticExpression)
149+
assert (expr_2.const == 5).all()
150+
assert expr_2.nterm == 2
151+
152+
assert_quadequal(expr, expr_2)
149153

150154

151155
def test_quadratic_expression_subtraction(x: Variable, y: Variable) -> None:

0 commit comments

Comments
 (0)