Skip to content

Commit c357de4

Browse files
Add test_puzzle
1 parent 3c5b67a commit c357de4

File tree

2 files changed

+140
-21
lines changed

2 files changed

+140
-21
lines changed

snake_mip_solver/puzzle.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(self,
5555
self.end_cell = end_cell
5656

5757
# Offset patterns for different types of tile relationships
58-
self._adjacent_offsets = [(-1, 0), (1, 0), (0, -1), (0, 1)]
58+
self._orthogonal_offsets = [(-1, 0), (1, 0), (0, -1), (0, 1)]
5959
self._diagonal_offsets = [(-1, -1), (-1, 1), (1, -1), (1, 1)]
6060

6161
# Validate puzzle configuration
@@ -178,18 +178,17 @@ def _check_no_diagonal_touching(self, solution: Set[Tuple[int, int]]) -> bool:
178178

179179
for diagonal_pos in diagonal_neighbors:
180180
if diagonal_pos in solution:
181-
dr, dc = diagonal_pos
182181
# Check if they are connected by orthogonal cells
183182
# If there are orthogonal connections, diagonal touching is allowed
184-
orthogonal_connections = [
185-
(row - 1, col) in solution and (dr, col) in solution, # vertical connection
186-
(row + 1, col) in solution and (dr, col) in solution, # vertical connection
187-
(row, col - 1) in solution and (row, dc) in solution, # horizontal connection
188-
(row, col + 1) in solution and (row, dc) in solution # horizontal connection
189-
]
183+
orthogonal_neighbors = self.get_tiles_by_offsets(position, self._orthogonal_offsets)
184+
diagonal_orthogonal_neighbors = self.get_tiles_by_offsets(diagonal_pos, self._orthogonal_offsets)
185+
186+
# Check if there's an orthogonal path connecting the two diagonal cells
187+
shared_orthogonal = orthogonal_neighbors.intersection(diagonal_orthogonal_neighbors)
188+
orthogonal_connections = shared_orthogonal.intersection(solution)
190189

191190
# If diagonal cells touch but no orthogonal connection exists, it's invalid
192-
if not any(orthogonal_connections):
191+
if not orthogonal_connections:
193192
return False
194193

195194
return True
@@ -215,8 +214,8 @@ def _check_snake_path(self, solution: Set[Tuple[int, int]]) -> bool:
215214
# Count neighbors for each cell
216215
neighbor_count = {}
217216
for position in solution:
218-
adjacent_neighbors = self.get_tiles_by_offsets(position, self._adjacent_offsets)
219-
count = len(adjacent_neighbors.intersection(solution))
217+
orthogonal_neighbors = self.get_tiles_by_offsets(position, self._orthogonal_offsets)
218+
count = len(orthogonal_neighbors.intersection(solution))
220219
neighbor_count[position] = count
221220

222221
# Check start and end cells have exactly 1 neighbor
@@ -239,15 +238,15 @@ def dfs(current):
239238
return
240239
visited.add(current)
241240

242-
adjacent_neighbors = self.get_tiles_by_offsets(current, self._adjacent_offsets)
243-
for neighbor in adjacent_neighbors:
241+
orthogonal_neighbors = self.get_tiles_by_offsets(current, self._orthogonal_offsets)
242+
for neighbor in orthogonal_neighbors:
244243
if neighbor in solution and neighbor not in visited:
245244
dfs(neighbor)
246245

247246
# Start DFS from start_cell
248247
dfs(self.start_cell)
249248

250-
# All cells should be reachable from start, and end should be reachable
249+
# All cells should be reachable from start
251250
return len(visited) == len(solution) and self.end_cell in visited
252251

253252
def get_grid_size(self) -> Tuple[int, int]:

tests/test_puzzle.py

Lines changed: 127 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,133 @@ class TestSnakePuzzle:
77

88
def test_basic_puzzle_creation(self):
99
"""Test creating a basic puzzle."""
10-
# Example: Replace with your puzzle parameters
11-
# puzzle = SnakePuzzle(size=5, constraints=[1, 2, 3])
10+
# Test basic 3x3 puzzle with some constraints
11+
row_sums = [2, None, 1] # Row 0 needs 2 cells, row 1 unconstrained, row 2 needs 1 cell
12+
col_sums = [1, 2, None] # Col 0 needs 1 cell, col 1 needs 2 cells, col 2 unconstrained
13+
start_cell = (0, 0)
14+
end_cell = (2, 2)
1215

13-
# Basic instantiation should work with default template
14-
puzzle = SnakePuzzle()
16+
puzzle = SnakePuzzle(row_sums=row_sums, col_sums=col_sums,
17+
start_cell=start_cell, end_cell=end_cell)
1518
assert puzzle is not None
1619

17-
# Once implemented, test basic properties:
18-
# assert puzzle.size == 5
19-
# assert len(puzzle.constraints) == 3
20+
# Test basic properties
21+
assert puzzle.rows == 3
22+
assert puzzle.cols == 3
23+
assert puzzle.row_sums == [2, None, 1]
24+
assert puzzle.col_sums == [1, 2, None]
25+
assert puzzle.get_grid_size() == (3, 3)
26+
assert puzzle.get_start_cell() == (0, 0)
27+
assert puzzle.get_end_cell() == (2, 2)
28+
29+
def test_puzzle_validation(self):
30+
"""Test puzzle input validation."""
31+
start_cell = (0, 0)
32+
end_cell = (2, 2)
33+
34+
# Test invalid dimensions (empty lists)
35+
with pytest.raises(ValueError, match="Number of rows must be positive"):
36+
SnakePuzzle([], [1, 2, 3], start_cell, end_cell)
37+
38+
with pytest.raises(ValueError, match="Number of columns must be positive"):
39+
SnakePuzzle([1, 2, 3], [], start_cell, end_cell)
40+
41+
# Test invalid sum values
42+
with pytest.raises(ValueError, match="Row 0 sum 5 must be between 0 and 3"):
43+
SnakePuzzle([5, 1, 1], [1, 1, 1], start_cell, end_cell)
44+
45+
with pytest.raises(ValueError, match="Column 1 sum -1 must be between 0 and 3"):
46+
SnakePuzzle([1, 1, 1], [1, -1, 1], start_cell, end_cell)
47+
48+
# Test invalid start/end cells
49+
with pytest.raises(ValueError, match="Start cell .* is out of bounds"):
50+
SnakePuzzle([1, 1, 1], [1, 1, 1], (3, 0), end_cell)
51+
52+
with pytest.raises(ValueError, match="End cell .* is out of bounds"):
53+
SnakePuzzle([1, 1, 1], [1, 1, 1], start_cell, (0, 3))
54+
55+
with pytest.raises(ValueError, match="Start cell and end cell cannot be the same"):
56+
SnakePuzzle([1, 1, 1], [1, 1, 1], (0, 0), (0, 0))
57+
58+
def test_solution_validation(self):
59+
"""Test solution validation."""
60+
puzzle = SnakePuzzle([2, 1, 2], [1, 3, 1], start_cell=(0, 0), end_cell=(2, 2))
61+
62+
# Actual solution is (0,0), (0,1), (1,1), (2,1), (2,2)
63+
64+
# Test empty solution
65+
assert not puzzle.is_valid_solution(set())
66+
67+
# Test solution missing start or end cell
68+
assert not puzzle.is_valid_solution({(0, 1), (1, 1), (2, 1)}) # Missing start (0,0) and end (2,2)
69+
70+
# Test out of bounds solution
71+
assert not puzzle.is_valid_solution({(0, 0), (3, 0), (2, 2)}) # row 3 is out of bounds
72+
assert not puzzle.is_valid_solution({(0, 0), (0, 3), (2, 2)}) # col 3 is out of bounds
73+
74+
# Test the actual valid solution
75+
valid_solution = {(0, 0), (0, 1), (1, 1), (2, 1), (2, 2)}
76+
assert puzzle.is_valid_solution(valid_solution)
77+
78+
# Test solution that violates row constraints
79+
invalid_row_solution = {(0, 0), (1, 1), (2, 2)} # Row 0 has only 1 cell but needs 2
80+
assert not puzzle.is_valid_solution(invalid_row_solution)
81+
82+
# Test solution that violates column constraints
83+
invalid_col_solution = {(0, 0), (0, 1), (0, 2), (2, 2)} # Col 1 has only 1 cell but needs 3
84+
assert not puzzle.is_valid_solution(invalid_col_solution)
85+
86+
# Test disconnected path
87+
disconnected_solution = {(0, 0), (0, 1), (2, 1), (2, 2)} # Missing (1,1) - creates gap
88+
assert not puzzle.is_valid_solution(disconnected_solution)
89+
90+
def test_utility_methods(self):
91+
"""Test utility methods."""
92+
puzzle = SnakePuzzle([2, None, 1], [1, 2, None], start_cell=(0, 0), end_cell=(2, 2))
93+
94+
# Test get_row_sum
95+
assert puzzle.get_row_sum(0) == 2
96+
assert puzzle.get_row_sum(1) is None
97+
assert puzzle.get_row_sum(2) == 1
98+
99+
with pytest.raises(IndexError):
100+
puzzle.get_row_sum(3)
101+
102+
# Test get_col_sum
103+
assert puzzle.get_col_sum(0) == 1
104+
assert puzzle.get_col_sum(1) == 2
105+
assert puzzle.get_col_sum(2) is None
106+
107+
with pytest.raises(IndexError):
108+
puzzle.get_col_sum(3)
109+
110+
# Test start/end cell getters
111+
assert puzzle.get_start_cell() == (0, 0)
112+
assert puzzle.get_end_cell() == (2, 2)
113+
114+
# Test helper methods
115+
assert puzzle.is_within_bounds(0, 0)
116+
assert puzzle.is_within_bounds(2, 2)
117+
assert not puzzle.is_within_bounds(-1, 0)
118+
assert not puzzle.is_within_bounds(3, 0)
119+
assert not puzzle.is_within_bounds(0, 3)
120+
121+
# Test get_tile_by_offset
122+
assert puzzle.get_tile_by_offset((1, 1), (-1, 0)) == (0, 1)
123+
assert puzzle.get_tile_by_offset((0, 0), (-1, 0)) is None # out of bounds
124+
125+
# Test get_tiles_by_offsets
126+
adjacent_tiles = puzzle.get_tiles_by_offsets((1, 1), puzzle._orthogonal_offsets)
127+
expected_adjacent = {(0, 1), (2, 1), (1, 0), (1, 2)}
128+
assert adjacent_tiles == expected_adjacent
129+
130+
def test_repr(self):
131+
"""Test string representation."""
132+
puzzle = SnakePuzzle([1, None], [2, 1, None], start_cell=(0, 0), end_cell=(1, 2))
133+
repr_str = repr(puzzle)
134+
assert "SnakePuzzle" in repr_str
135+
assert "rows=2" in repr_str
136+
assert "cols=3" in repr_str
137+
assert "start=(0, 0)" in repr_str
138+
assert "end=(1, 2)" in repr_str
139+

0 commit comments

Comments
 (0)