diff --git a/doc/release_notes.rst b/doc/release_notes.rst index 5b08b0fc..127c61ef 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -3,6 +3,7 @@ Release Notes .. Upcoming Version .. ---------------- +.. * Improved constraint equality check in `linopy.testing.assert_conequal` to less strict optionally Version 0.5.6 -------------- diff --git a/linopy/testing.py b/linopy/testing.py index dfa46081..0392064e 100644 --- a/linopy/testing.py +++ b/linopy/testing.py @@ -29,9 +29,25 @@ def assert_quadequal( return assert_equal(_expr_unwrap(a), _expr_unwrap(b)) -def assert_conequal(a: Constraint, b: Constraint) -> None: - """Assert that two constraints are equal.""" - return assert_equal(_con_unwrap(a), _con_unwrap(b)) +def assert_conequal(a: Constraint, b: Constraint, strict: bool = True) -> None: + """ + Assert that two constraints are equal. + + Parameters + ---------- + a: Constraint + The first constraint. + b: Constraint + The second constraint. + strict: bool + Whether to compare the constraints strictly. If not, only compare mathematically relevant parts. + """ + if strict: + assert_equal(_con_unwrap(a), _con_unwrap(b)) + else: + assert_linequal(a.lhs, b.lhs) + assert_equal(a.sign, b.sign) + assert_equal(a.rhs, b.rhs) def assert_model_equal(a: Model, b: Model) -> None: diff --git a/test/test_constraints.py b/test/test_constraints.py index eb10b13a..cca010e8 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -42,6 +42,31 @@ def test_constraint_assignment() -> None: assert_conequal(m.constraints.con0, con0) +def test_constraint_equality() -> None: + m: Model = Model() + + lower: xr.DataArray = xr.DataArray( + np.zeros((10, 10)), coords=[range(10), range(10)] + ) + upper: xr.DataArray = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)]) + x = m.add_variables(lower, upper, name="x") + y = m.add_variables(name="y") + + con0 = m.add_constraints(1 * x + 10 * y, EQUAL, 0) + + assert_conequal(con0, 1 * x + 10 * y == 0, strict=False) + assert_conequal(1 * x + 10 * y == 0, 1 * x + 10 * y == 0, strict=False) + + with pytest.raises(AssertionError): + assert_conequal(con0, 1 * x + 10 * y <= 0, strict=False) + + with pytest.raises(AssertionError): + assert_conequal(con0, 1 * x + 10 * y >= 0, strict=False) + + with pytest.raises(AssertionError): + assert_conequal(10 * y + 2 * x == 0, 1 * x + 10 * y == 0, strict=False) + + def test_constraints_getattr_formatted() -> None: m: Model = Model() x = m.add_variables(0, 10, name="x")