@@ -46,7 +46,7 @@ def dfs_gm(
4646 prescreenOp = None ,
4747 hook = None ,
4848 return_nodes = False ,
49- LUTfx_mod_name_to_org = {},
49+ lut_fx_mod_name_to_org = {},
5050):
5151 """Depth-First Search at FX IR level, to replace our old TorchScript equivalent func
5252 Because FX IR is a higher level IR, should have much fewer
@@ -217,7 +217,7 @@ def _dfs(curr_node, depth):
217217 for n_ln , d in node_found .items ():
218218 n , line_num = n_ln # unpack tuple
219219 org_mod_names [
220- get_org_mod_name_of_fx_node (n , gm , LUTfx_mod_name_to_org ), line_num
220+ get_org_mod_name_of_fx_node (n , gm , lut_fx_mod_name_to_org ), line_num
221221 ] = d # see Note 2
222222
223223 return dict (
@@ -227,7 +227,7 @@ def _dfs(curr_node, depth):
227227 return node_found
228228
229229
230- def find_conv_on_shortcut_gm (gm : torch .fx .GraphModule , LUTfx_mod_name_to_org = {}):
230+ def find_conv_on_shortcut_gm (gm : torch .fx .GraphModule , lut_fx_mod_name_to_org = {}):
231231 """Identify Conv on shortcut using FX GM DFS
232232 It's (almost) specific for ResNet-like CNNs, will return a list of module names (as used in the
233233 original model, not FX module names)
@@ -335,18 +335,18 @@ def find_conv_on_shortcut_gm(gm: torch.fx.GraphModule, LUTfx_mod_name_to_org={})
335335 conv_mod = gm .get_submodule (n_conv_i .target )
336336 else :
337337 conv_mod = get_org_mod_name_of_fx_node (
338- n_conv_i , LUTfx2org = LUTfx_mod_name_to_org
338+ n_conv_i , lut_fx2org = lut_fx_mod_name_to_org
339339 )
340340 if conv_mod .out_channels > conv_mod .in_channels : # see Note 2
341341 qconv_candidate .append (
342- get_org_mod_name_of_fx_node (n_conv_i , gm , LUTfx_mod_name_to_org )
342+ get_org_mod_name_of_fx_node (n_conv_i , gm , lut_fx_mod_name_to_org )
343343 )
344344
345345 return qconv_candidate
346346
347347
348348def find_1st_last_gm (
349- gm , firstOps = None , lastOps = None , return_1st_last_sep = False , LUTfx_mod_name_to_org = {}
349+ gm , firstOps = None , lastOps = None , return_1st_last_sep = False , lut_fx_mod_name_to_org = {}
350350):
351351 """Identify the first and last layer of interests
352352 Usually only interested in Conv and Linear, but could be others as well
@@ -366,14 +366,14 @@ def find_1st_last_gm(
366366 gm ,
367367 targetOp = firstOps ,
368368 find1stOnly = True ,
369- LUTfx_mod_name_to_org = LUTfx_mod_name_to_org ,
369+ lut_fx_mod_name_to_org = lut_fx_mod_name_to_org ,
370370 )
371371 last_candidates = dfs_gm (
372372 gm ,
373373 targetOp = lastOps ,
374374 find1stOnly = True ,
375375 reverse = True ,
376- LUTfx_mod_name_to_org = LUTfx_mod_name_to_org ,
376+ lut_fx_mod_name_to_org = lut_fx_mod_name_to_org ,
377377 )
378378
379379 min_depth = min (list (first_candidates .values ()) + [999 ])
@@ -397,7 +397,7 @@ def find_1st_last_gm(
397397
398398
399399def find_single_sided_op_gm (
400- gm , op_of_interest = None , return_LUTs = False , verbose = False , LUTfx_mod_name_to_org = {}
400+ gm , op_of_interest = None , return_LUTs = False , verbose = False , lut_fx_mod_name_to_org = {}
401401):
402402 """Try to determine the "polarity" of output of "every nodes" based on their inputs and the Op
403403 itself, then decide which Conv/Linear (or user-specified Op) will use single-sided quantizer
@@ -546,7 +546,7 @@ def find_single_sided_op_gm(
546546 gm ,
547547 targetOp = op_of_interest ,
548548 return_nodes = True ,
549- LUTfx_mod_name_to_org = LUTfx_mod_name_to_org ,
549+ lut_fx_mod_name_to_org = lut_fx_mod_name_to_org ,
550550 )
551551
552552 SingleSidedOps = []
@@ -557,7 +557,7 @@ def find_single_sided_op_gm(
557557 risky_nodes .append (n )
558558 if all (input_pos ):
559559 SingleSidedOps .append (
560- get_org_mod_name_of_fx_node (n , gm , LUTfx_mod_name_to_org )
560+ get_org_mod_name_of_fx_node (n , gm , lut_fx_mod_name_to_org )
561561 )
562562
563563 if risky_nodes :
@@ -570,7 +570,7 @@ def find_single_sided_op_gm(
570570 return SingleSidedOps
571571
572572
573- def find_qkvsync_candidates_gm (gm , return_nodes = False , LUTfx_mod_name_to_org = {}):
573+ def find_qkvsync_candidates_gm (gm , return_nodes = False , lut_fx_mod_name_to_org = {}):
574574 """Identify groups of Linears that share the same parent. It's a transformer-specific feature.
575575
576576 NOTE:
@@ -609,7 +609,7 @@ def find_qkvsync_candidates_gm(gm, return_nodes=False, LUTfx_mod_name_to_org={})
609609 for depth , nodes in LUTdep2nodes .items ():
610610 parents = [ni .all_input_nodes [0 ] for ni in nodes ]
611611 org_mod_names = [
612- get_org_mod_name_of_fx_node (ni , gm , LUTfx_mod_name_to_org ) for ni in nodes
612+ get_org_mod_name_of_fx_node (ni , gm , lut_fx_mod_name_to_org ) for ni in nodes
613613 ]
614614 if all (p == parents [0 ] for p in parents [1 :]):
615615 Nshared_parents += 1
@@ -620,7 +620,7 @@ def find_qkvsync_candidates_gm(gm, return_nodes=False, LUTfx_mod_name_to_org={})
620620 return my_1st_sibling
621621
622622
623- def find_silu_gm (gm , LUTfx_mod_name_to_org = {}):
623+ def find_silu_gm (gm , lut_fx_mod_name_to_org = {}):
624624 """Special handle for Conv following silu, specific for EffDet and etc
625625 LLM could use SiLU as well (llama?), but not relavent to this func
626626 """
@@ -633,14 +633,14 @@ def find_silu_gm(gm, LUTfx_mod_name_to_org={}):
633633 gpOp = get_target_op_from_node (gp_nodes [0 ], gm ) if gp_nodes else None
634634
635635 if torch .nn .functional .silu in [pOp , gpOp ]:
636- siluConv [get_org_mod_name_of_fx_node (n , gm , LUTfx_mod_name_to_org )] = {
636+ siluConv [get_org_mod_name_of_fx_node (n , gm , lut_fx_mod_name_to_org )] = {
637637 "qa_mode" : "qsilu"
638638 }
639639
640640 return siluConv
641641
642642
643- def find_rpn_fpn_gm (gm , verbose = False , Nsubgraph = 0 , LUTfx_mod_name_to_org = {}):
643+ def find_rpn_fpn_gm (gm , verbose = False , Nsubgraph = 0 , lut_fx_mod_name_to_org = {}):
644644 """For object detection CNN models, RPN (RegionProposalNetwork) and FPN (FeaturePyramidNetwork)
645645 are commonly used. prefer to skip them, but may be ok to quantize in some cases.
646646
@@ -703,7 +703,7 @@ def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0, LUTfx_mod_name_to_org={}):
703703 targetOp = [torch .nn .Conv2d ],
704704 start_nodes = fpn_st_nodes ,
705705 stop_nodes = [fpn_end_node ],
706- LUTfx_mod_name_to_org = LUTfx_mod_name_to_org ,
706+ lut_fx_mod_name_to_org = lut_fx_mod_name_to_org ,
707707 )
708708 fpn_convs = [mod_name for mod_name , ln in fpn_convs .keys ()] # see Note 4
709709 fpn_adds = dfs_gm (
@@ -730,7 +730,7 @@ def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0, LUTfx_mod_name_to_org={}):
730730 ):
731731 fpn_inner_blocks .append (
732732 get_org_mod_name_of_fx_node (
733- gp , LUTfx2org = LUTfx_mod_name_to_org
733+ gp , lut_fx2org = lut_fx_mod_name_to_org
734734 )
735735 )
736736 fpn_convs += fpn_inner_blocks
@@ -744,7 +744,7 @@ def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0, LUTfx_mod_name_to_org={}):
744744 return fpn_convs
745745
746746
747- def find_and_prep_bmm_gm (gm , LUTfx_mod_name_to_org = {}):
747+ def find_and_prep_bmm_gm (gm , lut_fx_mod_name_to_org = {}):
748748 """Previously with TorchScript, we use this func to perform 2 tasks:
749749 a) create QBmms, and then attach them to the model,
750750 b) set up qcfg["which2patch_contextmanager"] so that patch_torch_bmm() context
@@ -798,7 +798,7 @@ def find_and_prep_bmm_gm(gm, LUTfx_mod_name_to_org={}):
798798 LUTmodname2linenum = {} # see Note 4
799799 for node_line_num , depth in LUT2sort .items ():
800800 node , line_num = node_line_num
801- org_mod_name = get_org_mod_name_of_fx_node (node , gm , LUTfx_mod_name_to_org )
801+ org_mod_name = get_org_mod_name_of_fx_node (node , gm , lut_fx_mod_name_to_org )
802802 if org_mod_name in LUTmodname2linenum :
803803 LUTmodname2linenum [org_mod_name ] += [(node , line_num , depth )]
804804 else :
@@ -880,7 +880,7 @@ def model_analyzer(
880880 2. Use Dynamo to replace TorchScript tracing in old qmodel_prep(),
881881
882882 NOTE:
883- 1. Will use LUTweight2modname to find the prefix for subgraphs, should graph break. As module
883+ 1. Will use lut_weight2modname to find the prefix for subgraphs, should graph break. As module
884884 seems to have extra layer of wrapper from Dynamo, matching module, i.e. id(module), may lead
885885 to incorrect results, matching weights (tensor) should be consistent.
886886 2. For subgraph, we might be getting a partial "original name", such as layer.0.xxx instead of
@@ -894,7 +894,7 @@ def model_analyzer(
894894 """
895895
896896 qcfg ["N_backend_called" ] = 0
897- LUTweight2modname = {
897+ lut_weight2modname = {
898898 mod .weight : n
899899 for n , mod in model .named_modules ()
900900 if isinstance (mod , (torch .nn .Linear , torch .nn .Conv2d ))
@@ -941,20 +941,20 @@ def cus_backend_model_analyzer(
941941
942942 """
943943 qcfg ["N_backend_called" ] += 1
944- LUTfx_mod_name_to_org = {
945- n .replace (".weight" , "" ): LUTweight2modname [p ]
944+ lut_fx_mod_name_to_org = {
945+ n .replace (".weight" , "" ): lut_weight2modname [p ]
946946 for n , p in gm_fx .named_parameters ()
947- if p in LUTweight2modname
947+ if p in lut_weight2modname
948948 }
949949 prefix = None
950950 if qcfg ["N_backend_called" ] > 1 : # subgraph found, see Note 2
951951 for n in gm_fx .graph .nodes :
952952 if n .op == "call_module" :
953953 mod = gm_fx .get_submodule (n .target )
954954 if isinstance (mod , (torch .nn .Linear , torch .nn .Conv2d )):
955- real_org_modname = LUTweight2modname [mod .weight ]
955+ real_org_modname = lut_weight2modname [mod .weight ]
956956 part_org_modname = get_org_mod_name_of_fx_node (
957- n , gm_fx , LUTfx_mod_name_to_org
957+ n , gm_fx , lut_fx_mod_name_to_org
958958 )
959959 idx = real_org_modname .rindex (part_org_modname )
960960 if idx > 1 :
@@ -972,7 +972,7 @@ def cus_backend_model_analyzer(
972972 outputname = plotsvg ,
973973 show_details = True ,
974974 Nnode_to_plot = 1000 ,
975- LUTfx_mod_name_to_org = LUTfx_mod_name_to_org ,
975+ lut_fx_mod_name_to_org = lut_fx_mod_name_to_org ,
976976 )
977977
978978 # Graph checks begin. Use append, add prefix if needed
@@ -984,7 +984,7 @@ def cus_backend_model_analyzer(
984984 if isinstance (m , torch .nn .Conv2d ) or issubclass (type (m ), torch .nn .Conv2d )
985985 ]
986986 if len (all_conv ) > 0 :
987- skip_candidates += find_conv_on_shortcut_gm (gm_fx , LUTfx_mod_name_to_org )
987+ skip_candidates += find_conv_on_shortcut_gm (gm_fx , lut_fx_mod_name_to_org )
988988
989989 # Check 2. first/last, see Note 2 and 3
990990 if qcfg ["N_backend_called" ] > 1 :
@@ -993,19 +993,19 @@ def cus_backend_model_analyzer(
993993 _ , last_only = find_1st_last_gm (
994994 gm_fx ,
995995 return_1st_last_sep = True ,
996- LUTfx_mod_name_to_org = LUTfx_mod_name_to_org ,
996+ lut_fx_mod_name_to_org = lut_fx_mod_name_to_org ,
997997 )
998998 skip_candidates += last_only
999999 else :
10001000 # see Note 4
10011001 skip_candidates += find_1st_last_gm (
1002- gm_fx , LUTfx_mod_name_to_org = LUTfx_mod_name_to_org
1002+ gm_fx , lut_fx_mod_name_to_org = lut_fx_mod_name_to_org
10031003 )
10041004 qcfg ["qskip_layer_name" ] += add_prefix_to_list_or_dict (skip_candidates , prefix )
10051005
10061006 # Check 3: single/double sided
10071007 qcfg ["qsinglesided_name" ] += add_prefix_to_list_or_dict (
1008- find_single_sided_op_gm (gm_fx , LUTfx_mod_name_to_org = LUTfx_mod_name_to_org ),
1008+ find_single_sided_op_gm (gm_fx , lut_fx_mod_name_to_org = lut_fx_mod_name_to_org ),
10091009 prefix ,
10101010 )
10111011
@@ -1016,12 +1016,12 @@ def cus_backend_model_analyzer(
10161016 # Check 5: Conv+SiLU
10171017 qcfg ["qspecial_layers" ].update (
10181018 add_prefix_to_list_or_dict (
1019- find_silu_gm (gm_fx , LUTfx_mod_name_to_org ), prefix
1019+ find_silu_gm (gm_fx , lut_fx_mod_name_to_org ), prefix
10201020 )
10211021 )
10221022
10231023 # Check 6: BMM
1024- temp_dict = find_and_prep_bmm_gm (gm_fx , LUTfx_mod_name_to_org ) # see Note 5
1024+ temp_dict = find_and_prep_bmm_gm (gm_fx , lut_fx_mod_name_to_org ) # see Note 5
10251025 if len (temp_dict ["layers_with_bmm" ]) > 0 :
10261026 temp_dict ["layers_with_bmm" ] = add_prefix_to_list_or_dict (
10271027 temp_dict ["layers_with_bmm" ], prefix
@@ -1033,7 +1033,7 @@ def cus_backend_model_analyzer(
10331033
10341034 # Check 7: QKV
10351035 temp_dict = find_qkvsync_candidates_gm (
1036- gm_fx , LUTfx_mod_name_to_org = LUTfx_mod_name_to_org
1036+ gm_fx , lut_fx_mod_name_to_org = lut_fx_mod_name_to_org
10371037 ) # see Note 6
10381038 temp_dict = add_prefix_to_list_or_dict (
10391039 temp_dict , prefix , update_both_k_and_v = True
0 commit comments