1010from __future__ import annotations
1111
1212from sys import maxsize
13- from typing import TypeVar # Keep only TypeVar import, remove Generic
13+ from typing import Generic , TypeVar , Optional
1414
1515T = TypeVar ("T" )
1616
@@ -47,7 +47,7 @@ def get_child_right_position(position: int) -> int:
4747 return (2 * position ) + 2
4848
4949
50- class MinPriorityQueue [T ]: # Updated: use square brackets for generic class
50+ class MinPriorityQueue ( Generic [T ]):
5151 """
5252 Minimum Priority Queue Class
5353
@@ -90,6 +90,7 @@ def __init__(self) -> None:
9090
9191 def __len__ (self ) -> int :
9292 return self .elements
93+
9394 def __repr__ (self ) -> str :
9495 return str (self .heap )
9596
@@ -116,7 +117,7 @@ def extract_min(self) -> T:
116117 self ._bubble_down (bubble_down_elem )
117118 return elem
118119
119- def update_key (self , elem : T , weight : int ) -> None :
120+ def update_key (self , elem : T , weight : int ) -> None : # 修复了这里的类型提示
120121 # Update the weight of the given key
121122 position = self .position_map [elem ]
122123 self .heap [position ] = (elem , weight )
@@ -155,8 +156,34 @@ def _bubble_down(self, elem: T) -> None:
155156 _ , child_left_weight = self .heap [child_left_position ]
156157 _ , child_right_weight = self .heap [child_right_position ]
157158 if child_right_weight < child_left_weight and child_right_weight < weight :
158-
159- class GraphUndirectedWeighted [T ]: # Updated: use square brackets for generic class
159+ self ._swap_nodes (child_right_position , curr_pos )
160+ return self ._bubble_down (elem )
161+ if child_left_position < self .elements :
162+ _ , child_left_weight = self .heap [child_left_position ]
163+ if child_left_weight < weight :
164+ self ._swap_nodes (child_left_position , curr_pos )
165+ return self ._bubble_down (elem )
166+ else :
167+ return None
168+ if child_right_position < self .elements :
169+ _ , child_right_weight = self .heap [child_right_position ]
170+ if child_right_weight < weight :
171+ self ._swap_nodes (child_right_position , curr_pos )
172+ return self ._bubble_down (elem )
173+ return None
174+ def _swap_nodes (self , node1_pos : int , node2_pos : int ) -> None :
175+ # Swap the nodes at the given positions
176+ node1_elem = self .heap [node1_pos ][0 ]
177+ node2_elem = self .heap [node2_pos ][0 ]
178+ self .heap [node1_pos ], self .heap [node2_pos ] = (
179+ self .heap [node2_pos ],
180+ self .heap [node1_pos ],
181+ )
182+ self .position_map [node1_elem ] = node2_pos
183+ self .position_map [node2_elem ] = node1_pos
184+
185+
186+ class GraphUndirectedWeighted (Generic [T ]):
160187 """
161188 Graph Undirected Weighted Class
162189
@@ -189,9 +216,9 @@ def add_edge(self, node1: T, node2: T, weight: int) -> None:
189216 self .connections [node2 ][node1 ] = weight
190217
191218
192- def prims_algo [ T ]( # Updated: add type parameter for generic function
219+ def prims_algo (
193220 graph : GraphUndirectedWeighted [T ],
194- ) -> tuple [dict [T , int ], dict [T , T | None ]]:
221+ ) -> tuple [dict [T , int ], dict [T , Optional [ T ]]]: # 使用 Optional[T] 替代 T | None
195222 """
196223 >>> graph = GraphUndirectedWeighted()
197224
@@ -212,11 +239,11 @@ def prims_algo[T]( # Updated: add type parameter for generic function
212239 """
213240 # prim's algorithm for minimum spanning tree
214241 dist : dict [T , int ] = dict .fromkeys (graph .connections , maxsize )
215- parent : dict [T , T | None ] = dict .fromkeys (graph .connections )
242+ parent : dict [T , Optional [ T ]] = dict .fromkeys (graph .connections , None )
216243
217244 priority_queue : MinPriorityQueue [T ] = MinPriorityQueue ()
218- for node , weight in dist . items () :
219- priority_queue .push (node , weight )
245+ for node in graph . connections :
246+ priority_queue .push (node , dist [ node ] )
220247
221248 if priority_queue .is_empty ():
222249 return dist , parent
@@ -225,17 +252,17 @@ def prims_algo[T]( # Updated: add type parameter for generic function
225252 node = priority_queue .extract_min ()
226253 dist [node ] = 0
227254 for neighbour in graph .connections [node ]:
228- if dist [neighbour ] > dist [ node ] + graph .connections [node ][neighbour ]:
229- dist [neighbour ] = dist [ node ] + graph .connections [node ][neighbour ]
255+ if dist [neighbour ] > graph .connections [node ][neighbour ]:
256+ dist [neighbour ] = graph .connections [node ][neighbour ]
230257 priority_queue .update_key (neighbour , dist [neighbour ])
231258 parent [neighbour ] = node
232259
233260 # running prim's algorithm
234261 while not priority_queue .is_empty ():
235262 node = priority_queue .extract_min ()
236- for neighbour in graph .connections [node ]:
237- if dist [neighbour ] > dist [ node ] + graph .connections [node ][neighbour ]:
238- dist [neighbour ] = dist [ node ] + graph .connections [node ][neighbour ]
263+ for neighbour in graph .connections [node ]
264+ if dist [neighbour ] > graph .connections [node ][neighbour ]:
265+ dist [neighbour ] = graph .connections [node ][neighbour ]
239266 priority_queue .update_key (neighbour , dist [neighbour ])
240267 parent [neighbour ] = node
241268 return dist , parent
0 commit comments