Skip to content

Commit 337f4cd

Browse files
committed
AoC 2024 Day 20 - faster
1 parent 2eec489 commit 337f4cd

File tree

4 files changed

+104
-15
lines changed

4 files changed

+104
-15
lines changed

src/main/python/AoC2024_20.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from aoc.common import InputData
99
from aoc.common import SolutionBase
10-
from aoc.graph import bfs
10+
from aoc.graph import bfs_full
1111
from aoc.grid import CharGrid
1212

1313
Input = CharGrid
@@ -39,28 +39,34 @@ def parse_input(self, input_data: InputData) -> Input:
3939
return CharGrid.from_strings(list(input_data))
4040

4141
def solve(self, grid: CharGrid, cheat_len: int, target: int) -> int:
42-
for cell in grid.get_cells():
43-
if grid.get_value(cell) == "S":
44-
start = cell
45-
if grid.get_value(cell) == "E":
46-
end = cell
47-
time, path = bfs(
42+
start = next(
43+
cell for cell in grid.get_cells() if grid.get_value(cell) == "S"
44+
)
45+
distances, _ = bfs_full(
4846
start,
49-
lambda cell: cell == end,
47+
lambda cell: grid.get_value(cell) != "#",
5048
lambda cell: (
5149
n
5250
for n in grid.get_capital_neighbours(cell)
5351
if grid.get_value(n) != "#"
5452
),
5553
)
56-
54+
dist = {(k.row, k.col): v for k, v in distances.items()}
5755
ans = 0
58-
for i1 in range(len(path) - cheat_len):
59-
for i2 in range(i1 + cheat_len, len(path)):
60-
p1, p2 = path[i1], path[i2]
61-
md = abs(p1.row - p2.row) + abs(p1.col - p2.col)
62-
if md <= cheat_len and i2 - i1 - md >= target:
63-
ans += 1
56+
for r, c in dist.keys():
57+
for md in range(2, cheat_len + 1):
58+
for dr in range(md + 1):
59+
dc = md - dr
60+
for rr, cc in {
61+
(r + dr, c + dc),
62+
(r + dr, c - dc),
63+
(r - dr, c + dc),
64+
(r - dr, c - dc),
65+
}:
66+
if (rr, cc) not in dist:
67+
continue
68+
if dist[(rr, cc)] - dist[(r, c)] >= target + md:
69+
ans += 1
6470
return ans
6571

6672
def part_1(self, grid: Input) -> Output1:

src/main/python/aoc/graph.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,30 @@ def bfs(
7171
raise RuntimeError("unsolvable")
7272

7373

74+
def bfs_full(
75+
start: T,
76+
is_end: Callable[[T], bool],
77+
adjacent: Callable[[T], Iterator[T]],
78+
) -> tuple[dict[T, int], dict[T, T]]:
79+
q: deque[tuple[int, T]] = deque()
80+
q.append((0, start))
81+
seen: set[T] = set()
82+
seen.add(start)
83+
parent: dict[T, T] = {}
84+
dists = defaultdict[T, int](int)
85+
while not len(q) == 0:
86+
distance, node = q.popleft()
87+
if is_end(node):
88+
dists[node] = distance
89+
for n in adjacent(node):
90+
if n in seen:
91+
continue
92+
seen.add(n)
93+
parent[n] = node
94+
q.append((distance + 1, n))
95+
return dists, parent
96+
97+
7498
def flood_fill(
7599
start: T,
76100
adjacent: Callable[[T], Iterator[T]],

src/main/python/aoc/grid.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,18 @@ def to(self, other: Cell) -> Direction:
3838
return Direction.DOWN if self.row < other.row else Direction.UP
3939
raise ValueError("not supported")
4040

41+
def get_all_at_manhattan_distance(self, distance: int) -> Iterator[Cell]:
42+
r, c = self.row, self.col
43+
for dr in range(distance + 1):
44+
dc = distance - dr
45+
for rr, cc in {
46+
(r + dr, c + dc),
47+
(r + dr, c - dc),
48+
(r - dr, c + dc),
49+
(r - dr, c - dc),
50+
}:
51+
yield Cell(rr, cc)
52+
4153

4254
@unique
4355
class IterDir(Enum):

src/test/python/test_grid.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,50 @@ def test_merge(self) -> None:
186186
"333333333",
187187
],
188188
)
189+
190+
191+
class CellTest(unittest.TestCase):
192+
def test_get_all_at_manhattan_distance_1(self) -> None:
193+
cell = Cell(0, 0)
194+
195+
ans = {n for n in cell.get_all_at_manhattan_distance(1)}
196+
197+
self.assertTrue(len(ans) == 4)
198+
199+
def test_get_all_at_manhattan_distance_2(self) -> None:
200+
cell = Cell(0, 0)
201+
202+
ans = {n for n in cell.get_all_at_manhattan_distance(2)}
203+
204+
self.assertTrue(len(ans) == 8)
205+
self.assertTrue(Cell(-2, 0) in ans)
206+
self.assertTrue(Cell(-1, 1) in ans)
207+
self.assertTrue(Cell(0, 2) in ans)
208+
self.assertTrue(Cell(1, 1) in ans)
209+
self.assertTrue(Cell(2, 0) in ans)
210+
self.assertTrue(Cell(1, -1) in ans)
211+
self.assertTrue(Cell(0, -2) in ans)
212+
self.assertTrue(Cell(-1, -1) in ans)
213+
214+
def test_get_all_at_manhattan_distance_3(self) -> None:
215+
cell = Cell(0, 0)
216+
217+
ans = {n for n in cell.get_all_at_manhattan_distance(3)}
218+
219+
self.assertTrue(len(ans) == 12)
220+
self.assertTrue(Cell(-3, 0) in ans)
221+
self.assertTrue(Cell(-2, 1) in ans)
222+
self.assertTrue(Cell(-1, 2) in ans)
223+
self.assertTrue(Cell(0, 3) in ans)
224+
self.assertTrue(Cell(1, 2) in ans)
225+
self.assertTrue(Cell(2, 1) in ans)
226+
self.assertTrue(Cell(3, 0) in ans)
227+
self.assertTrue(Cell(2, -1) in ans)
228+
self.assertTrue(Cell(1, -2) in ans)
229+
self.assertTrue(Cell(0, -3) in ans)
230+
self.assertTrue(Cell(-1, -2) in ans)
231+
self.assertTrue(Cell(-2, -1) in ans)
232+
233+
234+
if __name__ == '__main__':
235+
unittest.main()

0 commit comments

Comments
 (0)