Skip to content

Commit be4657d

Browse files
RobbieKiwiRobbie MuirFabianHofmann
authored
add linear expression from constant (#518)
* add linear expression from constant * add doc * add to_polars improvements * minor changes * minor changes * formatting * improve test coverage * fix tests * fixed test * fix: constants helpers and solver detection --------- Co-authored-by: Robbie Muir <[email protected]> Co-authored-by: Fabian <[email protected]>
1 parent 3807adb commit be4657d

13 files changed

+227
-18
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,4 @@ benchmark/scripts/leftovers/
4949

5050
# direnv
5151
.envrc
52+
AGENTS.md

doc/release_notes.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ Release Notes
22
=============
33

44
.. Upcoming Version
5-
5+
* Add convenience function to create LinearExpression from constant
66
* Fix compatibility for xpress versions below 9.6 (regression)
77
* Performance: Up to 50x faster ``repr()`` for variables/constraints via O(log n) label lookup and direct numpy indexing
88
* Performance: Up to 46x faster ``ncons`` property by replacing ``.flat.labels.unique()`` with direct counting

linopy/common.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@
3333
SIGNS_pretty,
3434
sign_replace_dict,
3535
)
36-
from linopy.types import CoordsLike, DimsLike
36+
from linopy.types import (
37+
CoordsLike,
38+
DimsLike,
39+
SideLike,
40+
)
3741

3842
if TYPE_CHECKING:
3943
from linopy.constraints import Constraint
@@ -1120,7 +1124,7 @@ def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
11201124
return wrapper
11211125

11221126

1123-
def is_constant(func: Callable[..., Any]) -> Callable[..., Any]:
1127+
def require_constant(func: Callable[..., Any]) -> Callable[..., Any]:
11241128
from linopy import expressions, variables
11251129

11261130
@wraps(func)
@@ -1129,7 +1133,8 @@ def wrapper(self: Any, arg: Any) -> Any:
11291133
arg,
11301134
variables.Variable
11311135
| variables.ScalarVariable
1132-
| expressions.LinearExpression,
1136+
| expressions.LinearExpression
1137+
| expressions.QuadraticExpression,
11331138
):
11341139
raise TypeError(f"Assigned rhs must be a constant, got {type(arg)}).")
11351140
return func(self, arg)
@@ -1325,3 +1330,40 @@ def __call__(self) -> bool:
13251330
stacklevel=2,
13261331
)
13271332
return self.value
1333+
1334+
1335+
def is_constant(x: SideLike) -> bool:
1336+
"""
1337+
Check if the given object is a constant type or an expression type without
1338+
any variables.
1339+
1340+
Note that an expression such as ``x - x + 1`` will evaluate to ``False`` as
1341+
the expression is not simplified before evaluation.
1342+
1343+
Parameters
1344+
----------
1345+
x : SideLike
1346+
The object to check.
1347+
1348+
Returns
1349+
-------
1350+
bool
1351+
True if the object is constant-like, False otherwise.
1352+
"""
1353+
from linopy.expressions import (
1354+
SUPPORTED_CONSTANT_TYPES,
1355+
LinearExpression,
1356+
QuadraticExpression,
1357+
)
1358+
from linopy.variables import ScalarVariable, Variable
1359+
1360+
if isinstance(x, Variable | ScalarVariable):
1361+
return False
1362+
if isinstance(x, LinearExpression | QuadraticExpression):
1363+
return x.is_constant
1364+
if isinstance(x, SUPPORTED_CONSTANT_TYPES):
1365+
return True
1366+
raise TypeError(
1367+
"Expected a constant, variable, or expression on the constraint side, "
1368+
f"got {type(x)}."
1369+
)

linopy/constraints.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@
4343
group_terms_polars,
4444
has_optimized_model,
4545
infer_schema_polars,
46-
is_constant,
4746
iterate_slices,
4847
maybe_replace_signs,
4948
print_coord,
5049
print_single_constraint,
5150
print_single_expression,
5251
replace_by_map,
52+
require_constant,
5353
save_join,
5454
to_dataframe,
5555
to_polars,
@@ -457,7 +457,7 @@ def sign(self) -> DataArray:
457457
return self.data.sign
458458

459459
@sign.setter
460-
@is_constant
460+
@require_constant
461461
def sign(self, value: SignLike) -> None:
462462
value = maybe_replace_signs(DataArray(value)).broadcast_like(self.sign)
463463
self._data = assign_multiindex_safe(self.data, sign=value)

linopy/expressions.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
get_index_map,
5959
group_terms_polars,
6060
has_optimized_model,
61+
is_constant,
6162
iterate_slices,
6263
print_coord,
6364
print_single_expression,
@@ -441,6 +442,11 @@ def __repr__(self) -> str:
441442

442443
return "\n".join(lines)
443444

