Skip to content

Commit a978c35

Browse files
authored
Update minimum_spanning_tree_prims2.py
1 parent 8404a25 commit a978c35

File tree

1 file changed

+77
-63
lines changed

1 file changed

+77
-63
lines changed

graphs/minimum_spanning_tree_prims2.py

Lines changed: 77 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
def get_parent_position(position: int) -> int:
1919
"""
20-
heap helper function get the position of the parent of the current node
20+
Heap helper function to get the position of the parent of the current node
2121
2222
>>> get_parent_position(1)
2323
0
@@ -29,7 +29,7 @@ def get_parent_position(position: int) -> int:
2929

3030
def get_child_left_position(position: int) -> int:
3131
"""
32-
heap helper function get the position of the left child of the current node
32+
Heap helper function to get the position of the left child of the current node
3333
3434
>>> get_child_left_position(0)
3535
1
@@ -39,29 +39,27 @@ def get_child_left_position(position: int) -> int:
3939

4040
def get_child_right_position(position: int) -> int:
4141
"""
42-
heap helper function get the position of the right child of the current node
42+
Heap helper function to get the position of the right child of the current node
4343
4444
>>> get_child_right_position(0)
4545
2
4646
"""
4747
return (2 * position) + 2
4848

4949

50+
5051
class MinPriorityQueue(Generic[T]):
5152
"""
5253
Minimum Priority Queue Class
5354
5455
Functions:
55-
is_empty: function to check if the priority queue is empty
56-
push: function to add an element with given priority to the queue
57-
extract_min: function to remove and return the element with lowest weight (highest
58-
priority)
59-
update_key: function to update the weight of the given key
60-
_bubble_up: helper function to place a node at the proper position (upward
61-
movement)
62-
_bubble_down: helper function to place a node at the proper position (downward
63-
movement)
64-
_swap_nodes: helper function to swap the nodes at the given positions
56+
is_empty: Check if the priority queue is empty
57+
push: Add an element with given priority to the queue
58+
extract_min: Remove and return the element with lowest weight (highest priority)
59+
update_key: Update the weight of the given key
60+
_bubble_up: Place a node at proper position (upward movement)
61+
_bubble_down: Place a node at proper position (downward movement)
62+
_swap_nodes: Swap nodes at given positions
6563
6664
>>> queue = MinPriorityQueue()
6765
@@ -95,18 +93,18 @@ def __repr__(self) -> str:
9593
return str(self.heap)
9694

9795
def is_empty(self) -> bool:
98-
# Check if the priority queue is empty
96+
"""Check if the priority queue is empty"""
9997
return self.elements == 0
10098

10199
def push(self, elem: T, weight: int) -> None:
102-
# Add an element with given priority to the queue
100+
"""Add an element with given priority to the queue"""
103101
self.heap.append((elem, weight))
104102
self.position_map[elem] = self.elements
105103
self.elements += 1
106104
self._bubble_up(elem)
107105

108106
def extract_min(self) -> T:
109-
# Remove and return the element with lowest weight (highest priority)
107+
"""Remove and return the element with lowest weight (highest priority)"""
110108
if self.elements > 1:
111109
self._swap_nodes(0, self.elements - 1)
112110
elem, _ = self.heap.pop()
@@ -117,8 +115,8 @@ def extract_min(self) -> T:
117115
self._bubble_down(bubble_down_elem)
118116
return elem
119117

120-
def update_key(self, elem: T, weight: int) -> None: # 修复了这里的类型提示
121-
# Update the weight of the given key
118+
def update_key(self, elem: T, weight: int) -> None:
119+
"""Update the weight of the given key"""
122120
position = self.position_map[elem]
123121
self.heap[position] = (elem, weight)
124122
if position > 0:
@@ -130,49 +128,51 @@ def update_key(self, elem: T, weight: int) -> None: # 修复了这里的类型
130128
self._bubble_down(elem)
131129
else:
132130
self._bubble_down(elem)
133-
134131
def _bubble_up(self, elem: T) -> None:
135-
# Place a node at the proper position (upward movement) [to be used internally
136-
# only]
132+
"""Place node at proper position (upward movement) - internal use only"""
137133
curr_pos = self.position_map[elem]
138134
if curr_pos == 0:
139-
return None
135+
return
140136
parent_position = get_parent_position(curr_pos)
141137
_, weight = self.heap[curr_pos]
142138
_, parent_weight = self.heap[parent_position]
143139
if parent_weight > weight:
144140
self._swap_nodes(parent_position, curr_pos)
145-
return self._bubble_up(elem)
146-
return None
141+
self._bubble_up(elem)
147142

148143
def _bubble_down(self, elem: T) -> None:
149-
# Place a node at the proper position (downward movement) [to be used
150-
# internally only]
144+
"""Place node at proper position (downward movement) - internal use only"""
151145
curr_pos = self.position_map[elem]
152146
_, weight = self.heap[curr_pos]
153147
child_left_position = get_child_left_position(curr_pos)
154148
child_right_position = get_child_right_position(curr_pos)
149+
150+
# Check if both children exist
155151
if child_left_position < self.elements and child_right_position < self.elements:
156152
_, child_left_weight = self.heap[child_left_position]
157153
_, child_right_weight = self.heap[child_right_position]
158154
if child_right_weight < child_left_weight and child_right_weight < weight:
159155
self._swap_nodes(child_right_position, curr_pos)
160-
return self._bubble_down(elem)
156+
self._bubble_down(elem)
157+
return
158+
159+
# Check left child
161160
if child_left_position < self.elements:
162161
_, child_left_weight = self.heap[child_left_position]
163162
if child_left_weight < weight:
164163
self._swap_nodes(child_left_position, curr_pos)
165-
return self._bubble_down(elem)
166-
else:
167-
return None
164+
self._bubble_down(elem)
165+
return
166+
167+
# Check right child
168168
if child_right_position < self.elements:
169169
_, child_right_weight = self.heap[child_right_position]
170170
if child_right_weight < weight:
171171
self._swap_nodes(child_right_position, curr_pos)
172-
return self._bubble_down(elem)
173-
return None
172+
self._bubble_down(elem)
173+
174174
def _swap_nodes(self, node1_pos: int, node2_pos: int) -> None:
175-
# Swap the nodes at the given positions
175+
"""Swap nodes at given positions"""
176176
node1_elem = self.heap[node1_pos][0]
177177
node2_elem = self.heap[node2_pos][0]
178178
self.heap[node1_pos], self.heap[node2_pos] = (
@@ -188,8 +188,8 @@ class GraphUndirectedWeighted(Generic[T]):
188188
Graph Undirected Weighted Class
189189
190190
Functions:
191-
add_node: function to add a node in the graph
192-
add_edge: function to add an edge between 2 nodes in the graph
191+
add_node: Add a node to the graph
192+
add_edge: Add an edge between two nodes with given weight
193193
"""
194194

