Skip to content

Commit dd16651

Browse files
authored
Update minimum_spanning_tree_prims2.py
1 parent a7e4603 commit dd16651

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

graphs/minimum_spanning_tree_prims2.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from __future__ import annotations
1111

1212
from sys import maxsize
13-
from typing import Generic, TypeVar, Optional
13+
from typing import Generic, TypeVar
1414

1515
T = TypeVar("T")
1616

@@ -146,7 +146,7 @@ def _bubble_down(self, elem: T) -> None:
146146
_, weight = self.heap[curr_pos]
147147
child_left_position = get_child_left_position(curr_pos)
148148
child_right_position = get_child_right_position(curr_pos)
149-
149+
150150
# Check if both children exist
151151
if child_left_position < self.elements and child_right_position < self.elements:
152152
_, child_left_weight = self.heap[child_left_position]
@@ -155,15 +155,15 @@ def _bubble_down(self, elem: T) -> None:
155155
self._swap_nodes(child_right_position, curr_pos)
156156
self._bubble_down(elem)
157157
return
158-
158+
159159
# Check left child
160160
if child_left_position < self.elements:
161161
_, child_left_weight = self.heap[child_left_position]
162162
if child_left_weight < weight:
163163
self._swap_nodes(child_left_position, curr_pos)
164164
self._bubble_down(elem)
165165
return
166-
166+
167167
# Check right child
168168
if child_right_position < self.elements:
169169
_, child_right_weight = self.heap[child_right_position]
@@ -218,7 +218,7 @@ def add_edge(self, node1: T, node2: T, weight: int) -> None:
218218

219219
def prims_algo(
220220
graph: GraphUndirectedWeighted[T],
221-
) -> tuple[dict[T, int], dict[T, Optional[T]]]:
221+
) -> tuple[dict[T, int], dict[T, T | None]]:
222222
"""
223223
Prim's algorithm for minimum spanning tree
224224
@@ -244,9 +244,9 @@ def prims_algo(
244244
>>> parent["d"]
245245
'c'
246246
"""
247-
# Initialize distance and parent dictionaries
248-
dist: dict[T, int] = {node: maxsize for node in graph.connections}
249-
parent: dict[T, Optional[T]] = {node: None for node in graph.connections}
247+
# Initialize distance and parent dictionaries using dict.fromkeys
248+
dist: dict[T, int] = dict.fromkeys(graph.connections, maxsize)
249+
parent: dict[T, T | None] = dict.fromkeys(graph.connections, None)
250250

251251
# Create priority queue and add all nodes
252252
priority_queue: MinPriorityQueue[T] = MinPriorityQueue()
@@ -256,6 +256,7 @@ def prims_algo(
256256
# Return if graph is empty
257257
if priority_queue.is_empty():
258258
return dist, parent
259+
259260
# Start with first node
260261
start_node = priority_queue.extract_min()
261262
dist[start_node] = 0
@@ -270,7 +271,7 @@ def prims_algo(
270271
# Main algorithm loop
271272
while not priority_queue.is_empty():
272273
node = priority_queue.extract_min()
273-
274+
274275
# Explore neighbors of current node
275276
for neighbor, weight in graph.connections[node].items():
276277
# Update if found better connection to tree

0 commit comments

Comments
 (0)