|
1 | 1 | import ast |
2 | 2 | import logging |
3 | | -from collections.abc import Generator |
| 3 | +from collections.abc import Generator, Mapping |
| 4 | +from graphlib import TopologicalSorter |
4 | 5 | from pathlib import Path |
5 | 6 | from typing import Iterable |
6 | 7 |
|
@@ -206,58 +207,18 @@ def find_all_dependents( |
206 | 207 | return visited |
207 | 208 |
|
208 | 209 |
|
209 | | -class CircularDependency(Exception): |
210 | | - pass |
211 | | - |
212 | | - |
213 | | -def topo_sort(graph: dict[str, set[str]]) -> list[str]: |
| 210 | +def topo_sort(graph: Mapping[str, set[str]]) -> list[str]: |
214 | 211 | """Topological sort |
215 | 212 |
|
216 | 213 | Args: |
217 | 214 | graph (Mapping[str, set[str]]): Dependency graph |
218 | 215 |
|
219 | 216 | Raises: |
220 | | - CircularDependency: Raised if a cycle is detected |
| 217 | + CycleError: Raised if a cycle is detected |
221 | 218 |
|
222 | 219 | Returns: |
223 | 220 | list[str]: Ordered list of the nodes |
224 | 221 | """ |
225 | | - visited = list() |
226 | | - stack = list() |
227 | | - rec_stack = set() # Set to track nodes in the current recursion stack |
228 | | - cycle_detected = False # Flag to indicate cycle detection |
229 | | - |
230 | | - def _node_gen(): |
231 | | - for node, edges in graph.items(): |
232 | | - yield node |
233 | | - if edges: |
234 | | - yield from edges |
235 | | - |
236 | | - nodes = set(_node_gen()) |
237 | | - |
238 | | - def visit(node: str): |
239 | | - nonlocal cycle_detected |
240 | | - if node in rec_stack: |
241 | | - cycle_detected = True |
242 | | - return |
243 | | - elif node in visited: |
244 | | - return |
245 | | - |
246 | | - visited.append(node) |
247 | | - rec_stack.add(node) |
248 | | - |
249 | | - adjacent_nodes = graph.get(node) or set() |
250 | | - for adj_node in adjacent_nodes: |
251 | | - visit(adj_node) |
252 | | - |
253 | | - rec_stack.remove(node) |
254 | | - stack.append(node) |
255 | | - |
256 | | - for node in nodes: |
257 | | - if node not in visited: |
258 | | - visit(node) |
259 | | - if cycle_detected: |
260 | | - deps = graph[node] |
261 | | - raise CircularDependency(f"Visited {visited} already, but {node} depends on {deps}") |
262 | 222 |
|
263 | | - return stack |
| 223 | + ts = TopologicalSorter(graph) |
| 224 | + return list(ts.static_order()) |
0 commit comments