Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions src/fromager/dependency_graph.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import dataclasses
import graphlib
import json
import logging
import pathlib
import threading
import typing

from packaging.requirements import Requirement
Expand Down Expand Up @@ -312,3 +314,133 @@ def _depth_first_traversal(
yield from self._depth_first_traversal(
edge.destination_node.children, visited, match_dep_types
)


class TrackingTopologicalSorter:
"""A thread-safe topological sorter that tracks nodes in progress

``TopologicalSorter.get_ready()`` returns each node only once. The
tracking topological sorter keeps track which nodes are marked as done.
The ``get_available()`` method returns nodes again and again, until
they are marked as done. The graph is active until all nodes are marked
as done.

Individual nodes can be marked as exclusive nodes. ``get_available``
treats exclusive nodes special and returns:

1. one or more non-exclusive nodes
2. exactly one exclusive node that is a predecessor of another node
3. exactly one exclusive node

The class uses a lock for ``is_activate`, ``get_available`, and ``done``,
so the methods can be used from threading pool and future callback.
"""

__slots__ = (
"_dep_nodes",
"_exclusive_nodes",
"_in_progress_nodes",
"_lock",
"_topo",
)

def __init__(
self,
graph: typing.Mapping[DependencyNode, typing.Iterable[DependencyNode]]
| None = None,
) -> None:
self._topo: graphlib.TopologicalSorter[DependencyNode] = (
graphlib.TopologicalSorter()
)
# set of nodes that are not done, yet
self._in_progress_nodes: set[DependencyNode] = set()
# set of nodes that are predecessors of other nodes
self._dep_nodes: set[DependencyNode] = set()
# dict of nodes -> priority; dependency: -1, leaf: +1
self._exclusive_nodes: dict[DependencyNode, int] = {}
self._lock = threading.Lock()
if graph is not None:
for node, predecessors in graph.items():
self.add(node, *predecessors)

@property
def dependency_nodes(self) -> set[DependencyNode]:
"""Nodes that other nodes depend on"""
return self._dep_nodes.copy()

@property
def exclusive_nodes(self) -> set[DependencyNode]:
"""Nodes that are marked as exclusive"""
return set(self._exclusive_nodes)

def add(
self,
node: DependencyNode,
*predecessors: DependencyNode,
exclusive: bool = False,
) -> None:
"""Add new node

Can be called multiple times for a node to add more predecessors or
to mark a node as exclusive. Exclusive nodes cannot be unmarked.
"""
self._topo.add(node, *predecessors)
self._dep_nodes.update(predecessors)
if exclusive:
self._exclusive_nodes[node] = 1

def prepare(self) -> None:
"""Prepare and check for cyclic dependencies"""
self._topo.prepare()
for node in self._exclusive_nodes:
if node in self._dep_nodes:
# give dependency nodes a higher priority
self._exclusive_nodes[node] = -1

def is_active(self) -> bool:
with self._lock:
return bool(self._in_progress_nodes) or self._topo.is_active()

def __bool__(self) -> bool:
return self.is_active()

def get_available(self) -> set[DependencyNode]:
"""Get available nodes

A node can be returned multiple times until it is marked as 'done'.
"""
with self._lock:
# get ready nodes, update in progress nodes.
ready = self._topo.get_ready()
self._in_progress_nodes.update(ready)

# get and prefer non-exclusive nodes. Exclusive nodes are
# 'heavy' nodes, that that a long time to build. Start with
# 'light' nodes first.
exclusive_nodes = self._exclusive_nodes
non_exclusive = self._in_progress_nodes.difference(exclusive_nodes)
if non_exclusive:
# set.difference() returns a new set object
return non_exclusive

# return a single exclusive node, prefer nodes that are a
# dependency of other nodes.
exclusive = self._in_progress_nodes.intersection(exclusive_nodes)
exclusive_list = sorted(
exclusive,
key=lambda node: (exclusive_nodes[node], node),
)
return {exclusive_list[0]}

def done(self, *nodes: DependencyNode) -> None:
"""Mark nodes as done"""
with self._lock:
self._in_progress_nodes.difference_update(nodes)
self._topo.done(*nodes)

def static_batches(self) -> typing.Iterable[set[DependencyNode]]:
self.prepare()
while self.is_active():
nodes = self.get_available()
yield nodes
self.done(*nodes)
67 changes: 66 additions & 1 deletion tests/test_dependency_graph.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import dataclasses
import typing

import pytest
from packaging.utils import canonicalize_name
from packaging.version import Version

from fromager.dependency_graph import DependencyNode
from fromager.dependency_graph import DependencyNode, TrackingTopologicalSorter


def mknode(name: str, version: str = "1.0", **kwargs) -> DependencyNode:
Expand Down Expand Up @@ -59,3 +60,67 @@ def test_dependencynode_dataclass():
assert root.canonicalized_name == ""
assert root.version == Version("0.0")
assert root.key == ""


def test_tracking_topology_sorter() -> None:
a = mknode("a")
b = mknode("b")
c = mknode("c")
d = mknode("d")
e = mknode("e")
f = mknode("f")

graph: typing.Mapping[DependencyNode, typing.Iterable[DependencyNode]]
graph = {
a: [b, c],
b: [c, d],
d: [e],
f: [d],
}

topo = TrackingTopologicalSorter(graph)
topo.prepare()

assert topo.dependency_nodes == {b, c, d, e}
assert topo.exclusive_nodes == set()
# properties return new objects
assert topo.dependency_nodes is not topo.dependency_nodes
assert topo.exclusive_nodes is not topo.exclusive_nodes

processed: list[DependencyNode] = []
while topo.is_active():
ready = sorted(topo.get_available())
r0 = ready[0]
processed.append(r0)
topo.done(r0)
# c and e have no dependency
# d depends on e
# b after d
# f after d, but sorting pushes it after a
# a on b
assert processed == [c, e, d, b, a, f]

topo = TrackingTopologicalSorter(graph)
assert topo.dependency_nodes == {b, c, d, e}
assert topo.exclusive_nodes == set()
batches = list(topo.static_batches())
assert batches == [
{c, e},
{d},
{b, f},
{a},
]

topo = TrackingTopologicalSorter(graph)
# mark b as exclusive
topo.add(b, exclusive=True)
assert topo.dependency_nodes == {b, c, d, e}
assert topo.exclusive_nodes == {b}
batches = list(topo.static_batches())
assert batches == [
{c, e},
{d},
{f},
{b},
{a},
]
Loading