3333logger = logging .getLogger (__name__ )
3434
3535
36+ def run_fwd_once (model , sample_inp ):
37+ with torch .no_grad ():
38+ if isinstance (sample_inp , dict ) or all (
39+ hasattr (sample_inp , k ) for k in ("keys" , "values" , "items" )
40+ ):
41+ out = model (** sample_inp )
42+ elif isinstance (sample_inp , tuple ):
43+ out = model (* sample_inp )
44+ elif isinstance (sample_inp , torch .Tensor ):
45+ out = model (sample_inp )
46+ else :
47+ try :
48+ # assume user provided input is ready-to-run...
49+ out = model (sample_inp )
50+ except RuntimeError :
51+ logger .info (
52+ f"Unknown data structure for example_input.{ type (sample_inp )} Please check."
53+ )
54+ return out
55+
56+
3657def dfs_gm (
3758 gm ,
3859 targetOp = None ,
@@ -229,7 +250,9 @@ def _dfs(curr_node, depth):
229250
230251
231252def find_conv_on_shortcut_gm (
232- gm : torch .fx .GraphModule , lut_fx_mod_name_to_org : Optional [Dict [str , str ]] = None
253+ gm : torch .fx .GraphModule ,
254+ lut_fx_mod_name_to_org : Optional [Dict [str , str ]] = None ,
255+ lut_name_to_mod = {},
233256):
234257 """Identify Conv on shortcut using FX GM DFS
235258 It's (almost) specific for ResNet-like CNNs, will return a list of module names (as used in the
@@ -337,9 +360,13 @@ def find_conv_on_shortcut_gm(
337360 if n_conv_i .op == "call_module" :
338361 conv_mod = gm .get_submodule (n_conv_i .target )
339362 else :
340- conv_mod = get_org_mod_name_of_fx_node (
363+ # in case aten IR is being used
364+ conv_mod_name = get_org_mod_name_of_fx_node (
341365 n_conv_i , lut_fx2org = lut_fx_mod_name_to_org
342366 )
367+ conv_mod = lut_name_to_mod .get (conv_mod_name , None )
368+ if not isinstance (conv_mod , torch .nn .Conv2d ):
369+ continue
343370 if conv_mod .out_channels > conv_mod .in_channels : # see Note 2
344371 qconv_candidate .append (
345372 get_org_mod_name_of_fx_node (
@@ -1003,8 +1030,17 @@ def cus_backend_model_analyzer(
10031030 for _ , m in gm_fx .named_modules ()
10041031 if isinstance (m , torch .nn .Conv2d ) or issubclass (type (m ), torch .nn .Conv2d )
10051032 ]
1006- if len (all_conv ) > 0 :
1007- skip_candidates += find_conv_on_shortcut_gm (gm_fx , lut_fx_mod_name_to_org )
1033+ # if gm is using aten IR, only ops can be seen, no modules.
1034+ conv_ops = dfs_gm (
1035+ gm_fx ,
1036+ targetOp = [torch .nn .Conv2d , torch .nn .functional .conv2d ],
1037+ return_nodes = True ,
1038+ )
1039+ lut_name_to_mod = {n : m for m , n in qcfg ["LUTmodule_name" ].items ()}
1040+ if len (all_conv ) > 0 or len (conv_ops ) > 0 :
1041+ skip_candidates += find_conv_on_shortcut_gm (
1042+ gm_fx , lut_fx_mod_name_to_org , lut_name_to_mod
1043+ )
10081044
10091045 # Check 2. first/last, see Note 2 and 3, NOTE that transformers are handled differently
10101046 if qcfg ["N_backend_called" ] > 1 :
@@ -1064,6 +1100,7 @@ def cus_backend_model_analyzer(
10641100 from functools import partial
10651101
10661102 # Third Party
1103+ from torchvision .models import VisionTransformer
10671104 from transformers import PreTrainedModel
10681105
10691106 if issubclass (type (model ), torch .nn .Module ):
@@ -1075,7 +1112,7 @@ def cus_backend_model_analyzer(
10751112 model_to_be_traced = model
10761113 model_param_size = 999
10771114
1078- is_transformers = issubclass (type (model ), PreTrainedModel )
1115+ is_transformers = issubclass (type (model ), ( PreTrainedModel , VisionTransformer ) )
10791116 if model_param_size > 1 :
10801117 # Standard
10811118 import sys
@@ -1111,35 +1148,25 @@ def call_seq_hook(mod, *_args, **_kwargs):
11111148 h_hooks .append (m .register_forward_hook (call_seq_hook ))
11121149
11131150 with torch .no_grad ():
1114- model ( ** sample_inp )
1151+ run_fwd_once ( model , sample_inp )
11151152
11161153 for h in h_hooks :
11171154 h .remove ()
11181155
11191156 # only add last layer
11201157 qcfg ["qskip_layer_name" ] += [qcfg ["mod_call_seq" ][- 1 ]]
1158+ # unless it's a ViT, skip first Conv as well
1159+ if issubclass (type (model ), VisionTransformer ) and isinstance (
1160+ model .get_submodule (qcfg ["mod_call_seq" ][0 ]), torch .nn .Conv2d
1161+ ):
1162+ qcfg ["qskip_layer_name" ] += [qcfg ["mod_call_seq" ][0 ]]
11211163
11221164 with torch .no_grad ():
11231165 model_opt = torch .compile (
11241166 model_to_be_traced ,
11251167 backend = cus_bknd ,
11261168 )
1127- if isinstance (sample_inp , dict ) or all (
1128- hasattr (sample_inp , k ) for k in ("keys" , "values" , "items" )
1129- ):
1130- model_opt (** sample_inp )
1131- elif isinstance (sample_inp , tuple ):
1132- model_opt (* sample_inp )
1133- elif isinstance (sample_inp , torch .Tensor ):
1134- model_opt (sample_inp )
1135- else :
1136- try :
1137- # assume user provided input is ready-to-run...
1138- model_opt (sample_inp )
1139- except RuntimeError :
1140- logger .info (
1141- f"Unknown data structure for example_input.{ type (sample_inp )} Please check."
1142- )
1169+ run_fwd_once (model_opt , sample_inp )
11431170
11441171 del model_opt
11451172
0 commit comments