-
Notifications
You must be signed in to change notification settings - Fork 6
Simplify Locs, remove -1s; symlink Input nodes in start_graph, and graph in each map-elem/loop-iter #280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Simplify Locs, remove -1s; symlink Input nodes in start_graph, and graph in each map-elem/loop-iter #280
Changes from 12 commits
89717b5
92de749
70fb35e
0af5259
2e48e46
ecf0e08
9ea962a
ace7c46
233c2b2
12090f1
d0f0bc0
a66eb64
0386882
d29ce52
3ea56de
a3c47de
3a35fb8
163b90c
4ef2440
d35bd23
cb939f7
4b50336
ef04bd4
a19ce7c
c2e26d0
844c089
e027bfc
6ff25b6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,4 @@ | ||
| import os | ||
| from pathlib import Path | ||
|
|
||
| BODY_PORT = "body" | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmmm, perhaps I should use this everywhere (and move to labels.py) rather than delete it!? |
||
| PACKAGE_PATH = Path(os.path.dirname(os.path.realpath(__file__))) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,7 @@ | |
| from pathlib import Path | ||
| import subprocess | ||
| import sys | ||
| from typing import Sequence | ||
|
|
||
| from tierkreis.controller.data.core import PortID | ||
| from tierkreis.controller.data.types import bytes_from_ptype, ptype_from_bytes | ||
|
|
@@ -22,24 +23,44 @@ | |
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| # ALAN this should really be NodeRunTask (or RunNodeTask) | ||
acl-cqc marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| @dataclass | ||
| class NodeRunData: | ||
| node_location: Loc | ||
| node: NodeDef | ||
| output_list: list[PortID] | ||
|
|
||
|
|
||
| def start_nodes( | ||
| @dataclass | ||
| class LoopIterTask: | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is actually |
||
| iter_location: Loc | ||
| graph_input: OutputLoc | ||
| inputs: dict[PortID, OutputLoc] | ||
|
|
||
|
|
||
| Task = NodeRunData | LoopIterTask | ||
|
|
||
|
|
||
| def start_tasks( | ||
| storage: ControllerStorage, | ||
| executor: ControllerExecutor, | ||
| node_run_data: list[NodeRunData], | ||
| tasks: Sequence[Task], | ||
| enable_logging: bool = True, | ||
| ) -> None: | ||
| started_locs: set[Loc] = set() | ||
| for node_run_datum in node_run_data: | ||
| if node_run_datum.node_location in started_locs: | ||
| continue | ||
| start(storage, executor, node_run_datum) | ||
| started_locs.add(node_run_datum.node_location) | ||
| for task in tasks: | ||
| if isinstance(task, LoopIterTask): | ||
| start_graph( | ||
| storage, | ||
| executor, | ||
| task.iter_location, | ||
| task.graph_input, | ||
| task.inputs, | ||
| ) | ||
| started_locs.add(task.iter_location) | ||
| elif task.node_location not in started_locs: | ||
| start(storage, executor, task, enable_logging) | ||
| started_locs.add(task.node_location) | ||
|
|
||
|
|
||
| def run_builtin(def_path: Path, logs_path: Path) -> None: | ||
|
|
@@ -92,9 +113,7 @@ def start( | |
| executor.run(launcher_name, call_args_path) | ||
|
|
||
| elif node.type == "input": | ||
| input_loc = parent.N(-1) | ||
| storage.link_outputs(node_location, node.name, input_loc, node.name) | ||
| storage.mark_node_finished(node_location) | ||
| assert storage.is_node_finished(node_location) | ||
|
|
||
| elif node.type == "output": | ||
| storage.mark_node_finished(node_location) | ||
|
|
@@ -108,51 +127,31 @@ def start( | |
| storage.mark_node_finished(node_location) | ||
|
|
||
| elif node.type == "eval": | ||
| message = storage.read_output(parent.N(node.graph[0]), node.graph[1]) | ||
| g = ptype_from_bytes(message, GraphData) | ||
| ins["body"] = (parent.N(node.graph[0]), node.graph[1]) | ||
| ins.update(g.fixed_inputs) | ||
|
|
||
| pipe_inputs_to_output_location(storage, node_location.N(-1), ins) | ||
| graph_input = (parent.N(node.graph[0]), node.graph[1]) | ||
| start_graph(storage, executor, node_location, graph_input, ins) | ||
|
|
||
| elif node.type == "loop": | ||
| ins["body"] = (parent.N(node.body[0]), node.body[1]) | ||
| pipe_inputs_to_output_location(storage, node_location.N(-1), ins) | ||
| graph_input = (parent.N(node.body[0]), node.body[1]) | ||
| if ( | ||
| node.name is not None | ||
| ): # should we do this only in debug mode? -> need to think through how this would work | ||
| storage.write_debug_data(node.name, node_location) | ||
| start( | ||
| storage, | ||
| executor, | ||
| NodeRunData( | ||
| node_location.L(0), | ||
| Eval((-1, "body"), {k: (-1, k) for k, _ in ins.items()}, node.outputs), | ||
| output_list, | ||
| ), | ||
| ) | ||
| start_graph(storage, executor, node_location.L(0), graph_input, ins) | ||
|
|
||
| elif node.type == "map": | ||
| first_ref = next(x for x in ins.values() if x[1] == "*") | ||
| map_eles = outputs_iter(storage, first_ref[0]) | ||
| if not map_eles: | ||
| storage.mark_node_finished(node_location) | ||
| graph_input = (parent.N(node.body[0]), node.body[1]) | ||
| for idx, p in map_eles: | ||
| eval_inputs: dict[PortID, tuple[Loc, PortID]] = {} | ||
| eval_inputs["body"] = (parent.N(node.body[0]), node.body[1]) | ||
| for k, (i, port) in ins.items(): | ||
| if port == "*": | ||
| eval_inputs[k] = (i, p) | ||
| else: | ||
| eval_inputs[k] = (i, port) | ||
| pipe_inputs_to_output_location( | ||
| storage, node_location.M(idx).N(-1), eval_inputs | ||
| ) | ||
| # Necessary in the node visualization | ||
| storage.write_node_def( | ||
| node_location.M(idx), Eval((-1, "body"), node.inputs, node.outputs) | ||
| start_graph( | ||
| storage, | ||
| executor, | ||
| node_location.M(idx), | ||
| graph_input, | ||
| {k: (i, p if port == "*" else port) for k, (i, port) in ins.items()}, | ||
| ) | ||
|
|
||
| elif node.type == "ifelse": | ||
| pass | ||
|
|
||
|
|
@@ -162,6 +161,32 @@ def start( | |
| assert_never(node) | ||
|
|
||
|
|
||
| def start_graph( | ||
| storage: ControllerStorage, | ||
| executor: ControllerExecutor, | ||
| loc: Loc, | ||
| graph_input: OutputLoc, | ||
| ins: dict[PortID, OutputLoc], | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I could add an |
||
| ) -> None: | ||
| # We have to write something here to mark the node/graph as started. | ||
| # TODO ALAN - pass NodeDef? But can't really make valid/correct inputs. | ||
|
||
| # For now just write a dummy, but don't overwrite if there's a better one already! | ||
| if not storage.is_node_started(loc): | ||
| storage.write_node_def(loc, Eval((-1, "body"), {})) | ||
| message = storage.read_output(*graph_input) | ||
| g = ptype_from_bytes(message, GraphData) | ||
| ins["body"] = graph_input | ||
| ins.update(g.fixed_inputs) | ||
| for i, n in enumerate(g.nodes): | ||
| if n.type == "input": | ||
| input_loc = loc.N(i) | ||
| if value := ins.get(n.name): | ||
| storage.link_outputs(input_loc, n.name, *value) | ||
| # else, ideally we'd check if that input is optional and error if not, | ||
| # but since we don't have the graph type here, we'll assume it's optional! | ||
| storage.mark_node_finished(input_loc) | ||
|
|
||
|
|
||
| def pipe_inputs_to_output_location( | ||
| storage: ControllerStorage, | ||
| output_loc: Loc, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.