195195
def __init__(self) -> None:
@@ -203,13 +203,13 @@ def __len__(self) -> int:
203203
return self.nodes
204204

205205
def add_node(self, node: T) -> None:
206-
# Add a node in the graph if it is not in the graph
206+
"""Add a node to the graph if not already present"""
207207
if node not in self.connections:
208208
self.connections[node] = {}
209209
self.nodes += 1
210210

211211
def add_edge(self, node1: T, node2: T, weight: int) -> None:
212-
# Add an edge between 2 nodes in the graph
212+
"""Add an edge between two nodes with given weight"""
213213
self.add_node(node1)
214214
self.add_node(node2)
215215
self.connections[node1][node2] = weight
@@ -218,10 +218,11 @@ def add_edge(self, node1: T, node2: T, weight: int) -> None:
218218

219219
def prims_algo(
220220
graph: GraphUndirectedWeighted[T],
221-
) -> tuple[dict[T, int], dict[T, Optional[T]]]: # 使用 Optional[T] 替代 T | None
221+
) -> tuple[dict[T, int], dict[T, Optional[T]]]:
222222
"""
223-
>>> graph = GraphUndirectedWeighted()
223+
Prim's algorithm for minimum spanning tree
224224
225+
>>> graph = GraphUndirectedWeighted()
225226
>>> graph.add_edge("a", "b", 3)
226227
>>> graph.add_edge("b", "c", 10)
227228
>>> graph.add_edge("c", "d", 5)
@@ -230,39 +231,52 @@ def prims_algo(
230231
231232
>>> dist, parent = prims_algo(graph)
232233
233-
>>> abs(dist["a"] - dist["b"])
234+
>>> dist["b"]
234235
3
235-
>>> abs(dist["d"] - dist["b"])
236-
15
237-
>>> abs(dist["a"] - dist["c"])
238-
13
236+
>>> dist["c"]
237+
10
238+
>>> dist["d"]
239+
5
240+
>>> parent["b"]
241+
'a'
242+
>>> parent["c"]
243+
'b'
244+
>>> parent["d"]
245+
'c'
239246
"""
240-
# prim's algorithm for minimum spanning tree
241-
dist: dict[T, int] = dict.fromkeys(graph.connections, maxsize)
242-
parent: dict[T, Optional[T]] = dict.fromkeys(graph.connections, None)
247+
# Initialize distance and parent dictionaries
248+
dist: dict[T, int] = {node: maxsize for node in graph.connections}
249+
parent: dict[T, Optional[T]] = {node: None for node in graph.connections}
243250

251+
# Create priority queue and add all nodes
244252
priority_queue: MinPriorityQueue[T] = MinPriorityQueue()
245253
for node in graph.connections:
246254
priority_queue.push(node, dist[node])
247255

256+
# Return if graph is empty
248257
if priority_queue.is_empty():
249258
return dist, parent
250-
251-
# initialization
252-
node = priority_queue.extract_min()
253-
dist[node] = 0
254-
for neighbour in graph.connections[node]:
255-
if dist[neighbour] > graph.connections[node][neighbour]:
256-
dist[neighbour] = graph.connections[node][neighbour]
257-
priority_queue.update_key(neighbour, dist[neighbour])
258-
parent[neighbour] = node
259-
260-
# running prim's algorithm
259+
# Start with first node
260+
start_node = priority_queue.extract_min()
261+
dist[start_node] = 0
262+
263+
# Update neighbors of start node
264+
for neighbor, weight in graph.connections[start_node].items():
265+
if dist[neighbor] > weight:
266+
dist[neighbor] = weight
267+
priority_queue.update_key(neighbor, weight)
268+
parent[neighbor] = start_node
269+
270+
# Main algorithm loop
261271
while not priority_queue.is_empty():
262272
node = priority_queue.extract_min()
263-
for neighbour in graph.connections[node]
264-
if dist[neighbour] > graph.connections[node][neighbour]:
265-
dist[neighbour] = graph.connections[node][neighbour]
266-
priority_queue.update_key(neighbour, dist[neighbour])
267-
parent[neighbour] = node
273+
274+
# Explore neighbors of current node
275+
for neighbor, weight in graph.connections[node].items():
276+
# Update if found better connection to tree
277+
if dist[neighbor] > weight:
278+
dist[neighbor] = weight
279+
priority_queue.update_key(neighbor, weight)
280+
parent[neighbor] = node
281+
268282
return dist, parent

0 commit comments

Comments
 (0)