Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Release Notes

.. Upcoming Version
.. ----------------
.. * Improved constraint equality check in `linopy.testing.assert_conequal` to less strict optionally

Version 0.5.6
--------------
Expand Down
22 changes: 19 additions & 3 deletions linopy/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,25 @@ def assert_quadequal(
return assert_equal(_expr_unwrap(a), _expr_unwrap(b))


def assert_conequal(a: Constraint, b: Constraint) -> None:
"""Assert that two constraints are equal."""
return assert_equal(_con_unwrap(a), _con_unwrap(b))
def assert_conequal(a: Constraint, b: Constraint, strict: bool = True) -> None:
"""
Assert that two constraints are equal.

Parameters
----------
a: Constraint
The first constraint.
b: Constraint
The second constraint.
strict: bool
Whether to compare the constraints strictly. If not, only compare mathematically relevant parts.
"""
if strict:
assert_equal(_con_unwrap(a), _con_unwrap(b))
else:
assert_linequal(a.lhs, b.lhs)
assert_equal(a.sign, b.sign)
assert_equal(a.rhs, b.rhs)


def assert_model_equal(a: Model, b: Model) -> None:
Expand Down
25 changes: 25 additions & 0 deletions test/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,31 @@ def test_constraint_assignment() -> None:
assert_conequal(m.constraints.con0, con0)


def test_constraint_equality() -> None:
m: Model = Model()

lower: xr.DataArray = xr.DataArray(
np.zeros((10, 10)), coords=[range(10), range(10)]
)
upper: xr.DataArray = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)])
x = m.add_variables(lower, upper, name="x")
y = m.add_variables(name="y")

con0 = m.add_constraints(1 * x + 10 * y, EQUAL, 0)

assert_conequal(con0, 1 * x + 10 * y == 0, strict=False)
assert_conequal(1 * x + 10 * y == 0, 1 * x + 10 * y == 0, strict=False)

with pytest.raises(AssertionError):
assert_conequal(con0, 1 * x + 10 * y <= 0, strict=False)

with pytest.raises(AssertionError):
assert_conequal(con0, 1 * x + 10 * y >= 0, strict=False)

with pytest.raises(AssertionError):
assert_conequal(10 * y + 2 * x == 0, 1 * x + 10 * y == 0, strict=False)


def test_constraints_getattr_formatted() -> None:
m: Model = Model()
x = m.add_variables(0, 10, name="x")
Expand Down
Loading