Skip to content

Commit 4420ef2

Browse files
committed
automated component build order via kahn's algorithm on dependency graph, working on cycle detection and sending useful dev errors
1 parent c454822 commit 4420ef2

File tree

1 file changed

+151
-42
lines changed

1 file changed

+151
-42
lines changed

src/koi_net/assembler.py

Lines changed: 151 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from collections import deque
12
import inspect
23
from enum import StrEnum
4+
from pprint import pp
35
from typing import Any, Protocol, Self
46
from dataclasses import make_dataclass
57

@@ -15,78 +17,181 @@ class CompType(StrEnum):
1517
FACTORY = "FACTORY"
1618
OBJECT = "OBJECT"
1719

18-
class BuildOrderer(type):
19-
def __new__(cls, name: str, bases: tuple, dct: dict[str]):
20-
"""Sets `cls._build_order` from component order in class definition."""
21-
cls = super().__new__(cls, name, bases, dct)
22-
23-
if "_build_order" not in dct:
24-
components: dict[str, Any] = {}
25-
# adds components from base classes (including cls)
26-
for base in reversed(inspect.getmro(cls)[:-1]):
27-
for k, v in vars(base).items():
28-
# excludes built in and private attributes
29-
if not k.startswith("_"):
30-
components[k] = v
31-
32-
# recipe list constructed from names of non-None components
33-
cls._build_order = [
34-
name for name, _type in components.items()
35-
if _type is not None
36-
]
37-
38-
return cls
3920

4021
class NodeContainer(Protocol):
4122
"""Dummy 'shape' for node containers built by assembler."""
4223
entrypoint = EntryPoint
4324

44-
class NodeAssembler(metaclass=BuildOrderer):
25+
class NodeAssembler:
26+
27+
4528
# Self annotation lying to type checker to reflect typing set in node blueprints
4629
def __new__(self) -> Self:
4730
"""Returns assembled node container."""
48-
return self._build_node()
31+
32+
comps = self._collect_comps()
33+
# pp(list(comps.keys()))
34+
adj, comp_types = self._build_deps(comps)
35+
# pp(adj)
36+
# pp(comp_types)
37+
build_order = self._build_order(adj)
38+
# pp(build_order)
39+
components = self._build_comps(build_order, adj, comp_types)
40+
node = self._build_node(components)
41+
42+
old = list(comps.keys())
43+
new = build_order
44+
45+
result = []
46+
47+
for idx, item in enumerate(new):
48+
old_idx = old.index(item)
49+
if old_idx == idx:
50+
result.append(f"{idx}. {item}")
51+
else:
52+
result.append(f"{idx}. {item} (moved from {old_idx})")
53+
54+
# print("\n".join(result))
55+
56+
return node
57+
58+
@classmethod
59+
def _collect_comps(cls):
60+
comps: dict[str, Any] = {}
61+
# adds components from base classes, including cls)
62+
for base in inspect.getmro(cls)[:-1]:
63+
for k, v in vars(base).items():
64+
# excludes built in, private, and `None` attributes
65+
if k.startswith("_") or v is None:
66+
continue
67+
comps[k] = v
68+
return comps
4969

5070
@classmethod
51-
def _build_deps(cls) -> dict[str, tuple[CompType, list[str]]]:
71+
def _build_deps(cls, comps) -> tuple[dict[str, list[str]], dict[str, CompType]]:
5272
"""Returns dependency graph for components defined in `cls_build_order`.
5373
5474
Graph representation is a dict where each key is a component name,
5575
and the value is tuple containing the component type, and a list
5676
of dependency component names.
5777
"""
5878

79+
comp_types = {}
5980
dep_graph = {}
60-
for comp_name in cls._build_order:
81+
for comp_name in comps:
6182
try:
6283
comp = getattr(cls, comp_name)
6384
except AttributeError:
6485
raise Exception(f"Component '{comp_name}' not found in class definition")
6586

6687
if not callable(comp):
67-
comp_type = CompType.OBJECT
88+
comp_types[comp_name] = CompType.OBJECT
6889
dep_names = []
6990

7091
elif isinstance(comp, type) and issubclass(comp, BaseModel):
71-
comp_type = CompType.OBJECT
92+
comp_types[comp_name] = CompType.OBJECT
7293
dep_names = []
7394

7495
else:
7596
sig = inspect.signature(comp)
76-
comp_type = CompType.FACTORY
97+
comp_types[comp_name] = CompType.FACTORY
7798
dep_names = list(sig.parameters)
7899

