diff --git a/src/fromager/dependency_graph.py b/src/fromager/dependency_graph.py index efde31cf..809e849e 100644 --- a/src/fromager/dependency_graph.py +++ b/src/fromager/dependency_graph.py @@ -1,9 +1,11 @@ from __future__ import annotations import dataclasses +import graphlib import json import logging import pathlib +import threading import typing from packaging.requirements import Requirement @@ -312,3 +314,133 @@ def _depth_first_traversal( yield from self._depth_first_traversal( edge.destination_node.children, visited, match_dep_types ) + + +class TrackingTopologicalSorter: + """A thread-safe topological sorter that tracks nodes in progress + + ``TopologicalSorter.get_ready()`` returns each node only once. The + tracking topological sorter keeps track which nodes are marked as done. + The ``get_available()`` method returns nodes again and again, until + they are marked as done. The graph is active until all nodes are marked + as done. + + Individual nodes can be marked as exclusive nodes. ``get_available`` + treats exclusive nodes special and returns: + + 1. one or more non-exclusive nodes + 2. exactly one exclusive node that is a predecessor of another node + 3. exactly one exclusive node + + The class uses a lock for ``is_activate`, ``get_available`, and ``done``, + so the methods can be used from threading pool and future callback. + """ + + __slots__ = ( + "_dep_nodes", + "_exclusive_nodes", + "_in_progress_nodes", + "_lock", + "_topo", + ) + + def __init__( + self, + graph: typing.Mapping[DependencyNode, typing.Iterable[DependencyNode]] + | None = None, + ) -> None: + self._topo: graphlib.TopologicalSorter[DependencyNode] = ( + graphlib.TopologicalSorter() + ) + # set of nodes that are not done, yet + self._in_progress_nodes: set[DependencyNode] = set() + # set of nodes that are predecessors of other nodes + self._dep_nodes: set[DependencyNode] = set() + # dict of nodes -> priority; dependency: -1, leaf: +1 + self._exclusive_nodes: dict[DependencyNode, int] = {} + self._lock = threading.Lock() + if graph is not None: + for node, predecessors in graph.items(): + self.add(node, *predecessors) + + @property + def dependency_nodes(self) -> set[DependencyNode]: + """Nodes that other nodes depend on""" + return self._dep_nodes.copy() + + @property + def exclusive_nodes(self) -> set[DependencyNode]: + """Nodes that are marked as exclusive""" + return set(self._exclusive_nodes) + + def add( + self, + node: DependencyNode, + *predecessors: DependencyNode, + exclusive: bool = False, + ) -> None: + """Add new node + + Can be called multiple times for a node to add more predecessors or + to mark a node as exclusive. Exclusive nodes cannot be unmarked. + """ + self._topo.add(node, *predecessors) + self._dep_nodes.update(predecessors) + if exclusive: + self._exclusive_nodes[node] = 1 + + def prepare(self) -> None: + """Prepare and check for cyclic dependencies""" + self._topo.prepare() + for node in self._exclusive_nodes: + if node in self._dep_nodes: + # give dependency nodes a higher priority + self._exclusive_nodes[node] = -1 + + def is_active(self) -> bool: + with self._lock: + return bool(self._in_progress_nodes) or self._topo.is_active() + + def __bool__(self) -> bool: + return self.is_active() + + def get_available(self) -> set[DependencyNode]: + """Get available nodes + + A node can be returned multiple times until it is marked as 'done'. + """ + with self._lock: + # get ready nodes, update in progress nodes. + ready = self._topo.get_ready() + self._in_progress_nodes.update(ready) + + # get and prefer non-exclusive nodes. Exclusive nodes are + # 'heavy' nodes, that that a long time to build. Start with + # 'light' nodes first. + exclusive_nodes = self._exclusive_nodes + non_exclusive = self._in_progress_nodes.difference(exclusive_nodes) + if non_exclusive: + # set.difference() returns a new set object + return non_exclusive + + # return a single exclusive node, prefer nodes that are a + # dependency of other nodes. + exclusive = self._in_progress_nodes.intersection(exclusive_nodes) + exclusive_list = sorted( + exclusive, + key=lambda node: (exclusive_nodes[node], node), + ) + return {exclusive_list[0]} + + def done(self, *nodes: DependencyNode) -> None: + """Mark nodes as done""" + with self._lock: + self._in_progress_nodes.difference_update(nodes) + self._topo.done(*nodes) + + def static_batches(self) -> typing.Iterable[set[DependencyNode]]: + self.prepare() + while self.is_active(): + nodes = self.get_available() + yield nodes + self.done(*nodes) diff --git a/tests/test_dependency_graph.py b/tests/test_dependency_graph.py index 5ce835ea..c7ce4b6f 100644 --- a/tests/test_dependency_graph.py +++ b/tests/test_dependency_graph.py @@ -1,10 +1,11 @@ import dataclasses +import typing import pytest from packaging.utils import canonicalize_name from packaging.version import Version -from fromager.dependency_graph import DependencyNode +from fromager.dependency_graph import DependencyNode, TrackingTopologicalSorter def mknode(name: str, version: str = "1.0", **kwargs) -> DependencyNode: @@ -59,3 +60,67 @@ def test_dependencynode_dataclass(): assert root.canonicalized_name == "" assert root.version == Version("0.0") assert root.key == "" + + +def test_tracking_topology_sorter() -> None: + a = mknode("a") + b = mknode("b") + c = mknode("c") + d = mknode("d") + e = mknode("e") + f = mknode("f") + + graph: typing.Mapping[DependencyNode, typing.Iterable[DependencyNode]] + graph = { + a: [b, c], + b: [c, d], + d: [e], + f: [d], + } + + topo = TrackingTopologicalSorter(graph) + topo.prepare() + + assert topo.dependency_nodes == {b, c, d, e} + assert topo.exclusive_nodes == set() + # properties return new objects + assert topo.dependency_nodes is not topo.dependency_nodes + assert topo.exclusive_nodes is not topo.exclusive_nodes + + processed: list[DependencyNode] = [] + while topo.is_active(): + ready = sorted(topo.get_available()) + r0 = ready[0] + processed.append(r0) + topo.done(r0) + # c and e have no dependency + # d depends on e + # b after d + # f after d, but sorting pushes it after a + # a on b + assert processed == [c, e, d, b, a, f] + + topo = TrackingTopologicalSorter(graph) + assert topo.dependency_nodes == {b, c, d, e} + assert topo.exclusive_nodes == set() + batches = list(topo.static_batches()) + assert batches == [ + {c, e}, + {d}, + {b, f}, + {a}, + ] + + topo = TrackingTopologicalSorter(graph) + # mark b as exclusive + topo.add(b, exclusive=True) + assert topo.dependency_nodes == {b, c, d, e} + assert topo.exclusive_nodes == {b} + batches = list(topo.static_batches()) + assert batches == [ + {c, e}, + {d}, + {f}, + {b}, + {a}, + ]