Skip to content

Commit 8f78526

Browse files
committed
fix simplify tests: correct comment and add cancellation tests
- Fix misleading comment/error message (coefficient is 6, not 5) - Add test for full cancellation (x - x = 0) - Add test for partial cancellation (2x - 2x + 3y = 3y)
1 parent 967e540 commit 8f78526

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

test/test_linear_expression.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,10 +1200,10 @@ def test_simplify_basic(x: Variable) -> None:
12001200
assert simplified.nterm == 1, f"Expected 1 term, got {simplified.nterm}"
12011201

12021202
x_len = len(x.coords["dim_0"])
1203-
# Check that the coefficient is 5
1203+
# Check that the coefficient is 6 (2 + 3 + 1)
12041204
coeffs: np.ndarray = simplified.coeffs.values
12051205
assert len(coeffs) == x_len, f"Expected {x_len} coefficients, got {len(coeffs)}"
1206-
assert all(coeffs == 6.0), f"Expected coefficient 5.0, got {coeffs[0]}"
1206+
assert all(coeffs == 6.0), f"Expected coefficient 6.0, got {coeffs[0]}"
12071207

12081208

12091209
def test_simplify_multiple_dimensions() -> None:
@@ -1255,3 +1255,24 @@ def test_simplify_with_constant(x: Variable) -> None:
12551255
assert all(simplified.coeffs.values == 5.0), (
12561256
f"Expected coefficient 5.0, got {simplified.coeffs.values}"
12571257
)
1258+
1259+
1260+
def test_simplify_cancellation(x: Variable) -> None:
1261+
"""Test that terms cancel out correctly when coefficients sum to zero."""
1262+
expr = x - x
1263+
simplified = expr.simplify()
1264+
1265+
assert simplified.nterm == 0, f"Expected 0 terms, got {simplified.nterm}"
1266+
assert simplified.coeffs.values.size == 0
1267+
assert simplified.vars.values.size == 0
1268+
1269+
1270+
def test_simplify_partial_cancellation(x: Variable, y: Variable) -> None:
1271+
"""Test partial cancellation where some terms cancel but others remain."""
1272+
expr = 2 * x - 2 * x + 3 * y
1273+
simplified = expr.simplify()
1274+
1275+
assert simplified.nterm == 1, f"Expected 1 term, got {simplified.nterm}"
1276+
assert all(simplified.coeffs.values == 3.0), (
1277+
f"Expected coefficient 3.0, got {simplified.coeffs.values}"
1278+
)

0 commit comments

Comments
 (0)