3333logger = logging .getLogger (__name__ )
3434
3535
36+ def run_fwd_once (model , sample_inp ):
37+ """Convenient function to run model once using correct input unpack."""
38+ with torch .no_grad ():
39+ if isinstance (sample_inp , dict ) or all (
40+ hasattr (sample_inp , k ) for k in ("keys" , "values" , "items" )
41+ ):
42+ out = model (** sample_inp )
43+ elif isinstance (sample_inp , tuple ):
44+ out = model (* sample_inp )
45+ elif isinstance (sample_inp , torch .Tensor ):
46+ out = model (sample_inp )
47+ else :
48+ try :
49+ # assume user provided input is ready-to-run...
50+ out = model (sample_inp )
51+ except RuntimeError :
52+ logger .info (
53+ f"Unknown data structure for example_input.{ type (sample_inp )} Please check."
54+ )
55+ return out
56+
57+
3658def dfs_gm (
3759 gm ,
3860 targetOp = None ,
@@ -229,7 +251,9 @@ def _dfs(curr_node, depth):
229251
230252
231253def find_conv_on_shortcut_gm (
232- gm : torch .fx .GraphModule , lut_fx_mod_name_to_org : Optional [Dict [str , str ]] = None
254+ gm : torch .fx .GraphModule ,
255+ lut_fx_mod_name_to_org : Optional [Dict [str , str ]] = None ,
256+ lut_name_to_mod = None ,
233257):
234258 """Identify Conv on shortcut using FX GM DFS
235259 It's (almost) specific for ResNet-like CNNs, will return a list of module names (as used in the
@@ -254,6 +278,9 @@ def find_conv_on_shortcut_gm(
254278 5. count levels of each branch, decide which one is the shortcut
255279 """
256280
281+ if lut_name_to_mod is None :
282+ lut_name_to_mod = {}
283+
257284 # 1. Find "add" nodes, including inplace add as some may use "out+=shortcut"
258285 nodes_add = dfs_gm (gm , ["add" ], return_nodes = True )
259286
@@ -337,9 +364,13 @@ def find_conv_on_shortcut_gm(
337364 if n_conv_i .op == "call_module" :
338365 conv_mod = gm .get_submodule (n_conv_i .target )
339366 else :
340- conv_mod = get_org_mod_name_of_fx_node (
367+ # in case aten IR is being used
368+ conv_mod_name = get_org_mod_name_of_fx_node (
341369 n_conv_i , lut_fx2org = lut_fx_mod_name_to_org
342370 )
371+ conv_mod = lut_name_to_mod .get (conv_mod_name , None )
372+ if not isinstance (conv_mod , torch .nn .Conv2d ):
373+ continue
343374 if conv_mod .out_channels > conv_mod .in_channels : # see Note 2
344375 qconv_candidate .append (
345376 get_org_mod_name_of_fx_node (
@@ -1003,8 +1034,17 @@ def cus_backend_model_analyzer(
10031034 for _ , m in gm_fx .named_modules ()
10041035 if isinstance (m , torch .nn .Conv2d ) or issubclass (type (m ), torch .nn .Conv2d )
10051036 ]
1006- if len (all_conv ) > 0 :
1007- skip_candidates += find_conv_on_shortcut_gm (gm_fx , lut_fx_mod_name_to_org )
1037+ # if gm is using aten IR, only ops can be seen, no modules.
1038+ conv_ops = dfs_gm (
1039+ gm_fx ,
1040+ targetOp = [torch .nn .Conv2d , torch .nn .functional .conv2d ],
1041+ return_nodes = True ,
1042+ )
1043+ lut_name_to_mod = {n : m for m , n in qcfg ["LUTmodule_name" ].items ()}
1044+ if len (all_conv ) > 0 or len (conv_ops ) > 0 :
1045+ skip_candidates += find_conv_on_shortcut_gm (
1046+ gm_fx , lut_fx_mod_name_to_org , lut_name_to_mod
1047+ )
10081048
10091049 # Check 2. first/last, see Note 2 and 3, NOTE that transformers are handled differently
10101050 if qcfg ["N_backend_called" ] > 1 :
@@ -1064,6 +1104,7 @@ def cus_backend_model_analyzer(
10641104 from functools import partial
10651105
10661106 # Third Party
1107+ from torchvision .models import VisionTransformer
10671108 from transformers import PreTrainedModel
10681109
10691110 if issubclass (type (model ), torch .nn .Module ):
@@ -1075,7 +1116,7 @@ def cus_backend_model_analyzer(
10751116 model_to_be_traced = model
10761117 model_param_size = 999
10771118
1078- is_transformers = issubclass (type (model ), PreTrainedModel )
1119+ is_transformers = issubclass (type (model ), ( PreTrainedModel , VisionTransformer ) )
10791120 if model_param_size > 1 :
10801121 # Standard
10811122 import sys
@@ -1111,35 +1152,25 @@ def call_seq_hook(mod, *_args, **_kwargs):
11111152 h_hooks .append (m .register_forward_hook (call_seq_hook ))
11121153
11131154 with torch .no_grad ():
1114- model ( ** sample_inp )
1155+ run_fwd_once ( model , sample_inp )
11151156
11161157 for h in h_hooks :
11171158 h .remove ()
11181159
11191160 # only add last layer
11201161 qcfg ["qskip_layer_name" ] += [qcfg ["mod_call_seq" ][- 1 ]]
1162+ # unless it's a ViT, skip first Conv as well
1163+ if issubclass (type (model ), VisionTransformer ) and isinstance (
1164+ model .get_submodule (qcfg ["mod_call_seq" ][0 ]), torch .nn .Conv2d
1165+ ):
1166+ qcfg ["qskip_layer_name" ] += [qcfg ["mod_call_seq" ][0 ]]
11211167
11221168 with torch .no_grad ():
11231169 model_opt = torch .compile (
11241170 model_to_be_traced ,
11251171 backend = cus_bknd ,
11261172 )
1127- if isinstance (sample_inp , dict ) or all (
1128- hasattr (sample_inp , k ) for k in ("keys" , "values" , "items" )
1129- ):
1130- model_opt (** sample_inp )
1131- elif isinstance (sample_inp , tuple ):
1132- model_opt (* sample_inp )
1133- elif isinstance (sample_inp , torch .Tensor ):
1134- model_opt (sample_inp )
1135- else :
1136- try :
1137- # assume user provided input is ready-to-run...
1138- model_opt (sample_inp )
1139- except RuntimeError :
1140- logger .info (
1141- f"Unknown data structure for example_input.{ type (sample_inp )} Please check."
1142- )
1173+ run_fwd_once (model_opt , sample_inp )
11431174
11441175 del model_opt
11451176
0 commit comments