Skip to content

Commit 496cc5d

Browse files
fix aten IR issue with Conv2ds, including for ViTs
Signed-off-by: cliu-us <[email protected]>
1 parent 5315f13 commit 496cc5d

File tree

2 files changed

+51
-22
lines changed

2 files changed

+51
-22
lines changed

fms_mo/fx/dynamo_utils.py

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,27 @@
3333
logger = logging.getLogger(__name__)
3434

3535

36+
def run_fwd_once(model, sample_inp):
37+
with torch.no_grad():
38+
if isinstance(sample_inp, dict) or all(
39+
hasattr(sample_inp, k) for k in ("keys", "values", "items")
40+
):
41+
out = model(**sample_inp)
42+
elif isinstance(sample_inp, tuple):
43+
out = model(*sample_inp)
44+
elif isinstance(sample_inp, torch.Tensor):
45+
out = model(sample_inp)
46+
else:
47+
try:
48+
# assume user provided input is ready-to-run...
49+
out = model(sample_inp)
50+
except RuntimeError:
51+
logger.info(
52+
f"Unknown data structure for example_input.{type(sample_inp)} Please check."
53+
)
54+
return out
55+
56+
3657
def dfs_gm(
3758
gm,
3859
targetOp=None,
@@ -229,7 +250,9 @@ def _dfs(curr_node, depth):
229250

230251

231252
def find_conv_on_shortcut_gm(
232-
gm: torch.fx.GraphModule, lut_fx_mod_name_to_org: Optional[Dict[str, str]] = None
253+
gm: torch.fx.GraphModule,
254+
lut_fx_mod_name_to_org: Optional[Dict[str, str]] = None,
255+
lut_name_to_mod={},
233256
):
234257
"""Identify Conv on shortcut using FX GM DFS
235258
It's (almost) specific for ResNet-like CNNs, will return a list of module names (as used in the
@@ -337,9 +360,13 @@ def find_conv_on_shortcut_gm(
337360
if n_conv_i.op == "call_module":
338361
conv_mod = gm.get_submodule(n_conv_i.target)
339362
else:
340-
conv_mod = get_org_mod_name_of_fx_node(
363+
# in case aten IR is being used
364+
conv_mod_name = get_org_mod_name_of_fx_node(
341365
n_conv_i, lut_fx2org=lut_fx_mod_name_to_org
342366
)
367+
conv_mod = lut_name_to_mod.get(conv_mod_name, None)
368+
if not isinstance(conv_mod, torch.nn.Conv2d):
369+
continue
343370
if conv_mod.out_channels > conv_mod.in_channels: # see Note 2
344371
qconv_candidate.append(
345372
get_org_mod_name_of_fx_node(
@@ -1003,8 +1030,17 @@ def cus_backend_model_analyzer(
10031030
for _, m in gm_fx.named_modules()
10041031
if isinstance(m, torch.nn.Conv2d) or issubclass(type(m), torch.nn.Conv2d)
10051032
]
1006-
if len(all_conv) > 0:
1007-
skip_candidates += find_conv_on_shortcut_gm(gm_fx, lut_fx_mod_name_to_org)
1033+
# if gm is using aten IR, only ops can be seen, no modules.
1034+
conv_ops = dfs_gm(
1035+
gm_fx,
1036+
targetOp=[torch.nn.Conv2d, torch.nn.functional.conv2d],
1037+
return_nodes=True,
1038+
)
1039+
lut_name_to_mod = {n: m for m, n in qcfg["LUTmodule_name"].items()}
1040+
if len(all_conv) > 0 or len(conv_ops) > 0:
1041+
skip_candidates += find_conv_on_shortcut_gm(
1042+
gm_fx, lut_fx_mod_name_to_org, lut_name_to_mod
1043+
)
10081044

10091045
# Check 2. first/last, see Note 2 and 3, NOTE that transformers are handled differently
10101046
if qcfg["N_backend_called"] > 1:
@@ -1064,6 +1100,7 @@ def cus_backend_model_analyzer(
10641100
from functools import partial
10651101

10661102
# Third Party
1103+
from torchvision.models import VisionTransformer
10671104
from transformers import PreTrainedModel
10681105

10691106
if issubclass(type(model), torch.nn.Module):
@@ -1075,7 +1112,7 @@ def cus_backend_model_analyzer(
10751112
model_to_be_traced = model
10761113
model_param_size = 999
10771114

1078-
is_transformers = issubclass(type(model), PreTrainedModel)
1115+
is_transformers = issubclass(type(model), (PreTrainedModel, VisionTransformer))
10791116
if model_param_size > 1:
10801117
# Standard
10811118
import sys
@@ -1111,35 +1148,25 @@ def call_seq_hook(mod, *_args, **_kwargs):
11111148
h_hooks.append(m.register_forward_hook(call_seq_hook))
11121149

11131150
with torch.no_grad():
1114-
model(**sample_inp)
1151+
run_fwd_once(model, sample_inp)
11151152

11161153
for h in h_hooks:
11171154
h.remove()
11181155

11191156
# only add last layer
11201157
qcfg["qskip_layer_name"] += [qcfg["mod_call_seq"][-1]]
1158+
# unless it's a ViT, skip first Conv as well
1159+
if issubclass(type(model), VisionTransformer) and isinstance(
1160+
model.get_submodule(qcfg["mod_call_seq"][0]), torch.nn.Conv2d
1161+
):
1162+
qcfg["qskip_layer_name"] += [qcfg["mod_call_seq"][0]]
11211163

11221164
with torch.no_grad():
11231165
model_opt = torch.compile(
11241166
model_to_be_traced,
11251167
backend=cus_bknd,
11261168
)
1127-
if isinstance(sample_inp, dict) or all(
1128-
hasattr(sample_inp, k) for k in ("keys", "values", "items")
1129-
):
1130-
model_opt(**sample_inp)
1131-
elif isinstance(sample_inp, tuple):
1132-
model_opt(*sample_inp)
1133-
elif isinstance(sample_inp, torch.Tensor):
1134-
model_opt(sample_inp)
1135-
else:
1136-
try:
1137-
# assume user provided input is ready-to-run...
1138-
model_opt(sample_inp)
1139-
except RuntimeError:
1140-
logger.info(
1141-
f"Unknown data structure for example_input.{type(sample_inp)} Please check."
1142-
)
1169+
run_fwd_once(model_opt, sample_inp)
11431170

11441171
del model_opt
11451172

fms_mo/fx/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,8 @@ def get_org_mod_name_of_fx_node(
343343
str: corresponding name on original graph
344344
"""
345345
org_name = f"Unknown:{node.name}"
346+
if lut_fx2org == None:
347+
lut_fx2org = {}
346348
if "nn_module_stack" in node.meta:
347349
n_fx_mod_name = list(node.meta["nn_module_stack"].keys())[-1]
348350
n_fx_org_mod_name = list(node.meta["nn_module_stack"].values())[-1][0]

0 commit comments

Comments
 (0)