Skip to content

Commit 3592cb5

Browse files
mlazospytorchmergebot
authored andcommitted
[Hierarchical Compilation] Use universal flatten APIs (pytorch#152505)
Pull Request resolved: pytorch#152505 Approved by: https://github.com/anijain2305 ghstack dependencies: pytorch#152389
1 parent 023a3dc commit 3592cb5

File tree

5 files changed

+40
-52
lines changed

5 files changed

+40
-52
lines changed

test/dynamo/test_graph_deduplication.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# flake8: noqa: B950
33
import torch
44
import torch.fx
5-
from torch._dynamo.graph_deduplication import _flatten_args_kwargs
65
from torch._dynamo.graph_utils import _detect_cycles
76
from torch._dynamo.test_case import TestCase
87
from torch._dynamo.testing import AotEagerAndRecordGraphs, normalize_gm
@@ -583,13 +582,6 @@ def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"):
583582
""",
584583
)
585584

586-
def test_flatten_with_slices(self):
587-
tree = [{"x": 3}, ["x", slice(1, 2, 3), 1], [4, 5, 6, [slice(3, 4, 5)]]]
588-
out = _flatten_args_kwargs(tree)
589-
self.assertExpectedInline(
590-
str(out), """[3, 'x', 1, 2, 3, 1, 4, 5, 6, 3, 4, 5]"""
591-
)
592-
593585
def test_cycle_detection_no_cycle(self):
594586
def fn(x, y):
595587
x0 = x + 1

torch/_dynamo/graph_deduplication.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import operator
1212
from collections import defaultdict
1313
from collections.abc import Generator, Iterable
14-
from typing import Any, Optional
14+
from typing import Any
1515

1616
import torch
1717
import torch.fx
@@ -20,7 +20,7 @@
2020
from torch.utils._ordered_set import OrderedSet
2121

2222
from .graph_region_tracker import Node, Region
23-
from .graph_utils import _detect_cycles, _flatten_args_kwargs
23+
from .graph_utils import _detect_cycles, _get_flat_args, _get_flat_args_unique
2424

2525

2626
log = logging.getLogger(__name__)
@@ -92,7 +92,10 @@ def apply_graph_deduplication(output_graph) -> dict[str, torch.fx.GraphModule]:
9292
node_to_additional_deps,
9393
)
9494

95-
_stable_topological_sort(output_graph.graph, node_to_additional_deps)
95+
_stable_topological_sort(
96+
output_graph.graph,
97+
node_to_additional_deps, # type: ignore[arg-type]
98+
)
9699
return sub_gms
97100

98101

@@ -109,7 +112,7 @@ def _replace_region_with_subgraph(
109112
sub_args = []
110113
for node_ind, arg_ind in node_ind_arg_ind:
111114
node = region[node_ind]
112-
flattened_args_kwargs = _flatten_args_kwargs((node.args, node.kwargs))
115+
flattened_args_kwargs = _get_flat_args(node, {})
113116
sub_args.append(flattened_args_kwargs[arg_ind])
114117

115118
invoke_args = (get_subgraph_node, subgraph_name, *sub_args)
@@ -162,7 +165,7 @@ def _get_external_inputs(
162165
external_node_to_indices = dict()
163166
region_unique = set(region)
164167
for node_ind, node in enumerate(region):
165-
flattened_args_kwargs = _flatten_args_kwargs((node.args, node.kwargs))
168+
flattened_args_kwargs = _get_flat_args(node, {})
166169
for arg_ind, in_node in enumerate(flattened_args_kwargs):
167170
if (
168171
isinstance(in_node, Node)
@@ -237,23 +240,9 @@ def _create_subgraph(
237240
return subgraph, node_ind_input_inds
238241

239242

240-
def _args(
241-
n: torch.fx.Node,
242-
node_to_additional_deps: Optional[dict[torch.fx.Node, list[torch.fx.Node]]] = None,
243-
) -> list[torch.fx.node.Argument]:
244-
if node_to_additional_deps is None:
245-
node_to_additional_deps = {}
246-
247-
args: list[torch.fx.node.Argument] = []
248-
torch.fx.map_arg((n.args, n.kwargs), args.append)
249-
if n in node_to_additional_deps:
250-
args.extend(node_to_additional_deps[n])
251-
return args
252-
253-
254243
def _stable_topological_sort(
255244
graph: torch.fx.Graph,
256-
node_to_additional_deps: dict[torch.fx.Node, list[torch.fx.Node]],
245+
node_to_additional_deps: dict[torch.fx.Node, OrderedSet[torch.fx.Node]],
257246
) -> None:
258247
# Nodes are in exactly one of these four collections:
259248

@@ -283,7 +272,9 @@ def _stable_topological_sort(
283272
continue
284273

285274
waiting_for = [
286-
x for x in _args(node, node_to_additional_deps) if x not in ready
275+
x
276+
for x in _get_flat_args_unique(node, node_to_additional_deps)
277+
if x not in ready
287278
]
288279
if waiting_for:
289280
# We have unprocessed input nodes. Might as well wait for the last
@@ -328,7 +319,7 @@ def prev_cur_nodes(
328319
prev_nodes.append(cur_node)
329320

330321
for prev_nodes, cur_node in prev_cur_nodes(all_nodes):
331-
args_unique = _args(cur_node)
322+
args_unique = _get_flat_args_unique(cur_node, {})
332323
additional_deps = node_to_additional_deps[cur_node]
333324
additional_deps.extend(n for n in all_nodes_dep_on if n not in args_unique)
334325
if cur_node.target in global_state_targets:

torch/_dynamo/graph_region_tracker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from torch.utils._ordered_set import OrderedSet
2929
from torch.utils._pytree import tree_flatten
3030

31-
from .graph_utils import _flatten_args_kwargs
31+
from .graph_utils import _get_flat_args_unique
3232

3333

3434
T = TypeVar("T")
@@ -416,7 +416,7 @@ def _populate_recursive_ancestor_map(graph: torch.fx.Graph) -> dict[Node, set[No
416416
for node in graph.nodes:
417417
node_to_recursive_ancestors[node] = set()
418418
for node in graph.nodes:
419-
all_args = _flatten_args_kwargs((node.args, node.kwargs))
419+
all_args = _get_flat_args_unique(node, {})
420420
for arg in all_args:
421421
if isinstance(arg, Node):
422422
node_to_recursive_ancestors[node].update(

torch/_dynamo/graph_utils.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,33 @@
11
from collections import deque
22
from typing import Any
33

4-
from torch.fx import Graph, Node
5-
from torch.utils._pytree import tree_flatten
4+
from torch.fx import Graph, map_arg, Node
5+
from torch.utils._ordered_set import OrderedSet
66

77

88
# flattens with support for slices
99
# Note: a better way to do this would
1010
# be register/unregister slices as pytree nodes
1111
# but there is no unregister API in the pytorch
1212
# pytree impl
13-
def _flatten_args_kwargs(args: Any) -> list[Node]:
14-
fully_flattened = []
15-
16-
def flatten(args: Any) -> None:
17-
flattened, _ = tree_flatten(args)
18-
for arg in flattened:
19-
if isinstance(arg, slice):
20-
start = arg.start
21-
stop = arg.stop
22-
step = arg.step
23-
flatten((start, stop, step))
24-
else:
25-
fully_flattened.append(arg)
26-
27-
flatten(args)
28-
29-
return fully_flattened
13+
def _get_flat_args(
14+
node: Node, node_to_additional_deps: dict[Node, OrderedSet[Node]]
15+
) -> list[Node]:
16+
args = list[Any]()
17+
map_arg((node.args, node.kwargs), args.append)
18+
if node in node_to_additional_deps:
19+
args.extend(node_to_additional_deps[node])
20+
return args
21+
22+
23+
def _get_flat_args_unique(
24+
node: Node, node_to_additional_deps: dict[Node, OrderedSet[Node]]
25+
) -> OrderedSet[Node]:
26+
args = OrderedSet[Node]()
27+
map_arg((node.args, node.kwargs), args.add)
28+
if node in node_to_additional_deps:
29+
args.update(node_to_additional_deps[node])
30+
return args
3031

3132

3233
def _detect_cycles(graph: Graph) -> str:

torch/_dynamo/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@
9292
from torch.utils._triton import has_triton, has_triton_package
9393
from torch.utils.hooks import RemovableHandle
9494

95+
from .graph_utils import _get_flat_args
96+
9597

9698
if typing.TYPE_CHECKING:
9799
from collections.abc import (
@@ -3150,7 +3152,9 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
31503152
args, kwargs = get_fake_values_from_nodes(
31513153
tx, (node.args, node.kwargs), allow_non_graph_fake
31523154
)
3153-
flat_args_kwargs, _ = pytree.tree_flatten((args, kwargs))
3155+
flat_args_kwargs = get_fake_values_from_nodes(
3156+
tx, _get_flat_args(node, {}), allow_non_graph_fake
3157+
)
31543158
id_to_initial_version = {
31553159
id(arg): arg._version for arg in flat_args_kwargs if is_fake(arg)
31563160
}

0 commit comments

Comments
 (0)