79-
dep_graph[comp_name] = (comp_type, dep_names)
100+
dep_graph[comp_name] = dep_names
101+
102+
return dep_graph, comp_types
103+
104+
@classmethod
105+
def _find_cycle(cls, adj) -> list[str]:
106+
visited = set()
107+
stack = []
108+
on_stack = set()
109+
110+
def dfs(node):
111+
visited.add(node)
112+
stack.append(node)
113+
on_stack.add(node)
114+
115+
for nxt in adj[node]:
116+
if nxt not in visited:
117+
cycle = dfs(nxt)
118+
if cycle:
119+
return cycle
120+
121+
elif nxt in on_stack:
122+
idx = stack.index(nxt)
123+
return stack[idx:] + [nxt]
124+
125+
stack.pop()
126+
on_stack.remove(node)
127+
return None
128+
129+
for node in adj:
130+
if node not in visited:
131+
cycle = dfs(node)
132+
if cycle:
133+
return cycle
134+
135+
return None
136+
137+
@classmethod
138+
def _build_order(cls, adj) -> list[str]:
139+
# adj list: n -> outgoing neighbors
140+
141+
# reverse adj list: n -> incoming neighbors
142+
r_adj: dict[str, list[str]] = {}
143+
144+
# computes reverse adjacency list
145+
for node in adj:
146+
r_adj.setdefault(node, [])
147+
for n in adj[node]:
148+
r_adj.setdefault(n, [])
149+
r_adj[n].append(node)
150+
151+
out_degree: dict[str, int] = {
152+
n: len(neighbors)
153+
for n, neighbors in adj.items()
154+
}
155+
156+
queue = deque()
157+
for node in out_degree:
158+
if out_degree[node] == 0:
159+
queue.append(node)
160+
161+
ordered: list[str] = []
162+
while queue:
163+
n = queue.popleft()
164+
ordered.append(n)
165+
for next_n in r_adj[n]:
166+
out_degree[next_n] -= 1
167+
if out_degree[next_n] == 0:
168+
queue.append(next_n)
169+
170+
171+
172+
if len(ordered) != len(adj):
173+
cycle_nodes = set(adj.keys()) - set(ordered)
174+
cycle_adj = {}
175+
for n in list(cycle_nodes):
176+
cycle_adj[n] = set(adj[n]) & cycle_nodes
177+
print(n, "->", cycle_adj[n])
178+
179+
cycle = cls._find_cycle(cycle_adj)
180+
181+
print("FOUND CYCLE")
182+
print(" -> ".join(cycle))
80183

81-
return dep_graph
184+
print(len(ordered), "/", len(adj))
185+
186+
return ordered
82187

83188
@classmethod
84189
def _visualize(cls) -> str:
85190
"""Returns representation of dependency graph in Graphviz DOT language."""
86191
dep_graph = cls._build_deps()
87192

88193
s = "digraph G {\n"
89-
for node, (_, neighbors) in dep_graph.items():
194+
for node, neighbors in dep_graph.items():
90195
sub_s = node
91196
if neighbors:
92197
sub_s += f"-> {', '.join(neighbors)}"
@@ -96,32 +201,36 @@ def _visualize(cls) -> str:
96201
return s
97202

98203
@classmethod
99-
def _build_comps(cls) -> dict[str, Any]:
204+
def _build_comps(
205+
cls,
206+
build_order: list[str],
207+
dep_graph: dict[str, list[str]],
208+
comp_type: dict[str, CompType]
209+
) -> dict[str, Any]:
100210
"""Returns assembled components from dependency graph."""
101-
dep_graph = cls._build_deps()
102211

103212
components: dict[str, Any] = {}
104-
for comp_name, (comp_type, dep_names) in dep_graph.items():
213+
for comp_name in build_order:
214+
# for comp_name, (comp_type, dep_names) in dep_graph.items():
105215
comp = getattr(cls, comp_name, None)
106216

107-
if comp_type == CompType.OBJECT:
217+
if comp_type[comp_name] == CompType.OBJECT:
108218
components[comp_name] = comp
109219

110-
elif comp_type == CompType.FACTORY:
220+
elif comp_type[comp_name] == CompType.FACTORY:
111221
# builds depedency dict for current component
112222
dependencies = {}
113-
for dep_name in dep_names:
114-
if dep_name not in components:
115-
raise Exception(f"Couldn't find required component '{dep_name}'")
116-
dependencies[dep_name] = components[dep_name]
223+
for dep in dep_graph[comp_name]:
224+
if dep not in components:
225+
raise Exception(f"Couldn't find required component '{dep}'")
226+
dependencies[dep] = components[dep]
117227
components[comp_name] = comp(**dependencies)
118228

119229
return components
120230

121231
@classmethod
122-
def _build_node(cls) -> NodeContainer:
232+
def _build_node(cls, components: dict[str, Any]) -> NodeContainer:
123233
"""Returns node container from components."""
124-
components = cls._build_comps()
125234

126235
NodeContainer = make_dataclass(
127236
cls_name="NodeContainer",

0 commit comments

Comments
 (0)