Skip to content

Commit 5d78e66

Browse files
Add unit tests
1 parent df4a4a5 commit 5d78e66

File tree

2 files changed

+155
-1
lines changed

2 files changed

+155
-1
lines changed

snake_mip_solver/solver.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,3 +324,12 @@ def get_solver_info(self) -> Dict[str, str]:
324324
"start_cell": str(self.puzzle.start_cell),
325325
"end_cell": str(self.puzzle.end_cell)
326326
}
327+
328+
def get_solve_stats(self) -> Dict[str, int]:
329+
"""
330+
Get statistics from the last solve attempt.
331+
332+
Returns:
333+
Dictionary with solving statistics including iterations, cutting planes added, etc.
334+
"""
335+
return self._solve_stats.copy()

tests/test_solver.py

Lines changed: 146 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,4 +237,149 @@ def test_edge_cases(self):
237237
)
238238
solver = SnakeSolver(puzzle_single_col)
239239
solution = solver.solve()
240-
assert solution == {(0,0), (1,0)}
240+
assert solution == {(0,0), (1,0)}
241+
242+
def test_connectivity_cutting_planes_disjoint_infeasible(self):
243+
"""Test that the cutting plane approach correctly identifies disjoint puzzles as infeasible."""
244+
# This puzzle forces disconnected components and should be infeasible
245+
puzzle = SnakePuzzle(
246+
row_sums=[4, 3, 3, 3, 0],
247+
col_sums=[3, 0, 4, 2, 4],
248+
start_cell=(0, 0),
249+
end_cell=(2, 0)
250+
)
251+
solver = SnakeSolver(puzzle)
252+
253+
# Should return None (infeasible)
254+
solution = solver.solve(verbose=False, max_iterations=5)
255+
assert solution is None
256+
257+
# Check that cutting planes were used
258+
stats = solver.get_solve_stats()
259+
assert stats['iterations'] >= 2 # Should take more than 1 iteration
260+
assert stats['disconnected_solutions_found'] >= 1
261+
assert stats['cutting_planes_added'] >= 1
262+
263+
def test_connectivity_cutting_planes_stats_initialization(self):
264+
"""Test that solve statistics are properly initialized."""
265+
puzzle = SnakePuzzle(
266+
row_sums=[2, 1, 2],
267+
col_sums=[1, 3, 1],
268+
start_cell=(0, 0),
269+
end_cell=(2, 2)
270+
)
271+
solver = SnakeSolver(puzzle)
272+
273+
# Stats should be initialized before solving
274+
initial_stats = solver.get_solve_stats()
275+
assert initial_stats['iterations'] == 0
276+
assert initial_stats['cutting_planes_added'] == 0
277+
assert initial_stats['disconnected_solutions_found'] == 0
278+
279+
# Solve and check stats are updated
280+
solution = solver.solve(verbose=False)
281+
assert solution is not None
282+
283+
final_stats = solver.get_solve_stats()
284+
assert final_stats['iterations'] >= 1
285+
286+
def test_connectivity_cutting_planes_stats_reset(self):
287+
"""Test that solve statistics are reset between solve calls."""
288+
puzzle = SnakePuzzle(
289+
row_sums=[4, 3, 3, 3, 0],
290+
col_sums=[3, 0, 4, 2, 4],
291+
start_cell=(0, 0),
292+
end_cell=(2, 0)
293+
)
294+
solver = SnakeSolver(puzzle)
295+
296+
# First solve attempt
297+
solution1 = solver.solve(verbose=False, max_iterations=3)
298+
stats1 = solver.get_solve_stats()
299+
300+
# Second solve attempt (note: cutting planes from first solve persist)
301+
solution2 = solver.solve(verbose=False, max_iterations=3)
302+
stats2 = solver.get_solve_stats()
303+
304+
# Both should be None (infeasible)
305+
assert solution1 is None
306+
assert solution2 is None
307+
308+
# Stats should be reset, not accumulated
309+
assert stats1['iterations'] == 2
310+
assert stats2['iterations'] == 1
311+
312+
# First solve should find disconnected solutions
313+
assert stats1['disconnected_solutions_found'] == 1
314+
# Second solve shouldn't find any disconnected solution because cutting planes from first solve persist
315+
assert stats2['disconnected_solutions_found'] == 0
316+
317+
def test_connectivity_valid_puzzle_no_cutting_planes(self):
318+
"""Test that valid puzzles don't trigger cutting planes."""
319+
puzzle = SnakePuzzle(
320+
row_sums=[1, 1, 1, 3, 2, 5],
321+
col_sums=[4, 3, 1, 1, 1, 3],
322+
start_cell=(0, 0),
323+
end_cell=(3, 5)
324+
)
325+
solver = SnakeSolver(puzzle)
326+
327+
solution = solver.solve(verbose=False)
328+
assert solution is not None
329+
330+
# Should solve in 1 iteration with no cutting planes
331+
stats = solver.get_solve_stats()
332+
assert stats['iterations'] == 1
333+
assert stats['cutting_planes_added'] == 0
334+
assert stats['disconnected_solutions_found'] == 0
335+
336+
def test_connectivity_max_iterations_parameter(self):
337+
"""Test that max_iterations parameter is respected."""
338+
puzzle = SnakePuzzle(
339+
row_sums=[4, 3, 3, 3, 0],
340+
col_sums=[3, 0, 4, 2, 4],
341+
start_cell=(0, 0),
342+
end_cell=(2, 0)
343+
)
344+
solver = SnakeSolver(puzzle)
345+
346+
# Test with limited iterations
347+
solution0 = solver.solve(verbose=False, max_iterations=0)
348+
assert solution0 is None
349+
stats0 = solver.get_solve_stats()
350+
assert stats0['iterations'] == 0
351+
352+
solution1 = solver.solve(max_iterations=1)
353+
assert solution1 is None
354+
stats1 = solver.get_solve_stats()
355+
assert stats1['iterations'] == 1
356+
357+
def test_connectivity_verbose_output(self):
358+
"""Test that verbose output works correctly with cutting planes."""
359+
puzzle = SnakePuzzle(
360+
row_sums=[4, 3, 3, 3, 0],
361+
col_sums=[3, 0, 4, 2, 4],
362+
start_cell=(0, 0),
363+
end_cell=(2, 0)
364+
)
365+
solver = SnakeSolver(puzzle)
366+
367+
# Capture stdout
368+
captured_output = io.StringIO()
369+
sys.stdout = captured_output
370+
371+
try:
372+
solution = solver.solve(verbose=True, max_iterations=3)
373+
assert solution is None
374+
375+
# Restore stdout and check output
376+
sys.stdout = sys.__stdout__
377+
output = captured_output.getvalue()
378+
379+
# Should contain expected verbose messages
380+
assert "Solving Snake puzzle..." in output
381+
assert "Found disconnected solution" in output or "No solution exists" in output
382+
383+
finally:
384+
# Ensure stdout is restored even if test fails
385+
sys.stdout = sys.__stdout__

0 commit comments

Comments
 (0)