Skip to content

Commit 3ed898a

Browse files
committed
Pass: op.cond all test
1 parent 0dea901 commit 3ed898a

File tree

9 files changed

+112
-112
lines changed

9 files changed

+112
-112
lines changed

tico/experimental/controlflow/passes/map_subgraph.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,14 @@
2020
import torch
2121
from torch.export import ExportedProgram
2222

23-
from tico.serialize.quant_param import QPARAM_KEY, QuantParam
2423
from tico.utils import logging
2524
from tico.utils.passes import PassBase, PassResult
2625
from tico.utils.trace_decorators import trace_graph_diff_on_pass
27-
from tico.utils.utils import get_quant_dtype
2826
from tico.utils.validate_args_kwargs import CondArgs
2927
from tico.utils.graph import create_node
30-
from tico.utils.subgraph import get_gm_map
28+
from tico.utils.subgraph import get_all_graph_modules, freeze_subgraphs
3129
import operator
30+
from torch.utils import _pytree as pytree
3231

3332
@trace_graph_diff_on_pass
3433
class MapSubgraph(PassBase):
@@ -53,27 +52,50 @@ def call(self, exported_program: ExportedProgram, _) -> PassResult:
5352
continue
5453

5554
cond_args = CondArgs(*node.args, **node.kwargs)
55+
true_graph = cond_args.true_graph
56+
false_graph = cond_args.false_graph
57+
graph_args = cond_args.cond_args
5658

57-
true_graph_idx = None
58-
false_graph_idx = None
59-
for gm_info in get_gm_map(exported_program):
60-
if gm_info["name"] == cond_args.true_graph.name:
61-
true_graph_idx = gm_info["index"]
62-
continue
63-
if gm_info["name"] == cond_args.false_graph.name:
64-
false_graph_idx = gm_info["index"]
65-
continue
66-
assert true_graph_idx is not None
67-
assert false_graph_idx is not None
59+
def _set_meta_val(graph_node, graph_module, graph_args):
60+
def _get_meta_val(node):
61+
assert hasattr(node, 'meta'), f"'node' has no attribute named 'meta' (node: {node})"
62+
assert "val" in node.meta, f"val key not in node.meta (node: {node}, meta: {node.meta})"
63+
return node.meta["val"]
64+
65+
args, kwargs = pytree.tree_map_only(
66+
torch.fx.Node,
67+
_get_meta_val,
68+
(graph_args, {}),
69+
)
70+
71+
new_val = graph_module(*args, **kwargs) # type: ignore[operator]
72+
graph_node.meta["val"] = new_val
73+
74+
for graph_module, name in get_all_graph_modules(exported_program, subgraph_only=True):
75+
if true_graph.name == name:
76+
_set_meta_val(true_graph, graph_module, graph_args)
77+
if false_graph.name == name:
78+
_set_meta_val(false_graph, graph_module, graph_args)
79+
80+
assert "val" in true_graph.meta, f"{true_graph} has no node.meta['val']"
81+
assert "val" in false_graph.meta, f"{false_graph} has no node.meta['val']"
6882

83+
freeze_subgraphs(exported_program)
6984
with graph.inserting_before(node):
7085
circle_if = create_node(
7186
graph,
7287
torch.ops.circle_custom.if_,
73-
args=(cond_args.condition, true_graph_idx, false_graph_idx, cond_args.cond_args),
88+
args=(cond_args.condition, cond_args.true_graph, cond_args.false_graph, cond_args.cond_args),
7489
kwargs={},
7590
origin=node,
7691
)
92+
93+
for t, f in zip(true_graph.meta['val'], false_graph.meta['val']):
94+
assert type(t) == type(f)
95+
assert t.shape == f.shape, f"{t.shape} != {f.shape}"
96+
assert t.dtype == f.dtype, f"{t.dtype} != {f.dtype}"
97+
98+
circle_if.meta["val"] = true_graph.meta['val'][0]
7799

78100
# FIX ME UNLESS torch.ops.higher_order.cond generates this pattern
79101
assert len(node.users) == 1

