1212from collections .abc import Callable , Hashable , Iterator , Mapping , Sequence
1313from dataclasses import dataclass , field
1414from itertools import product , zip_longest
15- from typing import TYPE_CHECKING , Any
15+ from typing import TYPE_CHECKING , Any , TypeVar , overload
1616from warnings import warn
1717
1818import numpy as np
@@ -531,10 +531,18 @@ def __neg__(self) -> LinearExpression:
531531 """
532532 return self .assign_multiindex_safe (coeffs = - self .coeffs , const = - self .const )
533533
534+ @overload
534535 def __mul__ (
535- self ,
536+ self : GenericLinearExpression , other : ConstantLike
537+ ) -> GenericLinearExpression : ...
538+
539+ @overload
540+ def __mul__ (self , other : VariableLike | ExpressionLike ) -> QuadraticExpression : ...
541+
542+ def __mul__ (
543+ self : GenericLinearExpression ,
536544 other : SideLike ,
537- ) -> LinearExpression | QuadraticExpression :
545+ ) -> GenericLinearExpression | QuadraticExpression :
538546 """
539547 Multiply the expr by a factor.
540548 """
@@ -577,7 +585,9 @@ def _multiply_by_linear_expression(
577585 res = res + self .reset_const () * other .const
578586 return res # type: ignore
579587
580- def _multiply_by_constant (self , other : ConstantLike ) -> LinearExpression :
588+ def _multiply_by_constant (
589+ self : GenericLinearExpression , other : ConstantLike
590+ ) -> GenericLinearExpression :
581591 multiplier = as_dataarray (other , coords = self .coords , dims = self .coord_dims )
582592 coeffs = self .coeffs * multiplier
583593 assert all (coeffs .sizes [d ] == s for d , s in self .coeffs .sizes .items ())
@@ -592,7 +602,9 @@ def __pow__(self, other: int) -> QuadraticExpression:
592602 raise ValueError ("Power must be 2." )
593603 return self * self # type: ignore
594604
595- def __rmul__ (self , other : ConstantLike ) -> LinearExpression :
605+ def __rmul__ (
606+ self : GenericLinearExpression , other : ConstantLike
607+ ) -> GenericLinearExpression :
596608 """
597609 Right-multiply the expr by a factor.
598610 """
@@ -611,11 +623,18 @@ def __matmul__(
611623 return (self * other ).sum (dim = common_dims )
612624
613625 def __div__ (
614- self , other : Variable | ConstantLike
615- ) -> LinearExpression | QuadraticExpression :
626+ self : GenericLinearExpression , other : SideLike
627+ ) -> GenericLinearExpression :
616628 try :
617629 if isinstance (
618- other , (LinearExpression , variables .Variable , variables .ScalarVariable )
630+ other ,
631+ (
632+ variables .Variable ,
633+ variables .ScalarVariable ,
634+ LinearExpression ,
635+ ScalarLinearExpression ,
636+ QuadraticExpression ,
637+ ),
619638 ):
620639 raise TypeError (
621640 "unsupported operand type(s) for /: "
@@ -627,8 +646,8 @@ def __div__(
627646 return NotImplemented
628647
629648 def __truediv__ (
630- self , other : Variable | ConstantLike
631- ) -> LinearExpression | QuadraticExpression :
649+ self : GenericLinearExpression , other : SideLike
650+ ) -> GenericLinearExpression :
632651 return self .__div__ (other )
633652
634653 def __le__ (self , rhs : SideLike ) -> Constraint :
@@ -1514,6 +1533,9 @@ def to_polars(self) -> pl.DataFrame:
15141533 iterate_slices = iterate_slices
15151534
15161535
1536+ GenericLinearExpression = TypeVar ("GenericLinearExpression" , bound = LinearExpression )
1537+
1538+
15171539class QuadraticExpression (LinearExpression ):
15181540 """
15191541 A quadratic expression consisting of terms of coefficients and variables.
@@ -1544,7 +1566,7 @@ def __init__(self, data: Dataset | None, model: Model) -> None:
15441566 data = xr .Dataset (data .transpose (..., FACTOR_DIM , TERM_DIM ))
15451567 self ._data = data
15461568
1547- def __mul__ (self , other : ConstantLike ) -> QuadraticExpression :
1569+ def __mul__ (self , other : SideLike ) -> QuadraticExpression :
15481570 """
15491571 Multiply the expr by a factor.
15501572 """
@@ -1553,6 +1575,7 @@ def __mul__(self, other: ConstantLike) -> QuadraticExpression:
15531575 (
15541576 LinearExpression ,
15551577 QuadraticExpression ,
1578+ ScalarLinearExpression ,
15561579 variables .Variable ,
15571580 variables .ScalarVariable ,
15581581 ),
@@ -1562,9 +1585,9 @@ def __mul__(self, other: ConstantLike) -> QuadraticExpression:
15621585 f"{ type (self )} and { type (other )} . "
15631586 "Higher order non-linear expressions are not yet supported."
15641587 )
1565- return super ().__mul__ (other ) # type: ignore
1588+ return super ().__mul__ (other )
15661589
1567- def __rmul__ (self , other : ConstantLike ) -> QuadraticExpression :
1590+ def __rmul__ (self , other : SideLike ) -> QuadraticExpression :
15681591 return self .__mul__ (other )
15691592
15701593 @property
0 commit comments