Skip to content

Commit e9bc8d6

Browse files
Merge pull request #81 from chichun-charlie-liu/main
fix: a bug that prevented dynamo from working with PT 2.5.1 has been fixed
2 parents b5f84d9 + fcb9618 commit e9bc8d6

File tree

3 files changed

+57
-24
lines changed

3 files changed

+57
-24
lines changed

fms_mo/fx/dynamo_utils.py

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,28 @@
3333
logger = logging.getLogger(__name__)
3434

3535

36+
def run_fwd_once(model, sample_inp):
37+
"""Convenient function to run model once using correct input unpack."""
38+
with torch.no_grad():
39+
if isinstance(sample_inp, dict) or all(
40+
hasattr(sample_inp, k) for k in ("keys", "values", "items")
41+
):
42+
out = model(**sample_inp)
43+
elif isinstance(sample_inp, tuple):
44+
out = model(*sample_inp)
45+
elif isinstance(sample_inp, torch.Tensor):
46+
out = model(sample_inp)
47+
else:
48+
try:
49+
# assume user provided input is ready-to-run...
50+
out = model(sample_inp)
51+
except RuntimeError:
52+
logger.info(
53+
f"Unknown data structure for example_input.{type(sample_inp)} Please check."
54+
)
55+
return out
56+
57+
3658
def dfs_gm(
3759
gm,
3860
targetOp=None,
@@ -229,7 +251,9 @@ def _dfs(curr_node, depth):
229251

230252

231253
def find_conv_on_shortcut_gm(
232-
gm: torch.fx.GraphModule, lut_fx_mod_name_to_org: Optional[Dict[str, str]] = None
254+
gm: torch.fx.GraphModule,
255+
lut_fx_mod_name_to_org: Optional[Dict[str, str]] = None,
256+
lut_name_to_mod=None,
233257
):
234258
"""Identify Conv on shortcut using FX GM DFS
235259
It's (almost) specific for ResNet-like CNNs, will return a list of module names (as used in the
@@ -254,6 +278,9 @@ def find_conv_on_shortcut_gm(
254278
5. count levels of each branch, decide which one is the shortcut
255279
"""
256280

281+
if lut_name_to_mod is None:
282+
lut_name_to_mod = {}
283+
257284
# 1. Find "add" nodes, including inplace add as some may use "out+=shortcut"
258285
nodes_add = dfs_gm(gm, ["add"], return_nodes=True)
259286

@@ -337,9 +364,13 @@ def find_conv_on_shortcut_gm(
337364
if n_conv_i.op == "call_module":
338365
conv_mod = gm.get_submodule(n_conv_i.target)
339366
else:
340-
conv_mod = get_org_mod_name_of_fx_node(
367+
# in case aten IR is being used
368+
conv_mod_name = get_org_mod_name_of_fx_node(
341369
n_conv_i, lut_fx2org=lut_fx_mod_name_to_org
342370
)
371+
conv_mod = lut_name_to_mod.get(conv_mod_name, None)
372+
if not isinstance(conv_mod, torch.nn.Conv2d):
373+
continue
343374
if conv_mod.out_channels > conv_mod.in_channels: # see Note 2
344375
qconv_candidate.append(
345376
get_org_mod_name_of_fx_node(
@@ -1003,8 +1034,17 @@ def cus_backend_model_analyzer(
10031034
for _, m in gm_fx.named_modules()
10041035
if isinstance(m, torch.nn.Conv2d) or issubclass(type(m), torch.nn.Conv2d)
10051036
]
1006-
if len(all_conv) > 0:
1007-
skip_candidates += find_conv_on_shortcut_gm(gm_fx, lut_fx_mod_name_to_org)
1037+
# if gm is using aten IR, only ops can be seen, no modules.
1038+
conv_ops = dfs_gm(
1039+
gm_fx,
1040+
targetOp=[torch.nn.Conv2d, torch.nn.functional.conv2d],
1041+
return_nodes=True,
1042+
)
1043+
lut_name_to_mod = {n: m for m, n in qcfg["LUTmodule_name"].items()}
1044+
if len(all_conv) > 0 or len(conv_ops) > 0:
1045+
skip_candidates += find_conv_on_shortcut_gm(
1046+
gm_fx, lut_fx_mod_name_to_org, lut_name_to_mod
1047+
)
10081048

10091049
# Check 2. first/last, see Note 2 and 3, NOTE that transformers are handled differently
10101050
if qcfg["N_backend_called"] > 1:
@@ -1064,6 +1104,7 @@ def cus_backend_model_analyzer(
10641104
from functools import partial
10651105

10661106
# Third Party
1107+
from torchvision.models import VisionTransformer
10671108
from transformers import PreTrainedModel
10681109

10691110
if issubclass(type(model), torch.nn.Module):
@@ -1075,7 +1116,7 @@ def cus_backend_model_analyzer(
10751116
model_to_be_traced = model
10761117
model_param_size = 999
10771118

1078-
is_transformers = issubclass(type(model), PreTrainedModel)
1119+
is_transformers = issubclass(type(model), (PreTrainedModel, VisionTransformer))
10791120
if model_param_size > 1:
10801121
# Standard
10811122
import sys
@@ -1111,35 +1152,25 @@ def call_seq_hook(mod, *_args, **_kwargs):
11111152
h_hooks.append(m.register_forward_hook(call_seq_hook))
11121153

11131154
with torch.no_grad():
1114-
model(**sample_inp)
1155+
run_fwd_once(model, sample_inp)
11151156

11161157
for h in h_hooks:
11171158
h.remove()
11181159

11191160
# only add last layer
11201161
qcfg["qskip_layer_name"] += [qcfg["mod_call_seq"][-1]]
1162+
# unless it's a ViT, skip first Conv as well
1163+
if issubclass(type(model), VisionTransformer) and isinstance(
1164+
model.get_submodule(qcfg["mod_call_seq"][0]), torch.nn.Conv2d
1165+
):
1166+
qcfg["qskip_layer_name"] += [qcfg["mod_call_seq"][0]]
11211167

11221168
with torch.no_grad():
11231169
model_opt = torch.compile(
11241170
model_to_be_traced,
11251171
backend=cus_bknd,
11261172
)
1127-
if isinstance(sample_inp, dict) or all(
1128-
hasattr(sample_inp, k) for k in ("keys", "values", "items")
1129-
):
1130-
model_opt(**sample_inp)
1131-
elif isinstance(sample_inp, tuple):
1132-
model_opt(*sample_inp)
1133-
elif isinstance(sample_inp, torch.Tensor):
1134-
model_opt(sample_inp)
1135-
else:
1136-
try:
1137-
# assume user provided input is ready-to-run...
1138-
model_opt(sample_inp)
1139-
except RuntimeError:
1140-
logger.info(
1141-
f"Unknown data structure for example_input.{type(sample_inp)} Please check."
1142-
)
1173+
run_fwd_once(model_opt, sample_inp)
11431174

11441175
del model_opt
11451176

fms_mo/fx/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,8 @@ def get_org_mod_name_of_fx_node(
343343
str: corresponding name on original graph
344344
"""
345345
org_name = f"Unknown:{node.name}"
346+
if lut_fx2org is None:
347+
lut_fx2org = {}
346348
if "nn_module_stack" in node.meta:
347349
n_fx_mod_name = list(node.meta["nn_module_stack"].keys())[-1]
348350
n_fx_org_mod_name = list(node.meta["nn_module_stack"].values())[-1][0]
@@ -360,7 +362,7 @@ def get_org_mod_name_of_fx_node(
360362
org_name = v[: -len(suffix)]
361363
break
362364

363-
if org_name is None:
365+
if org_name.startswith("Unknown:"):
364366
org_name = lname_to_org_name(n_fx_org_mod_name)
365367

366368
return org_name

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ dependencies = [
2525
"numpy>=1.26.4,<2.3.0",
2626
"accelerate>=0.20.3,!=0.34,<1.4",
2727
"transformers>=4.45,<4.49",
28-
"torch>=2.2.0,<2.5",
28+
"torch>=2.2.0,<2.6",
2929
"triton>=3.0,<3.2",
3030
"tqdm>=4.66.2,<5.0",
3131
"datasets>=3.0.0,<4.0",

0 commit comments

Comments
 (0)