2020import torch
2121from torch .export import ExportedProgram
2222
23- from tico .serialize .quant_param import QPARAM_KEY , QuantParam
2423from tico .utils import logging
2524from tico .utils .passes import PassBase , PassResult
2625from tico .utils .trace_decorators import trace_graph_diff_on_pass
27- from tico .utils .utils import get_quant_dtype
2826from tico .utils .validate_args_kwargs import CondArgs
2927from tico .utils .graph import create_node
30- from tico .utils .subgraph import get_gm_map
28+ from tico .utils .subgraph import get_all_graph_modules , freeze_subgraphs
3129import operator
30+ from torch .utils import _pytree as pytree
3231
3332@trace_graph_diff_on_pass
3433class MapSubgraph (PassBase ):
@@ -53,27 +52,50 @@ def call(self, exported_program: ExportedProgram, _) -> PassResult:
5352 continue
5453
5554 cond_args = CondArgs (* node .args , ** node .kwargs )
55+ true_graph = cond_args .true_graph
56+ false_graph = cond_args .false_graph
57+ graph_args = cond_args .cond_args
5658
57- true_graph_idx = None
58- false_graph_idx = None
59- for gm_info in get_gm_map (exported_program ):
60- if gm_info ["name" ] == cond_args .true_graph .name :
61- true_graph_idx = gm_info ["index" ]
62- continue
63- if gm_info ["name" ] == cond_args .false_graph .name :
64- false_graph_idx = gm_info ["index" ]
65- continue
66- assert true_graph_idx is not None
67- assert false_graph_idx is not None
59+ def _set_meta_val (graph_node , graph_module , graph_args ):
60+ def _get_meta_val (node ):
61+ assert hasattr (node , 'meta' ), f"'node' has no attribute named 'meta' (node: { node } )"
62+ assert "val" in node .meta , f"val key not in node.meta (node: { node } , meta: { node .meta } )"
63+ return node .meta ["val" ]
64+
65+ args , kwargs = pytree .tree_map_only (
66+ torch .fx .Node ,
67+ _get_meta_val ,
68+ (graph_args , {}),
69+ )
70+
71+ new_val = graph_module (* args , ** kwargs ) # type: ignore[operator]
72+ graph_node .meta ["val" ] = new_val
73+
74+ for graph_module , name in get_all_graph_modules (exported_program , subgraph_only = True ):
75+ if true_graph .name == name :
76+ _set_meta_val (true_graph , graph_module , graph_args )
77+ if false_graph .name == name :
78+ _set_meta_val (false_graph , graph_module , graph_args )
79+
80+ assert "val" in true_graph .meta , f"{ true_graph } has no node.meta['val']"
81+ assert "val" in false_graph .meta , f"{ false_graph } has no node.meta['val']"
6882
83+ freeze_subgraphs (exported_program )
6984 with graph .inserting_before (node ):
7085 circle_if = create_node (
7186 graph ,
7287 torch .ops .circle_custom .if_ ,
73- args = (cond_args .condition , true_graph_idx , false_graph_idx , cond_args .cond_args ),
88+ args = (cond_args .condition , cond_args . true_graph , cond_args . false_graph , cond_args .cond_args ),
7489 kwargs = {},
7590 origin = node ,
7691 )
92+
93+ for t , f in zip (true_graph .meta ['val' ], false_graph .meta ['val' ]):
94+ assert type (t ) == type (f )
95+ assert t .shape == f .shape , f"{ t .shape } != { f .shape } "
96+ assert t .dtype == f .dtype , f"{ t .dtype } != { f .dtype } "
97+
98+ circle_if .meta ["val" ] = true_graph .meta ['val' ][0 ]
7799
78100 # FIX ME UNLESS torch.ops.higher_order.cond generates this pattern
79101 assert len (node .users ) == 1
0 commit comments