Skip to content

Commit 33432ac

Browse files
Merge branch 'master' into oetc-support
2 parents 302246e + 8266b8b commit 33432ac

File tree

6 files changed

+64
-29
lines changed

6 files changed

+64
-29
lines changed

doc/release_notes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ Release Notes
33

44
.. Upcoming Version
55
.. ----------------
6+
.. * Improved constraint equality check in `linopy.testing.assert_conequal` to less strict optionally
67
78
Version 0.5.6
89
--------------

linopy/expressions.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1941,49 +1941,40 @@ def merge(
19411941
dim : str
19421942
Dimension along which the expressions should be concatenated.
19431943
cls : type
1944-
Type of the resulting expression.
1944+
Explicitly set the type of the resulting expression (So that the type checker will know the return type)
19451945
**kwargs
19461946
Additional keyword arguments passed to xarray.concat. Defaults to
19471947
{coords: "minimal", compat: "override"} or, in the special case described
19481948
above, to {coords: "minimal", compat: "override", "join": "override"}.
19491949
19501950
Returns
19511951
-------
1952-
res : linopy.LinearExpression
1952+
res : linopy.LinearExpression or linopy.QuadraticExpression
19531953
"""
1954-
if cls is None:
1955-
warn(
1956-
"Using merge without specifying the class is deprecated",
1957-
DeprecationWarning,
1958-
)
1959-
cls = LinearExpression
1960-
1961-
linopy_types = (variables.Variable, LinearExpression, QuadraticExpression)
1962-
19631954
if not isinstance(exprs, list) and len(add_exprs):
19641955
warn(
19651956
"Passing a tuple to the merge function is deprecated. Please pass a list of objects to be merged",
19661957
DeprecationWarning,
19671958
)
19681959
exprs = [exprs] + list(add_exprs) # type: ignore
1969-
model = exprs[0].model
19701960

1971-
if (
1972-
cls is QuadraticExpression
1973-
and dim == TERM_DIM
1974-
and any(type(e) is LinearExpression for e in exprs)
1975-
):
1961+
has_quad_expression = any(type(e) is QuadraticExpression for e in exprs)
1962+
has_linear_expression = any(type(e) is LinearExpression for e in exprs)
1963+
if cls is None:
1964+
cls = QuadraticExpression if has_quad_expression else LinearExpression
1965+
1966+
if cls is QuadraticExpression and dim == TERM_DIM and has_linear_expression:
19761967
raise ValueError(
19771968
"Cannot merge linear and quadratic expressions along term dimension."
19781969
"Convert to QuadraticExpression first."
19791970
)
19801971

1981-
if cls is not QuadraticExpression and any(
1982-
type(e) is QuadraticExpression for e in exprs
1983-
):
1984-
raise ValueError(
1985-
"Cannot merge linear and quadratic expressions to QuadraticExpression"
1986-
)
1972+
if has_quad_expression and cls is not QuadraticExpression:
1973+
raise ValueError("Cannot merge linear expressions to QuadraticExpression")
1974+
1975+
linopy_types = (variables.Variable, LinearExpression, QuadraticExpression)
1976+
1977+
model = exprs[0].model
19871978

19881979
if cls in linopy_types and dim in HELPER_DIMS:
19891980
coord_dims = [

linopy/testing.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,25 @@ def assert_quadequal(
2929
return assert_equal(_expr_unwrap(a), _expr_unwrap(b))
3030

3131

32-
def assert_conequal(a: Constraint, b: Constraint) -> None:
33-
"""Assert that two constraints are equal."""
34-
return assert_equal(_con_unwrap(a), _con_unwrap(b))
32+
def assert_conequal(a: Constraint, b: Constraint, strict: bool = True) -> None:
33+
"""
34+
Assert that two constraints are equal.
35+
36+
Parameters
37+
----------
38+
a: Constraint
39+
The first constraint.
40+
b: Constraint
41+
The second constraint.
42+
strict: bool
43+
Whether to compare the constraints strictly. If not, only compare mathematically relevant parts.
44+
"""
45+
if strict:
46+
assert_equal(_con_unwrap(a), _con_unwrap(b))
47+
else:
48+
assert_linequal(a.lhs, b.lhs)
49+
assert_equal(a.sign, b.sign)
50+
assert_equal(a.rhs, b.rhs)
3551

3652

3753
def assert_model_equal(a: Model, b: Model) -> None:

test/test_constraints.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,31 @@ def test_constraint_assignment() -> None:
4242
assert_conequal(m.constraints.con0, con0)
4343

4444

45+
def test_constraint_equality() -> None:
46+
m: Model = Model()
47+
48+
lower: xr.DataArray = xr.DataArray(
49+
np.zeros((10, 10)), coords=[range(10), range(10)]
50+
)
51+
upper: xr.DataArray = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)])
52+
x = m.add_variables(lower, upper, name="x")
53+
y = m.add_variables(name="y")
54+
55+
con0 = m.add_constraints(1 * x + 10 * y, EQUAL, 0)
56+
57+
assert_conequal(con0, 1 * x + 10 * y == 0, strict=False)
58+
assert_conequal(1 * x + 10 * y == 0, 1 * x + 10 * y == 0, strict=False)
59+
60+
with pytest.raises(AssertionError):
61+
assert_conequal(con0, 1 * x + 10 * y <= 0, strict=False)
62+
63+
with pytest.raises(AssertionError):
64+
assert_conequal(con0, 1 * x + 10 * y >= 0, strict=False)
65+
66+
with pytest.raises(AssertionError):
67+
assert_conequal(10 * y + 2 * x == 0, 1 * x + 10 * y == 0, strict=False)
68+
69+
4570
def test_constraints_getattr_formatted() -> None:
4671
m: Model = Model()
4772
x = m.add_variables(0, 10, name="x")

test/test_linear_expression.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,9 +1128,8 @@ def test_merge(x: Variable, y: Variable, z: Variable) -> None:
11281128
res = merge([expr1, expr2], cls=LinearExpression)
11291129
assert res.nterm == 6
11301130

1131-
with pytest.warns(DeprecationWarning):
1132-
res: LinearExpression = merge([expr1, expr2]) # type: ignore
1133-
assert res.nterm == 6
1131+
res: LinearExpression = merge([expr1, expr2]) # type: ignore
1132+
assert isinstance(res, LinearExpression)
11341133

11351134
# now concat with same length of terms
11361135
expr1 = z.sel(dim_0=0).sum("dim_1")

test/test_quadratic_expression.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,9 @@ def test_merge_linear_expression_and_quadratic_expression(
238238
with pytest.raises(ValueError):
239239
merge([linexpr, quadexpr], cls=QuadraticExpression)
240240

241+
new_quad_ex = merge([linexpr.to_quadexpr(), quadexpr]) # type: ignore
242+
assert isinstance(new_quad_ex, QuadraticExpression)
243+
241244
with pytest.warns(DeprecationWarning):
242245
merge(quadexpr, quadexpr, cls=QuadraticExpression) # type: ignore
243246

0 commit comments

Comments
 (0)