Skip to content

Commit 8404a25

Browse files
authored
Update minimum_spanning_tree_prims2.py
1 parent eed74bf commit 8404a25

File tree

1 file changed

+42
-15
lines changed

1 file changed

+42
-15
lines changed

graphs/minimum_spanning_tree_prims2.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from __future__ import annotations
1111

1212
from sys import maxsize
13-
from typing import TypeVar # Keep only TypeVar import, remove Generic
13+
from typing import Generic, TypeVar, Optional
1414

1515
T = TypeVar("T")
1616

@@ -47,7 +47,7 @@ def get_child_right_position(position: int) -> int:
4747
return (2 * position) + 2
4848

4949

50-
class MinPriorityQueue[T]: # Updated: use square brackets for generic class
50+
class MinPriorityQueue(Generic[T]):
5151
"""
5252
Minimum Priority Queue Class
5353
@@ -90,6 +90,7 @@ def __init__(self) -> None:
9090

9191
def __len__(self) -> int:
9292
return self.elements
93+
9394
def __repr__(self) -> str:
9495
return str(self.heap)
9596

@@ -116,7 +117,7 @@ def extract_min(self) -> T:
116117
self._bubble_down(bubble_down_elem)
117118
return elem
118119

119-
def update_key(self, elem: T, weight: int) -> None:
120+
def update_key(self, elem: T, weight: int) -> None: # 修复了这里的类型提示
120121
# Update the weight of the given key
121122
position = self.position_map[elem]
122123
self.heap[position] = (elem, weight)
@@ -155,8 +156,34 @@ def _bubble_down(self, elem: T) -> None:
155156
_, child_left_weight = self.heap[child_left_position]
156157
_, child_right_weight = self.heap[child_right_position]
157158
if child_right_weight < child_left_weight and child_right_weight < weight:
158-
159-
class GraphUndirectedWeighted[T]: # Updated: use square brackets for generic class
159+
self._swap_nodes(child_right_position, curr_pos)
160+
return self._bubble_down(elem)
161+
if child_left_position < self.elements:
162+
_, child_left_weight = self.heap[child_left_position]
163+
if child_left_weight < weight:
164+
self._swap_nodes(child_left_position, curr_pos)
165+
return self._bubble_down(elem)
166+
else:
167+
return None
168+
if child_right_position < self.elements:
169+
_, child_right_weight = self.heap[child_right_position]
170+
if child_right_weight < weight:
171+
self._swap_nodes(child_right_position, curr_pos)
172+
return self._bubble_down(elem)
173+
return None
174+
def _swap_nodes(self, node1_pos: int, node2_pos: int) -> None:
175+
# Swap the nodes at the given positions
176+
node1_elem = self.heap[node1_pos][0]
177+
node2_elem = self.heap[node2_pos][0]
178+
self.heap[node1_pos], self.heap[node2_pos] = (
179+
self.heap[node2_pos],
180+
self.heap[node1_pos],
181+
)
182+
self.position_map[node1_elem] = node2_pos
183+
self.position_map[node2_elem] = node1_pos
184+
185+
186+
class GraphUndirectedWeighted(Generic[T]):
160187
"""
161188
Graph Undirected Weighted Class
162189
@@ -189,9 +216,9 @@ def add_edge(self, node1: T, node2: T, weight: int) -> None:
189216
self.connections[node2][node1] = weight
190217

191218

192-
def prims_algo[T]( # Updated: add type parameter for generic function
219+
def prims_algo(
193220
graph: GraphUndirectedWeighted[T],
194-
) -> tuple[dict[T, int], dict[T, T | None]]:
221+
) -> tuple[dict[T, int], dict[T, Optional[T]]]: # 使用 Optional[T] 替代 T | None
195222
"""
196223
>>> graph = GraphUndirectedWeighted()
197224
@@ -212,11 +239,11 @@ def prims_algo[T]( # Updated: add type parameter for generic function
212239
"""
213240
# prim's algorithm for minimum spanning tree
214241
dist: dict[T, int] = dict.fromkeys(graph.connections, maxsize)
215-
parent: dict[T, T | None] = dict.fromkeys(graph.connections)
242+
parent: dict[T, Optional[T]] = dict.fromkeys(graph.connections, None)
216243

217244
priority_queue: MinPriorityQueue[T] = MinPriorityQueue()
218-
for node, weight in dist.items():
219-
priority_queue.push(node, weight)
245+
for node in graph.connections:
246+
priority_queue.push(node, dist[node])
220247

221248
if priority_queue.is_empty():
222249
return dist, parent
@@ -225,17 +252,17 @@ def prims_algo[T]( # Updated: add type parameter for generic function
225252
node = priority_queue.extract_min()
226253
dist[node] = 0
227254
for neighbour in graph.connections[node]:
228-
if dist[neighbour] > dist[node] + graph.connections[node][neighbour]:
229-
dist[neighbour] = dist[node] + graph.connections[node][neighbour]
255+
if dist[neighbour] > graph.connections[node][neighbour]:
256+
dist[neighbour] = graph.connections[node][neighbour]
230257
priority_queue.update_key(neighbour, dist[neighbour])
231258
parent[neighbour] = node
232259

233260
# running prim's algorithm
234261
while not priority_queue.is_empty():
235262
node = priority_queue.extract_min()
236-
for neighbour in graph.connections[node]:
237-
if dist[neighbour] > dist[node] + graph.connections[node][neighbour]:
238-
dist[neighbour] = dist[node] + graph.connections[node][neighbour]
263+
for neighbour in graph.connections[node]
264+
if dist[neighbour] > graph.connections[node][neighbour]:
265+
dist[neighbour] = graph.connections[node][neighbour]
239266
priority_queue.update_key(neighbour, dist[neighbour])
240267
parent[neighbour] = node
241268
return dist, parent

0 commit comments

Comments
 (0)