1717
1818def get_parent_position (position : int ) -> int :
1919 """
20- heap helper function get the position of the parent of the current node
20+ Heap helper function to get the position of the parent of the current node
2121
2222 >>> get_parent_position(1)
2323 0
@@ -29,7 +29,7 @@ def get_parent_position(position: int) -> int:
2929
3030def get_child_left_position (position : int ) -> int :
3131 """
32- heap helper function get the position of the left child of the current node
32+ Heap helper function to get the position of the left child of the current node
3333
3434 >>> get_child_left_position(0)
3535 1
@@ -39,29 +39,27 @@ def get_child_left_position(position: int) -> int:
3939
4040def get_child_right_position (position : int ) -> int :
4141 """
42- heap helper function get the position of the right child of the current node
42+ Heap helper function to get the position of the right child of the current node
4343
4444 >>> get_child_right_position(0)
4545 2
4646 """
4747 return (2 * position ) + 2
4848
4949
50+
5051class MinPriorityQueue (Generic [T ]):
5152 """
5253 Minimum Priority Queue Class
5354
5455 Functions:
55- is_empty: function to check if the priority queue is empty
56- push: function to add an element with given priority to the queue
57- extract_min: function to remove and return the element with lowest weight (highest
58- priority)
59- update_key: function to update the weight of the given key
60- _bubble_up: helper function to place a node at the proper position (upward
61- movement)
62- _bubble_down: helper function to place a node at the proper position (downward
63- movement)
64- _swap_nodes: helper function to swap the nodes at the given positions
56+ is_empty: Check if the priority queue is empty
57+ push: Add an element with given priority to the queue
58+ extract_min: Remove and return the element with lowest weight (highest priority)
59+ update_key: Update the weight of the given key
60+ _bubble_up: Place a node at proper position (upward movement)
61+ _bubble_down: Place a node at proper position (downward movement)
62+ _swap_nodes: Swap nodes at given positions
6563
6664 >>> queue = MinPriorityQueue()
6765
@@ -95,18 +93,18 @@ def __repr__(self) -> str:
9593 return str (self .heap )
9694
9795 def is_empty (self ) -> bool :
98- # Check if the priority queue is empty
96+ """ Check if the priority queue is empty"""
9997 return self .elements == 0
10098
10199 def push (self , elem : T , weight : int ) -> None :
102- # Add an element with given priority to the queue
100+ """ Add an element with given priority to the queue"""
103101 self .heap .append ((elem , weight ))
104102 self .position_map [elem ] = self .elements
105103 self .elements += 1
106104 self ._bubble_up (elem )
107105
108106 def extract_min (self ) -> T :
109- # Remove and return the element with lowest weight (highest priority)
107+ """ Remove and return the element with lowest weight (highest priority)"""
110108 if self .elements > 1 :
111109 self ._swap_nodes (0 , self .elements - 1 )
112110 elem , _ = self .heap .pop ()
@@ -117,8 +115,8 @@ def extract_min(self) -> T:
117115 self ._bubble_down (bubble_down_elem )
118116 return elem
119117
120- def update_key (self , elem : T , weight : int ) -> None : # 修复了这里的类型提示
121- # Update the weight of the given key
118+ def update_key (self , elem : T , weight : int ) -> None :
119+ """ Update the weight of the given key"""
122120 position = self .position_map [elem ]
123121 self .heap [position ] = (elem , weight )
124122 if position > 0 :
@@ -130,49 +128,51 @@ def update_key(self, elem: T, weight: int) -> None: # 修复了这里的类型
130128 self ._bubble_down (elem )
131129 else :
132130 self ._bubble_down (elem )
133-
134131 def _bubble_up (self , elem : T ) -> None :
135- # Place a node at the proper position (upward movement) [to be used internally
136- # only]
132+ """Place node at proper position (upward movement) - internal use only"""
137133 curr_pos = self .position_map [elem ]
138134 if curr_pos == 0 :
139- return None
135+ return
140136 parent_position = get_parent_position (curr_pos )
141137 _ , weight = self .heap [curr_pos ]
142138 _ , parent_weight = self .heap [parent_position ]
143139 if parent_weight > weight :
144140 self ._swap_nodes (parent_position , curr_pos )
145- return self ._bubble_up (elem )
146- return None
141+ self ._bubble_up (elem )
147142
148143 def _bubble_down (self , elem : T ) -> None :
149- # Place a node at the proper position (downward movement) [to be used
150- # internally only]
144+ """Place node at proper position (downward movement) - internal use only"""
151145 curr_pos = self .position_map [elem ]
152146 _ , weight = self .heap [curr_pos ]
153147 child_left_position = get_child_left_position (curr_pos )
154148 child_right_position = get_child_right_position (curr_pos )
149+
150+ # Check if both children exist
155151 if child_left_position < self .elements and child_right_position < self .elements :
156152 _ , child_left_weight = self .heap [child_left_position ]
157153 _ , child_right_weight = self .heap [child_right_position ]
158154 if child_right_weight < child_left_weight and child_right_weight < weight :
159155 self ._swap_nodes (child_right_position , curr_pos )
160- return self ._bubble_down (elem )
156+ self ._bubble_down (elem )
157+ return
158+
159+ # Check left child
161160 if child_left_position < self .elements :
162161 _ , child_left_weight = self .heap [child_left_position ]
163162 if child_left_weight < weight :
164163 self ._swap_nodes (child_left_position , curr_pos )
165- return self ._bubble_down (elem )
166- else :
167- return None
164+ self ._bubble_down (elem )
165+ return
166+
167+ # Check right child
168168 if child_right_position < self .elements :
169169 _ , child_right_weight = self .heap [child_right_position ]
170170 if child_right_weight < weight :
171171 self ._swap_nodes (child_right_position , curr_pos )
172- return self ._bubble_down (elem )
173- return None
172+ self ._bubble_down (elem )
173+
174174 def _swap_nodes (self , node1_pos : int , node2_pos : int ) -> None :
175- # Swap the nodes at the given positions
175+ """ Swap nodes at given positions"""
176176 node1_elem = self .heap [node1_pos ][0 ]
177177 node2_elem = self .heap [node2_pos ][0 ]
178178 self .heap [node1_pos ], self .heap [node2_pos ] = (
@@ -188,8 +188,8 @@ class GraphUndirectedWeighted(Generic[T]):
188188 Graph Undirected Weighted Class
189189
190190 Functions:
191- add_node: function to add a node in the graph
192- add_edge: function to add an edge between 2 nodes in the graph
191+ add_node: Add a node to the graph
192+ add_edge: Add an edge between two nodes with given weight
193193 """
194194
195195 def __init__ (self ) -> None :
@@ -203,13 +203,13 @@ def __len__(self) -> int:
203203 return self .nodes
204204
205205 def add_node (self , node : T ) -> None :
206- # Add a node in the graph if it is not in the graph
206+ """ Add a node to the graph if not already present"""
207207 if node not in self .connections :
208208 self .connections [node ] = {}
209209 self .nodes += 1
210210
211211 def add_edge (self , node1 : T , node2 : T , weight : int ) -> None :
212- # Add an edge between 2 nodes in the graph
212+ """ Add an edge between two nodes with given weight"""
213213 self .add_node (node1 )
214214 self .add_node (node2 )
215215 self .connections [node1 ][node2 ] = weight
@@ -218,10 +218,11 @@ def add_edge(self, node1: T, node2: T, weight: int) -> None:
218218
219219def prims_algo (
220220 graph : GraphUndirectedWeighted [T ],
221- ) -> tuple [dict [T , int ], dict [T , Optional [T ]]]: # 使用 Optional[T] 替代 T | None
221+ ) -> tuple [dict [T , int ], dict [T , Optional [T ]]]:
222222 """
223- >>> graph = GraphUndirectedWeighted()
223+ Prim's algorithm for minimum spanning tree
224224
225+ >>> graph = GraphUndirectedWeighted()
225226 >>> graph.add_edge("a", "b", 3)
226227 >>> graph.add_edge("b", "c", 10)
227228 >>> graph.add_edge("c", "d", 5)
@@ -230,39 +231,52 @@ def prims_algo(
230231
231232 >>> dist, parent = prims_algo(graph)
232233
233- >>> abs( dist["a"] - dist[" b"])
234+ >>> dist["b"]
234235 3
235- >>> abs(dist["d"] - dist["b"])
236- 15
237- >>> abs(dist["a"] - dist["c"])
238- 13
236+ >>> dist["c"]
237+ 10
238+ >>> dist["d"]
239+ 5
240+ >>> parent["b"]
241+ 'a'
242+ >>> parent["c"]
243+ 'b'
244+ >>> parent["d"]
245+ 'c'
239246 """
240- # prim's algorithm for minimum spanning tree
241- dist : dict [T , int ] = dict . fromkeys ( graph .connections , maxsize )
242- parent : dict [T , Optional [T ]] = dict . fromkeys ( graph .connections , None )
247+ # Initialize distance and parent dictionaries
248+ dist : dict [T , int ] = { node : maxsize for node in graph .connections }
249+ parent : dict [T , Optional [T ]] = { node : None for node in graph .connections }
243250
251+ # Create priority queue and add all nodes
244252 priority_queue : MinPriorityQueue [T ] = MinPriorityQueue ()
245253 for node in graph .connections :
246254 priority_queue .push (node , dist [node ])
247255
256+ # Return if graph is empty
248257 if priority_queue .is_empty ():
249258 return dist , parent
250-
251- # initialization
252- node = priority_queue .extract_min ()
253- dist [node ] = 0
254- for neighbour in graph .connections [node ]:
255- if dist [neighbour ] > graph .connections [node ][neighbour ]:
256- dist [neighbour ] = graph .connections [node ][neighbour ]
257- priority_queue .update_key (neighbour , dist [neighbour ])
258- parent [neighbour ] = node
259-
260- # running prim's algorithm
259+ # Start with first node
260+ start_node = priority_queue .extract_min ()
261+ dist [start_node ] = 0
262+
263+ # Update neighbors of start node
264+ for neighbor , weight in graph .connections [start_node ].items ():
265+ if dist [neighbor ] > weight :
266+ dist [neighbor ] = weight
267+ priority_queue .update_key (neighbor , weight )
268+ parent [neighbor ] = start_node
269+
270+ # Main algorithm loop
261271 while not priority_queue .is_empty ():
262272 node = priority_queue .extract_min ()
263- for neighbour in graph .connections [node ]
264- if dist [neighbour ] > graph .connections [node ][neighbour ]:
265- dist [neighbour ] = graph .connections [node ][neighbour ]
266- priority_queue .update_key (neighbour , dist [neighbour ])
267- parent [neighbour ] = node
273+
274+ # Explore neighbors of current node
275+ for neighbor , weight in graph .connections [node ].items ():
276+ # Update if found better connection to tree
277+ if dist [neighbor ] > weight :
278+ dist [neighbor ] = weight
279+ priority_queue .update_key (neighbor , weight )
280+ parent [neighbor ] = node
281+
268282 return dist , parent
0 commit comments