99import sys
1010from collections import defaultdict
1111from queue import PriorityQueue
12- from typing import Callable
1312from typing import Iterator
1413from typing import NamedTuple
1514from typing import Self
1817from aoc .common import SolutionBase
1918from aoc .common import aoc_samples
2019from aoc .geometry import Direction
21- from aoc .geometry import Turn
2220from aoc .grid import Cell
2321from aoc .grid import CharGrid
2422
2523Input = CharGrid
2624Output1 = int
2725Output2 = int
28- State = tuple [Cell , str ]
29-
30- DIRS = {"U" , "R" , "D" , "L" }
31- START_DIR = "R"
26+ State = tuple [int , int , int ]
27+
28+ IDX_TO_DIR = {
29+ 0 : Direction .UP ,
30+ 1 : Direction .RIGHT ,
31+ 2 : Direction .DOWN ,
32+ 3 : Direction .LEFT ,
33+ }
34+ DIR_TO_IDX = {
35+ Direction .UP : 0 ,
36+ Direction .RIGHT : 1 ,
37+ Direction .DOWN : 2 ,
38+ Direction .LEFT : 3 ,
39+ }
40+ TURNS = {
41+ Direction .UP : {Direction .LEFT , Direction .RIGHT },
42+ Direction .RIGHT : {Direction .UP , Direction .DOWN },
43+ Direction .DOWN : {Direction .LEFT , Direction .RIGHT },
44+ Direction .LEFT : {Direction .UP , Direction .DOWN },
45+ }
46+ START_DIR = Direction .RIGHT
47+ FORWARD , BACKWARD = 1 , - 1
3248
3349
3450TEST1 = """\
@@ -85,37 +101,16 @@ def from_grid(cls, grid: CharGrid) -> Self:
85101 end = cell
86102 return cls (grid , start , end )
87103
88- def get_turns (self , direction : Direction ) -> Iterator [str ]:
89- for turn in (Turn .LEFT , Turn .RIGHT ):
90- new_letter = direction .turn (turn ).letter
91- assert new_letter is not None
92- yield new_letter
93-
94- def adjacent_forward (self , state : State ) -> Iterator [State ]:
95- cell , letter = state
96- direction = Direction .from_str (letter )
97- for d in self .get_turns (direction ):
98- yield (cell , d )
99- nxt = cell .at (direction )
100- if self .grid .get_value (nxt ) != "#" :
101- yield (nxt , letter )
102-
103- def adjacent_backward (self , state : State ) -> Iterator [State ]:
104- cell , letter = state
105- direction = Direction .from_str (letter )
106- for d in self .get_turns (direction ):
107- yield (cell , d )
108- nxt = cell .at (direction .turn (Turn .AROUND ))
109- if self .grid .get_value (nxt ) != "#" :
110- yield (nxt , letter )
111-
112- def dijkstra (
113- self ,
114- starts : set [State ],
115- is_end : Callable [[State ], bool ],
116- adjacent : Callable [[State ], Iterator [State ]],
117- get_distance : Callable [[State , State ], int ],
118- ) -> dict [State , int ]:
104+ def adjacent (self , state : State , mode : int ) -> Iterator [State ]:
105+ r , c , dir = state
106+ direction = IDX_TO_DIR [dir ]
107+ for d in TURNS [direction ]:
108+ yield (r , c , DIR_TO_IDX [d ])
109+ nr , nc = r - mode * direction .y , c + mode * direction .x
110+ if self .grid .values [nr ][nc ] != "#" :
111+ yield (nr , nc , dir )
112+
113+ def dijkstra (self , starts : set [State ], mode : int ) -> dict [State , int ]:
119114 q : PriorityQueue [tuple [int , State ]] = PriorityQueue ()
120115 for s in starts :
121116 q .put ((0 , s ))
@@ -125,50 +120,54 @@ def dijkstra(
125120 while not q .empty ():
126121 dist , node = q .get ()
127122 curr_dist = dists [node ]
128- for n in adjacent (node ):
129- new_dist = curr_dist + get_distance ( node , n )
123+ for n in self . adjacent (node , mode ):
124+ new_dist = curr_dist + ( 1 if node [ 2 ] == n [ 2 ] else 1000 )
130125 if new_dist < dists [n ]:
131126 dists [n ] = new_dist
132127 q .put ((new_dist , n ))
133128 return dists
134129
135- def forward_distances (self ) -> dict [State , int ]:
136- return self .dijkstra (
137- {(self .start , START_DIR )},
138- lambda node : node [0 ] == self .end ,
139- self .adjacent_forward ,
140- lambda curr , nxt : 1 if curr [1 ] == nxt [1 ] else 1000 ,
130+ def forward_distances (self ) -> tuple [dict [State , int ], int ]:
131+ starts = {(self .start .row , self .start .col , DIR_TO_IDX [START_DIR ])}
132+ distances = self .dijkstra (starts , FORWARD )
133+ best = next (
134+ v
135+ for k , v in distances .items ()
136+ if (k [0 ], k [1 ]) == (self .end .row , self .end .col )
141137 )
138+ return distances , best
142139
143140 def backward_distances (self ) -> dict [State , int ]:
141+ starts = itertools .product (
142+ [self .end ],
143+ (DIR_TO_IDX [dir ] for dir in Direction .capitals ()),
144+ )
144145 return self .dijkstra (
145- {_ for _ in itertools .product ([self .end ], DIRS )},
146- lambda node : node [0 ] == self .start ,
147- self .adjacent_backward ,
148- lambda curr , nxt : 1 if curr [1 ] == nxt [1 ] else 1000 ,
146+ set ((s [0 ].row , s [0 ].col , s [1 ]) for s in starts ), BACKWARD
149147 )
150148
151149 def parse_input (self , input_data : InputData ) -> Input :
152150 return CharGrid .from_strings (list (input_data ))
153151
154152 def part_1 (self , grid : Input ) -> Output1 :
155153 maze = Solution .ReindeerMaze .from_grid (grid )
156- distances = maze .forward_distances ()
157- return next ( v for k , v in distances . items () if k [ 0 ] == maze . end )
154+ _ , best = maze .forward_distances ()
155+ return best
158156
159157 def part_2 (self , grid : Input ) -> Output2 :
160158 maze = Solution .ReindeerMaze .from_grid (grid )
161- forw_dists = maze .forward_distances ()
162- best = next (v for k , v in forw_dists .items () if k [0 ] == maze .end )
163- backw_dists = maze .backward_distances ()
159+ forward_distances , best = maze .forward_distances ()
160+ backward_distances = maze .backward_distances ()
164161 all_tile_states = itertools .product (
165162 grid .find_all_matching (lambda cell : grid .get_value (cell ) != "#" ),
166- DIRS ,
163+ ( DIR_TO_IDX [ dir ] for dir in Direction . capitals ()) ,
167164 )
168165 best_tiles = {
169166 cell
170167 for cell , dir in all_tile_states
171- if forw_dists [(cell , dir )] + backw_dists [(cell , dir )] == best
168+ if forward_distances [(cell .row , cell .col , dir )]
169+ + backward_distances [(cell .row , cell .col , dir )]
170+ == best
172171 }
173172 return len (best_tiles )
174173
0 commit comments