diff --git a/appdaemon/dependency.py b/appdaemon/dependency.py index 8787d185c..812099f3f 100644 --- a/appdaemon/dependency.py +++ b/appdaemon/dependency.py @@ -1,6 +1,7 @@ import ast import logging -from collections.abc import Generator +from collections.abc import Generator, Mapping +from graphlib import TopologicalSorter from pathlib import Path from typing import Iterable @@ -206,58 +207,18 @@ def find_all_dependents( return visited -class CircularDependency(Exception): - pass - - -def topo_sort(graph: dict[str, set[str]]) -> list[str]: +def topo_sort(graph: Mapping[str, set[str]]) -> list[str]: """Topological sort Args: graph (Mapping[str, set[str]]): Dependency graph Raises: - CircularDependency: Raised if a cycle is detected + CycleError: Raised if a cycle is detected Returns: list[str]: Ordered list of the nodes """ - visited = list() - stack = list() - rec_stack = set() # Set to track nodes in the current recursion stack - cycle_detected = False # Flag to indicate cycle detection - - def _node_gen(): - for node, edges in graph.items(): - yield node - if edges: - yield from edges - - nodes = set(_node_gen()) - - def visit(node: str): - nonlocal cycle_detected - if node in rec_stack: - cycle_detected = True - return - elif node in visited: - return - - visited.append(node) - rec_stack.add(node) - - adjacent_nodes = graph.get(node) or set() - for adj_node in adjacent_nodes: - visit(adj_node) - - rec_stack.remove(node) - stack.append(node) - - for node in nodes: - if node not in visited: - visit(node) - if cycle_detected: - deps = graph[node] - raise CircularDependency(f"Visited {visited} already, but {node} depends on {deps}") - return stack + ts = TopologicalSorter(graph) + return list(ts.static_order())