Skip to content

Commit 4be198f

Browse files
author
Robbie Muir
committed
Changes to typing
1 parent 4cb1f37 commit 4be198f

File tree

5 files changed

+62
-25
lines changed

5 files changed

+62
-25
lines changed

linopy/expressions.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence
1313
from dataclasses import dataclass, field
1414
from itertools import product, zip_longest
15-
from typing import TYPE_CHECKING, Any
15+
from typing import TYPE_CHECKING, Any, TypeVar, overload
1616
from warnings import warn
1717

1818
import numpy as np
@@ -531,10 +531,18 @@ def __neg__(self) -> LinearExpression:
531531
"""
532532
return self.assign_multiindex_safe(coeffs=-self.coeffs, const=-self.const)
533533

534+
@overload
534535
def __mul__(
535-
self,
536+
self: GenericLinearExpression, other: ConstantLike
537+
) -> GenericLinearExpression: ...
538+
539+
@overload
540+
def __mul__(self, other: VariableLike | ExpressionLike) -> QuadraticExpression: ...
541+
542+
def __mul__(
543+
self: GenericLinearExpression,
536544
other: SideLike,
537-
) -> LinearExpression | QuadraticExpression:
545+
) -> GenericLinearExpression | QuadraticExpression:
538546
"""
539547
Multiply the expr by a factor.
540548
"""
@@ -577,7 +585,9 @@ def _multiply_by_linear_expression(
577585
res = res + self.reset_const() * other.const
578586
return res # type: ignore
579587

580-
def _multiply_by_constant(self, other: ConstantLike) -> LinearExpression:
588+
def _multiply_by_constant(
589+
self: GenericLinearExpression, other: ConstantLike
590+
) -> GenericLinearExpression:
581591
multiplier = as_dataarray(other, coords=self.coords, dims=self.coord_dims)
582592
coeffs = self.coeffs * multiplier
583593
assert all(coeffs.sizes[d] == s for d, s in self.coeffs.sizes.items())
@@ -592,7 +602,9 @@ def __pow__(self, other: int) -> QuadraticExpression:
592602
raise ValueError("Power must be 2.")
593603
return self * self # type: ignore
594604

595-
def __rmul__(self, other: ConstantLike) -> LinearExpression:
605+
def __rmul__(
606+
self: GenericLinearExpression, other: ConstantLike
607+
) -> GenericLinearExpression:
596608
"""
597609
Right-multiply the expr by a factor.
598610
"""
@@ -611,11 +623,18 @@ def __matmul__(
611623
return (self * other).sum(dim=common_dims)
612624

613625
def __div__(
614-
self, other: Variable | ConstantLike
615-
) -> LinearExpression | QuadraticExpression:
626+
self: GenericLinearExpression, other: SideLike
627+
) -> GenericLinearExpression:
616628
try:
617629
if isinstance(
618-
other, (LinearExpression, variables.Variable, variables.ScalarVariable)
630+
other,
631+
(
632+
variables.Variable,
633+
variables.ScalarVariable,
634+
LinearExpression,
635+
ScalarLinearExpression,
636+
QuadraticExpression,
637+
),
619638
):
620639
raise TypeError(
621640
"unsupported operand type(s) for /: "
@@ -627,8 +646,8 @@ def __div__(
627646
return NotImplemented
628647

629648
def __truediv__(
630-
self, other: Variable | ConstantLike
631-
) -> LinearExpression | QuadraticExpression:
649+
self: GenericLinearExpression, other: SideLike
650+
) -> GenericLinearExpression:
632651
return self.__div__(other)
633652

634653
def __le__(self, rhs: SideLike) -> Constraint:
@@ -1514,6 +1533,9 @@ def to_polars(self) -> pl.DataFrame:
15141533
iterate_slices = iterate_slices
15151534

15161535

1536+
GenericLinearExpression = TypeVar("GenericLinearExpression", bound=LinearExpression)
1537+
1538+
15171539
class QuadraticExpression(LinearExpression):
15181540
"""
15191541
A quadratic expression consisting of terms of coefficients and variables.
@@ -1544,7 +1566,7 @@ def __init__(self, data: Dataset | None, model: Model) -> None:
15441566
data = xr.Dataset(data.transpose(..., FACTOR_DIM, TERM_DIM))
15451567
self._data = data
15461568

1547-
def __mul__(self, other: ConstantLike) -> QuadraticExpression:
1569+
def __mul__(self, other: SideLike) -> QuadraticExpression:
15481570
"""
15491571
Multiply the expr by a factor.
15501572
"""
@@ -1553,6 +1575,7 @@ def __mul__(self, other: ConstantLike) -> QuadraticExpression:
15531575
(
15541576
LinearExpression,
15551577
QuadraticExpression,
1578+
ScalarLinearExpression,
15561579
variables.Variable,
15571580
variables.ScalarVariable,
15581581
),
@@ -1562,9 +1585,9 @@ def __mul__(self, other: ConstantLike) -> QuadraticExpression:
15621585
f"{type(self)} and {type(other)}. "
15631586
"Higher order non-linear expressions are not yet supported."
15641587
)
1565-
return super().__mul__(other) # type: ignore
1588+
return super().__mul__(other)
15661589

