1010from __future__ import annotations
1111
1212from sys import maxsize
13- from typing import Generic , TypeVar , Optional
13+ from typing import Generic , TypeVar
1414
1515T = TypeVar ("T" )
1616
@@ -146,7 +146,7 @@ def _bubble_down(self, elem: T) -> None:
146146 _ , weight = self .heap [curr_pos ]
147147 child_left_position = get_child_left_position (curr_pos )
148148 child_right_position = get_child_right_position (curr_pos )
149-
149+
150150 # Check if both children exist
151151 if child_left_position < self .elements and child_right_position < self .elements :
152152 _ , child_left_weight = self .heap [child_left_position ]
@@ -155,15 +155,15 @@ def _bubble_down(self, elem: T) -> None:
155155 self ._swap_nodes (child_right_position , curr_pos )
156156 self ._bubble_down (elem )
157157 return
158-
158+
159159 # Check left child
160160 if child_left_position < self .elements :
161161 _ , child_left_weight = self .heap [child_left_position ]
162162 if child_left_weight < weight :
163163 self ._swap_nodes (child_left_position , curr_pos )
164164 self ._bubble_down (elem )
165165 return
166-
166+
167167 # Check right child
168168 if child_right_position < self .elements :
169169 _ , child_right_weight = self .heap [child_right_position ]
@@ -218,7 +218,7 @@ def add_edge(self, node1: T, node2: T, weight: int) -> None:
218218
219219def prims_algo (
220220 graph : GraphUndirectedWeighted [T ],
221- ) -> tuple [dict [T , int ], dict [T , Optional [ T ] ]]:
221+ ) -> tuple [dict [T , int ], dict [T , T | None ]]:
222222 """
223223 Prim's algorithm for minimum spanning tree
224224
@@ -244,9 +244,9 @@ def prims_algo(
244244 >>> parent["d"]
245245 'c'
246246 """
247- # Initialize distance and parent dictionaries
248- dist : dict [T , int ] = { node : maxsize for node in graph .connections }
249- parent : dict [T , Optional [ T ]] = { node : None for node in graph .connections }
247+ # Initialize distance and parent dictionaries using dict.fromkeys
248+ dist : dict [T , int ] = dict . fromkeys ( graph .connections , maxsize )
249+ parent : dict [T , T | None ] = dict . fromkeys ( graph .connections , None )
250250
251251 # Create priority queue and add all nodes
252252 priority_queue : MinPriorityQueue [T ] = MinPriorityQueue ()
@@ -256,6 +256,7 @@ def prims_algo(
256256 # Return if graph is empty
257257 if priority_queue .is_empty ():
258258 return dist , parent
259+
259260 # Start with first node
260261 start_node = priority_queue .extract_min ()
261262 dist [start_node ] = 0
@@ -270,7 +271,7 @@ def prims_algo(
270271 # Main algorithm loop
271272 while not priority_queue .is_empty ():
272273 node = priority_queue .extract_min ()
273-
274+
274275 # Explore neighbors of current node
275276 for neighbor , weight in graph .connections [node ].items ():
276277 # Update if found better connection to tree
0 commit comments