tico/serialize/circle_graph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,5 +324,4 @@ def get_tid(
324324
return self.name_to_tid[node_name]
325325

326326
# Unreachable
327-
breakpoint()
328327
raise RuntimeError("fx Node was not converted to tensor.")

tico/serialize/circle_serializer.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,14 @@
2828
from tico.serialize.operators.node_visitor import get_node_visitors
2929
from tico.utils import logging
3030
from tico.utils.serialize import finalise_tensor_names, validate_tensor_shapes
31+
from tico.utils.subgraph import get_all_graph_modules
3132

3233

3334
multiple_output_ops = [
3435
torch.ops.aten.split_with_sizes.default,
3536
torch.ops.aten.max.dim,
37+
# torch.ops.circle_custom.if_.default,
38+
# torch.ops.circle_custom.if_,
3639
]
3740

3841
def _initialize_model() -> tuple[CircleModel, CircleSubgraph]:
@@ -46,7 +49,6 @@ def _initialize_model() -> tuple[CircleModel, CircleSubgraph]:
4649
graph = CircleSubgraph(model)
4750
return model, graph
4851

49-
from tico.utils.subgraph import get_gm_map
5052

5153
def build_circle(
5254
ep: ExportedProgram, config: CompileConfigBase = get_default_config()
@@ -64,21 +66,17 @@ def build_circle(
6466
model = CircleModel()
6567

6668
op_codes: Dict[OpCode, int] = {}
67-
68-
for gm_info in get_gm_map(ep):
69-
if gm_info["name"]: #non-root subgraph
70-
graph_module = getattr(ep.graph_module, gm_info["name"])
71-
else:
72-
graph_module = ep.graph_module
69+
for graph_module, name in get_all_graph_modules(ep):
7370
ep_graph = graph_module.graph
74-
7571
graph = CircleSubgraph(model)
72+
7673
# Export tensors
77-
if gm_info["name"]: #non-root subgraph
78-
_export_tensors_for_subgraph(graph, ep_graph, ep)
79-
else:
74+
if name == '': # root graph
8075
_export_tensors(graph, ep_graph, ep)
81-
if gm_info["index"] == 0: # Root graph
76+
else:
77+
_export_tensors_for_subgraph(graph, ep_graph, ep)
78+
79+
if name == '': # root graph
8280
# Register inputs
8381
logger.debug("---------------Register inputs--------------")
8482
for in_spec in ep.graph_signature.input_specs:
@@ -108,7 +106,6 @@ def build_circle(
108106
# Export operators
109107
logger.debug("---------------Export operators--------------")
110108
visitors = get_node_visitors(op_codes, graph)
111-
ep_graph.print_tabular()
112109
for node in ep_graph.nodes:
113110
if node.op != "call_function":
114111
continue
@@ -161,8 +158,6 @@ def _export_tensors(graph: CircleSubgraph, ep_graph, ep: ExportedProgram) -> Non
161158
if node.target in multiple_output_ops:
162159
continue
163160
node_val = node.meta["val"]
164-
if node.name == 'cond':
165-
continue
166161
if node_val.layout != torch.strided:
167162
raise RuntimeError(
168163
f"Only support dense tensors (node layout: {node_val.layout})"
@@ -179,7 +174,7 @@ def _export_tensors(graph: CircleSubgraph, ep_graph, ep: ExportedProgram) -> Non
179174
elif node.op == "output":
180175
for output in node.args[0]:
181176
if isinstance(output, torch.fx.Node):
182-
assert graph.has_tensor(output.name)
177+
assert graph.has_tensor(output.name), f"{output}"
183178
continue
184179

185180
elif node.op == "call_method":

tico/serialize/operators/op_circle_if.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
2727
from tico.utils.validate_args_kwargs import CircleIfArgs
2828
from tico.utils.errors import NotYetSupportedError
29-
29+
from tico.utils.subgraph import get_frozen_subgraphs
3030

3131
@register_node_visitor
3232
class CircleIfVisitor(NodeVisitor):
33-
target: List[torch._ops.OpOverload] = [torch.ops.circle_custom.if_]
33+
target: List[torch._ops.OpOverload] = [torch.ops.circle_custom.if_, torch.ops.circle_custom.if_.default]
3434

3535
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
3636
super().__init__(op_codes, graph)
@@ -45,17 +45,28 @@ def define_node(
4545
if_args = CircleIfArgs(*node.args, **node.kwargs)
4646

4747
pred = if_args.pred
48-
then_idx = if_args.then_graph_idx
49-
else_idx = if_args.else_graph_idx
48+
then_graph = if_args.then_graph
49+
else_graph = if_args.else_graph
5050
arguments = if_args.if_args
5151

52+
then_graph_idx = None
53+
else_graph_idx = None
54+
for frozen_subgraph in get_frozen_subgraphs():
55+
if frozen_subgraph.name == then_graph.name:
56+
then_graph_idx = frozen_subgraph.idx
57+
if frozen_subgraph.name == else_graph.name:
58+
else_graph_idx = frozen_subgraph.idx
59+
assert then_graph_idx is not None
60+
assert else_graph_idx is not None
61+
62+
5263
inputs = [pred, *arguments]
5364
outputs = [node]
54-
65+
# outputs = [i for i in node.users.keys()]
5566
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
5667
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.IfOptions
5768
operator.builtinOptions = circle.IfOptions.IfOptionsT()
58-
operator.builtinOptions.thenSubgraphIndex = then_idx
59-
operator.builtinOptions.elseSubgraphIndex = else_idx
69+
operator.builtinOptions.thenSubgraphIndex = then_graph_idx
70+
operator.builtinOptions.elseSubgraphIndex = else_graph_idx
6071

6172
return operator

tico/utils/convert.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
trace_graph_diff_on_func,
8888
)
8989
from tico.utils.utils import has_quantization_ops, SuppressWarning
90-
from tico.utils.subgraph import get_gm_map
90+
from tico.utils.subgraph import get_all_graph_modules
9191

9292

9393
@trace_const_diff_on_func
@@ -199,11 +199,7 @@ def convert_exported_module_to_circle(
199199

200200
assert isinstance(config, CompileConfigBase)
201201

202-
for gm_info in get_gm_map(exported_program):
203-
if gm_info["name"]: #non-root subgraph
204-
graph_module = getattr(exported_program.graph_module, gm_info["name"])
205-
else:
206-
graph_module = exported_program.graph_module
202+
for graph_module, _ in get_all_graph_modules(exported_program):
207203
logger = logging.getLogger(__name__)
208204
logger.debug("Input ExportedProgram (must be core aten)")
209205
logger.debug(exported_program)
@@ -232,13 +228,9 @@ def convert_exported_module_to_circle(
232228
# CompositeImplicitAutograd and have functional schema are safe to not decompose.
233229
exported_program = traced_run_decompositions(exported_program)
234230

235-
for gm_info in get_gm_map(exported_program):
236-
if gm_info["name"]: #non-root subgraph
237-
graph_module = getattr(exported_program.graph_module, gm_info["name"])
238-
else:
239-
graph_module = exported_program.graph_module
231+
for graph_module, _ in get_all_graph_modules(exported_program):
240232
graph = graph_module.graph
241-
233+
242234
reinterpret_pass = PassManager(
243235
passes=[
244236
MapSubgraph(),

tico/utils/register_custom_op.py

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,10 @@
1919
from torch.library import custom_op, register_fake
2020

2121
from tico.utils.mx.mx_ops import _quantize_mx
22-
from tico.utils.subgraph import get_gm_map
2322

24-
# Note that an operator assumes input tensor has NHWC format.
2523
def CircleIf():
2624
@custom_op("circle_custom::if_", mutates_args=())
27-
def if_(pred: torch.Tensor, true_graph_idx: int, false_graph_idx: int, if_args: List[torch.Tensor]) -> torch.Tensor:
28-
true_graph = None
29-
false_graph = None
30-
for gm_info in get_gm_map():
31-
if gm_info["index"] == true_graph_idx:
32-
true_graph = gm_info["gm"]
33-
continue
34-
if gm_info["index"] == false_graph_idx:
35-
false_graph = gm_info["gm"]
36-
continue
37-
25+
def if_(pred: torch.Tensor, true_graph: torch.Tensor, false_graph: torch.Tensor, if_args: List[torch.Tensor]) -> torch.Tensor:
3826
if pred:
3927
result = true_graph(*if_args)
4028
assert len(result) == 1 # TODO: Support tuple of result
@@ -45,23 +33,12 @@ def if_(pred: torch.Tensor, true_graph_idx: int, false_graph_idx: int, if_args:
4533
return result[0]
4634

4735
@register_fake("circle_custom::if_")
48-
def _(pred: torch.Tensor, true_graph_idx: int, false_graph_idx: int, if_args: List):
49-
true_graph = None
50-
false_graph = None
51-
for gm_info in get_gm_map():
52-
if gm_info["index"] == true_graph_idx:
53-
true_graph = gm_info["gm"]
54-
continue
55-
if gm_info["index"] == false_graph_idx:
56-
false_graph = gm_info["gm"]
57-
continue
58-
36+
def _(pred: torch.Tensor, true_graph: torch.Tensor, false_graph: torch.Tensor, if_args: List[torch.Tensor]):
5937
result = true_graph(*if_args)
6038
assert len(result) == 1 # TODO: Support tuple of result
6139

6240
return result[0]
6341

64-
6542
# Note that an operator assumes input tensor has NHWC format.
6643
def CircleResizeNearestNeighbor():
6744
@custom_op("circle_custom::resize_nearest_neighbor", mutates_args=())

tico/utils/subgraph.py

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,42 @@
11
import torch
22
from torch.export import ExportedProgram
3-
from typing import Optional
4-
import functools
3+
from copy import deepcopy
4+
from typing import Iterator, List, Iterator
5+
from dataclasses import dataclass
6+
@dataclass
7+
class FrozenSubgraph:
8+
idx: int
9+
name: str # model-wise, unique name
10+
frozen_graph_module: torch.fx.GraphModule # copied subgraph
511

6-
_gm_map = None
7-
def get_gm_map(ep: Optional[ExportedProgram] = None):
12+
_frozen_subgraphs: List[FrozenSubgraph] = []
13+
14+
def freeze_subgraphs(ep: ExportedProgram):
815
"""
9-
Returns [{"index":0, "name": "true_graph_0", "getter": lambda ep: ep.graph_module}, ...}]
16+
Freeze subgraphs to provide shape inference logic of FakeTensor.
1017
"""
11-
# Build _gm_map only once while compiler running
12-
global _gm_map
13-
if _gm_map is None:
14-
assert ep is not None
15-
_gm_map = _build_gm_map(ep)
16-
return _gm_map
18+
for idx, (graph_module, name) in enumerate(get_all_graph_modules(ep, subgraph_only=True), start = 1):
19+
global _frozen_subgraphs
20+
_frozen_subgraphs += [FrozenSubgraph(idx = idx, name = name, frozen_graph_module = deepcopy(graph_module))]
1721

18-
def _build_gm_map(ep: ExportedProgram):
19-
ret = []
20-
21-
# root GraphModule 추가
22-
ret.append({
23-
"index": len(ret),
24-
"name": "",
25-
"gm": ep.graph_module,
26-
})
22+
def get_frozen_subgraphs() -> List[FrozenSubgraph]:
23+
global _frozen_subgraphs
24+
return _frozen_subgraphs
25+
26+
27+
def get_all_graph_modules(ep: ExportedProgram, subgraph_only: bool = False) -> Iterator[tuple[torch.fx.GraphModule, str]]:
28+
"""
29+
Get all graph modules and its name
30+
"""
31+
if not subgraph_only:
32+
yield ep.graph_module, "" # root has no name
2733

28-
# Inspect non-root subgraphs
34+
# yield subgraphs
2935
for node in ep.graph.nodes:
3036
if node.op == "get_attr":
31-
attr = getattr(node.graph.owning_module, node.target)
37+
graph_module = getattr(node.graph.owning_module, node.target)
3238

3339
# TODO: Enable recursion (n-depth)
34-
if isinstance(attr, torch.fx.graph_module.GraphModule):
35-
assert hasattr(node, 'name')
36-
assert getattr(node, 'name') != ret[0]["name"]
37-
graph_name = getattr(node, 'name')
38-
ret.append({
39-
"index": len(ret),
40-
"name": graph_name,
41-
"gm": attr,
42-
})
43-
return ret
40+
if isinstance(graph_module, torch.fx.graph_module.GraphModule):
41+
assert hasattr(graph_module, 'meta')
42+
yield graph_module, getattr(node, 'name')

0 commit comments

Comments
 (0)