1567-
def __rmul__(self, other: ConstantLike) -> QuadraticExpression:
1590+
def __rmul__(self, other: SideLike) -> QuadraticExpression:
15681591
return self.__mul__(other)
15691592

15701593
@property

linopy/variables.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,14 @@
5252
)
5353
from linopy.config import options
5454
from linopy.constants import HELPER_DIMS, TERM_DIM
55-
from linopy.types import ConstantLike, DimsLike, NotImplementedType, SideLike
55+
from linopy.types import (
56+
ConstantLike,
57+
DimsLike,
58+
ExpressionLike,
59+
NotImplementedType,
60+
SideLike,
61+
VariableLike,
62+
)
5663

5764
if TYPE_CHECKING:
5865
from linopy.constraints import AnonymousScalarConstraint, Constraint
@@ -382,7 +389,13 @@ def __neg__(self) -> LinearExpression:
382389
"""
383390
return self.to_linexpr(-1)
384391

385-
def __mul__(self, other: SideLike) -> LinearExpression | QuadraticExpression:
392+
@overload
393+
def __mul__(self, other: ConstantLike) -> LinearExpression: ...
394+
395+
@overload
396+
def __mul__(self, other: ExpressionLike | VariableLike) -> QuadraticExpression: ...
397+
398+
def __mul__(self, other: SideLike) -> ExpressionLike:
386399
"""
387400
Multiply variables with a coefficient, variable, or expression.
388401
"""
@@ -398,7 +411,7 @@ def __rmul__(self, other: ConstantLike) -> LinearExpression:
398411
"""
399412
Right-multiply variables by a constant
400413
"""
401-
return self.to_linexpr(other)
414+
return self * other
402415

403416
def __pow__(self, other: int) -> QuadraticExpression:
404417
"""
@@ -1539,6 +1552,8 @@ def __mul__(self, coeff: int | float) -> ScalarLinearExpression:
15391552
return self.to_scalar_linexpr(coeff)
15401553

15411554
def __rmul__(self, coeff: int | float) -> ScalarLinearExpression:
1555+
if isinstance(coeff, Variable):
1556+
return NotImplemented
15421557
return self.to_scalar_linexpr(coeff)
15431558

15441559
def __div__(self, coeff: int | float) -> ScalarLinearExpression:

test/test_compatible_arithmetrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def test_arithmetric_operations_variable(m: Model) -> None:
9797
assert_linequal(x + data, x + other_datatype)
9898
assert_linequal(x - data, x - other_datatype)
9999
assert_linequal(x * data, x * other_datatype)
100-
assert_linequal(x / data, x / other_datatype)
100+
assert_linequal(x / data, x / other_datatype) # type: ignore
101101
assert_linequal(data * x, other_datatype * x) # type: ignore
102102
assert x.__add__(object()) is NotImplemented
103103
assert x.__sub__(object()) is NotImplemented
@@ -131,7 +131,7 @@ def test_arithmetric_operations_vars_and_expr(m: Model) -> None:
131131
assert_linequal(x**2 + x, x + x**2)
132132
assert_linequal(x**2 * 2, x**2 * 2)
133133
with pytest.raises(TypeError):
134-
_ = x**2 * x # type: ignore
134+
_ = x**2 * x
135135

136136

137137
def test_arithmetric_operations_con(m: Model) -> None:

test/test_optimization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def test_default_setting_expression_sol_accessor(
386386
qexpr = 4 * x**2
387387
assert_equal(qexpr.solution, 4 * x.solution**2)
388388

389-
qexpr = 4 * x * y
389+
qexpr = 4 * (x * y) # type: ignore
390390
assert_equal(qexpr.solution, 4 * x.solution * y.solution)
391391

392392

test/test_typing.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,18 @@
55

66

77
def test_operations_with_data_arrays_are_typed_correctly() -> None:
8+
# Get the path of this file
9+
file_path = __file__
10+
result = api.run([file_path])
11+
assert result[2] == 0, "Mypy returned issues: " + result[0]
12+
813
m = linopy.Model()
914

1015
a: xr.DataArray = xr.DataArray([1, 2, 3])
1116

1217
v: linopy.Variable = m.add_variables(lower=0.0, name="v")
1318
e: linopy.LinearExpression = v * 1.0
1419
q = v * v
15-
assert isinstance(q, linopy.QuadraticExpression)
1620

1721
_ = a * v
1822
_ = v * a
@@ -25,8 +29,3 @@ def test_operations_with_data_arrays_are_typed_correctly() -> None:
2529
_ = a * q
2630
_ = q * a
2731
_ = q + a
28-
29-
# Get the path of this file
30-
file_path = __file__
31-
result = api.run([file_path])
32-
assert result[2] == 0, "Mypy returned issues: " + result[0]

0 commit comments

Comments
 (0)