445+
@property
446+
def is_constant(self) -> bool:
447+
"""True if the expression contains no variables."""
448+
return self.data.sizes[TERM_DIM] == 0
449+
444450
def print(self, display_max_rows: int = 20, display_max_terms: int = 20) -> None:
445451
"""
446452
Print the linear expression.
@@ -840,9 +846,7 @@ def cumsum(
840846
dim_dict = {dim_name: self.data.sizes[dim_name] for dim_name in dim}
841847
return self.rolling(dim=dim_dict).sum(keep_attrs=keep_attrs, skipna=skipna)
842848

843-
def to_constraint(
844-
self, sign: SignLike, rhs: ConstantLike | VariableLike | ExpressionLike
845-
) -> Constraint:
849+
def to_constraint(self, sign: SignLike, rhs: SideLike) -> Constraint:
846850
"""
847851
Convert a linear expression to a constraint.
848852
@@ -859,6 +863,11 @@ def to_constraint(
859863
which are moved to the left-hand-side and constant values which are moved
860864
to the right-hand side.
861865
"""
866+
if self.is_constant and is_constant(rhs):
867+
raise ValueError(
868+
f"Both sides of the constraint are constant. At least one side must contain variables. {self} {rhs}"
869+
)
870+
862871
all_to_lhs = (self - rhs).data
863872
data = assign_multiindex_safe(
864873
all_to_lhs[["coeffs", "vars"]], sign=sign, rhs=-all_to_lhs.const
@@ -1439,12 +1448,18 @@ def to_polars(self) -> pl.DataFrame:
14391448
14401449
The resulting DataFrame represents a long table format of the all
14411450
non-masked expressions with non-zero coefficients. It contains the
1442-
columns `coeffs`, `vars`.
1451+
columns `coeffs`, `vars`, `const`. The coeffs and vars columns will be null if the expression is constant.
14431452
14441453
Returns
14451454
-------
14461455
df : polars.DataFrame
14471456
"""
1457+
if self.is_constant:
1458+
df = pl.DataFrame(
1459+
{"const": self.data["const"].values.reshape(-1)}
1460+
).with_columns(pl.lit(None).alias("coeffs"), pl.lit(None).alias("vars"))
1461+
return df.select(["vars", "coeffs", "const"])
1462+
14481463
df = to_polars(self.data)
14491464
df = filter_nulls_polars(df)
14501465
df = group_terms_polars(df)
@@ -1647,6 +1662,26 @@ def process_one(
16471662

16481663
return merge(exprs, cls=cls) if len(exprs) > 1 else exprs[0]
16491664

1665+
@classmethod
1666+
def from_constant(cls, model: Model, constant: ConstantLike) -> LinearExpression:
1667+
"""
1668+
Create a linear expression from a constant value or series
1669+
1670+
Parameters
1671+
----------
1672+
model : linopy.Model
1673+
The model to which the constant expression will belong.
1674+
constant : int/float/array_like
1675+
The constant value for the linear expression.
1676+
1677+
Returns
1678+
-------
1679+
linopy.LinearExpression
1680+
A linear expression representing the constant value.
1681+
"""
1682+
const_da = as_dataarray(constant)
1683+
return LinearExpression(const_da, model)
1684+
16501685

16511686
class QuadraticExpression(BaseExpression):
16521687
"""
@@ -1835,12 +1870,22 @@ def to_polars(self, **kwargs: Any) -> pl.DataFrame:
18351870
18361871
The resulting DataFrame represents a long table format of the all
18371872
non-masked expressions with non-zero coefficients. It contains the
1838-
columns `coeffs`, `vars`.
1873+
columns `vars1`, `vars2`, `coeffs`, `const`. If the expression is constant, the `vars1` and `vars2` and `coeffs` columns will be null.
18391874
18401875
Returns
18411876
-------
18421877
df : polars.DataFrame
18431878
"""
1879+
if self.is_constant:
1880+
df = pl.DataFrame(
1881+
{"const": self.data["const"].values.reshape(-1)}
1882+
).with_columns(
1883+
pl.lit(None).alias("coeffs"),
1884+
pl.lit(None).alias("vars1"),
1885+
pl.lit(None).alias("vars2"),
1886+
)
1887+
return df.select(["vars1", "vars2", "coeffs", "const"])
1888+
18441889
vars = self.data.vars.assign_coords(
18451890
{FACTOR_DIM: ["vars1", "vars2"]}
18461891
).to_dataset(FACTOR_DIM)

linopy/solvers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,11 @@
6060
with contextlib.suppress(ModuleNotFoundError):
6161
import gurobipy
6262

63-
available_solvers.append("gurobi")
63+
try:
64+
with contextlib.closing(gurobipy.Env()):
65+
available_solvers.append("gurobi")
66+
except gurobipy.GurobiError:
67+
pass
6468
with contextlib.suppress(ModuleNotFoundError):
6569
_new_highspy_mps_layout = None
6670
import highspy

linopy/variables.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@
4242
get_dims_with_index_levels,
4343
get_label_position,
4444
has_optimized_model,
45-
is_constant,
4645
iterate_slices,
4746
print_coord,
4847
print_single_variable,
48+
require_constant,
4949
save_join,
5050
set_int_index,
5151
to_dataframe,
@@ -764,7 +764,7 @@ def upper(self) -> DataArray:
764764
return self.data.upper
765765

766766
@upper.setter
767-
@is_constant
767+
@require_constant
768768
def upper(self, value: ConstantLike) -> None:
769769
"""
770770
Set the upper bounds of the variables.
@@ -788,7 +788,7 @@ def lower(self) -> DataArray:
788788
return self.data.lower
789789

790790
@lower.setter
791-
@is_constant
791+
@require_constant
792792
def lower(self, value: ConstantLike) -> None:
793793
"""
794794
Set the lower bounds of the variables.

test/test_common.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414
from xarray import DataArray
1515
from xarray.testing.assertions import assert_equal
1616

17-
from linopy import LinearExpression, Variable
17+
from linopy import LinearExpression, Model, Variable
1818
from linopy.common import (
1919
align,
2020
as_dataarray,
2121
assign_multiindex_safe,
2222
best_int,
2323
get_dims_with_index_levels,
24+
is_constant,
2425
iterate_slices,
2526
)
2627
from linopy.testing import assert_linequal, assert_varequal
@@ -711,3 +712,28 @@ def test_align(x: Variable, u: Variable) -> None: # noqa: F811
711712
assert expr_obs.shape == (1, 1) # _term dim
712713
assert isinstance(expr_obs, LinearExpression)
713714
assert_linequal(expr_obs, expr.loc[[1]])
715+
716+
717+
def test_is_constant() -> None:
718+
model = Model()
719+
index = pd.Index(range(10), name="t")
720+
a = model.add_variables(name="a", coords=[index])
721+
b = a.sel(t=1)
722+
c = a * 2
723+
d = a * a
724+
725+
non_constant = [a, b, c, d]
726+
for nc in non_constant:
727+
assert not is_constant(nc)
728+
729+
constant_values = [
730+
5,
731+
3.14,
732+
np.int32(7),
733+
np.float64(2.71),
734+
pd.Series([1, 2, 3]),
735+
np.array([4, 5, 6]),
736+
xr.DataArray([k for k in range(10)], coords=[index]),
737+
]
738+
for cv in constant_values:
739+
assert is_constant(cv)

test/test_constraint.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import polars as pl
1313
import pytest
1414
import xarray as xr
15-
import xarray.core
1615
from xarray.testing import assert_equal
1716

1817
import linopy
@@ -70,6 +69,12 @@ def test_empty_constraints_repr() -> None:
7069
Model().constraints.__repr__()
7170

7271

72+
def test_cannot_create_constraint_without_variable() -> None:
73+
model = linopy.Model()
74+
with pytest.raises(ValueError):
75+
_ = linopy.LinearExpression(12, model) == linopy.LinearExpression(13, model)
76+
77+
7378
def test_constraints_getter(m: Model, c: linopy.constraints.Constraint) -> None:
7479
assert c.shape == (10,)
7580
assert isinstance(m.constraints[["c"]], Constraints)

test/test_linear_expression.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,6 +1123,43 @@ def test_linear_expression_from_tuples_bad_calls(
11231123
LinearExpression.from_tuples(10)
11241124

11251125

1126+
def test_linear_expression_from_constant_scalar(m: Model) -> None:
1127+
expr = LinearExpression.from_constant(model=m, constant=10)
1128+
assert expr.is_constant
1129+
assert isinstance(expr, LinearExpression)
1130+
assert (expr.const == 10).all()
1131+
1132+
1133+
def test_linear_expression_from_constant_1D(m: Model) -> None:
1134+
arr = pd.Series(index=pd.Index([0, 1], name="t"), data=[10, 20])
1135+
expr = LinearExpression.from_constant(model=m, constant=arr)
1136+
assert isinstance(expr, LinearExpression)
1137+
assert list(expr.coords.keys())[0] == "t"
1138+
assert expr.nterm == 0
1139+
assert (expr.const.values == [10, 20]).all()
1140+
assert expr.is_constant
1141+
1142+
1143+
def test_constant_linear_expression_to_polars_2D(m: Model) -> None:
1144+
index_a = pd.Index([0, 1], name="a")
1145+
index_b = pd.Index([0, 1, 2], name="b")
1146+
arr = np.array([[10, 20, 30], [40, 50, 60]])
1147+
const = xr.DataArray(data=arr, coords=[index_a, index_b])
1148+
1149+
le_variable = m.add_variables(name="var", coords=[index_a, index_b]) * 1 + const
1150+
assert not le_variable.is_constant
1151+
le_const = LinearExpression.from_constant(model=m, constant=const)
1152+
assert le_const.is_constant
1153+
1154+
var_pol = le_variable.to_polars()
1155+
const_pol = le_const.to_polars()
1156+
assert var_pol.shape == const_pol.shape
1157+
assert var_pol.columns == const_pol.columns
1158+
assert all(const_pol["const"] == var_pol["const"])
1159+
assert all(const_pol["coeffs"].is_null())
1160+
assert all(const_pol["vars"].is_null())
1161+
1162+
11261163
def test_linear_expression_sanitize(x: Variable, y: Variable, z: Variable) -> None:
11271164
expr = 10 * x + y + z
11281165
assert isinstance(expr.sanitize(), LinearExpression)

0 commit comments

Comments
 (0)