@@ -46,7 +46,7 @@ def dfs_gm(
4646 prescreenOp = None ,
4747 hook = None ,
4848 return_nodes = False ,
49- LUTfx_mod_name_to_org = {},
49+ lut_fx_mod_name_to_org = {},
5050):
5151 """Depth-First Search at FX IR level, to replace our old TorchScript equivalent func
5252 Because FX IR is a higher level IR, should have much fewer
@@ -217,7 +217,7 @@ def _dfs(curr_node, depth):
217217 for n_ln , d in node_found .items ():
218218 n , line_num = n_ln # unpack tuple
219219 org_mod_names [
220- get_org_mod_name_of_fx_node (n , gm , LUTfx_mod_name_to_org ), line_num
220+ get_org_mod_name_of_fx_node (n , gm , lut_fx_mod_name_to_org ), line_num
221221 ] = d # see Note 2
222222
223223 return dict (
@@ -227,7 +227,7 @@ def _dfs(curr_node, depth):
227227 return node_found
228228
229229
230- def find_conv_on_shortcut_gm (gm : torch .fx .GraphModule , LUTfx_mod_name_to_org = {}):
230+ def find_conv_on_shortcut_gm (gm : torch .fx .GraphModule , lut_fx_mod_name_to_org = {}):
231231 """Identify Conv on shortcut using FX GM DFS
232232 It's (almost) specific for ResNet-like CNNs, will return a list of module names (as used in the
233233 original model, not FX module names)
@@ -335,18 +335,18 @@ def find_conv_on_shortcut_gm(gm: torch.fx.GraphModule, LUTfx_mod_name_to_org={})
335335 conv_mod = gm .get_submodule (n_conv_i .target )
336336 else :
337337 conv_mod = get_org_mod_name_of_fx_node (
338- n_conv_i , LUTfx2org = LUTfx_mod_name_to_org
338+ n_conv_i , lut_fx2org = lut_fx_mod_name_to_org
339339 )
340340 if conv_mod .out_channels > conv_mod .in_channels : # see Note 2
341341 qconv_candidate .append (
342- get_org_mod_name_of_fx_node (n_conv_i , gm , LUTfx_mod_name_to_org )
342+ get_org_mod_name_of_fx_node (n_conv_i , gm , lut_fx_mod_name_to_org )
343343 )
344344
345345 return qconv_candidate
346346
347347
348348def find_1st_last_gm (
349- gm , firstOps = None , lastOps = None , return_1st_last_sep = False , LUTfx_mod_name_to_org = {}
349+ gm , firstOps = None , lastOps = None , return_1st_last_sep = False , lut_fx_mod_name_to_org = {}
350350):
351351 """Identify the first and last layer of interests
352352 Usually only interested in Conv and Linear, but could be others as well
@@ -366,14 +366,14 @@ def find_1st_last_gm(
366366 gm ,
367367 targetOp = firstOps ,
368368 find1stOnly = True ,
369- LUTfx_mod_name_to_org = LUTfx_mod_name_to_org ,
369+ lut_fx_mod_name_to_org = lut_fx_mod_name_to_org ,
370370 )
371371 last_candidates = dfs_gm (
372372 gm ,
373373 targetOp = lastOps ,
374374 find1stOnly = True ,
375375 reverse = True ,
376- LUTfx_mod_name_to_org = LUTfx_mod_name_to_org ,
376+ lut_fx_mod_name_to_org = lut_fx_mod_name_to_org ,
377377 )
378378
379379 min_depth = min (list (first_candidates .values ()) + [999 ])
@@ -397,7 +397,7 @@ def find_1st_last_gm(
397397
398398
399399def find_single_sided_op_gm (
400- gm , op_of_interest = None , return_LUTs = False , verbose = False , LUTfx_mod_name_to_org = {}
400+ gm , op_of_interest = None , return_LUTs = False , verbose = False , lut_fx_mod_name_to_org = {}
401401):
402402 """Try to determine the "polarity" of output of "every nodes" based on their inputs and the Op
403403 itself, then decide which Conv/Linear (or user-specified Op) will use single-sided quantizer
@@ -546,7 +546,7 @@ def find_single_sided_op_gm(
546546 gm ,
547547 targetOp = op_of_interest ,
548548 return_nodes = True ,
549- LUTfx_mod_name_to_org = LUTfx_mod_name_to_org ,
549+ lut_fx_mod_name_to_org = lut_fx_mod_name_to_org ,
550550 )
551551
552552 SingleSidedOps = []
@@ -557,7 +557,7 @@ def find_single_sided_op_gm(
557557 risky_nodes .append (n )
558558 if all (input_pos ):
559559 SingleSidedOps .append (
560- get_org_mod_name_of_fx_node (n , gm , LUTfx_mod_name_to_org )
560+ get_org_mod_name_of_fx_node (n , gm , lut_fx_mod_name_to_org )
561561 )
562562
563563 if risky_nodes :
@@ -570,7 +570,7 @@ def find_single_sided_op_gm(
570570 return SingleSidedOps
571571
572572
573- def find_qkvsync_candidates_gm (gm , return_nodes = False , LUTfx_mod_name_to_org = {}):
573+ def find_qkvsync_candidates_gm (gm , return_nodes = False , lut_fx_mod_name_to_org = {}):
574574 """Identify groups of Linears that share the same parent. It's a transformer-specific feature.
575575
576576 NOTE:
@@ -609,7 +609,7 @@ def find_qkvsync_candidates_gm(gm, return_nodes=False, LUTfx_mod_name_to_org={})
609609 for depth , nodes in LUTdep2nodes .items ():
610610 parents = [ni .all_input_nodes [0 ] for ni in nodes ]
611611 org_mod_names = [
612- get_org_mod_name_of_fx_node (ni , gm , LUTfx_mod_name_to_org ) for ni in nodes
612+ get_org_mod_name_of_fx_node (ni , gm , lut_fx_mod_name_to_org ) for ni in nodes
613613 ]
614614 if all (p == parents [0 ] for p in parents [1 :]):
615615 Nshared_parents += 1
@@ -620,7 +620,7 @@ def find_qkvsync_candidates_gm(gm, return_nodes=False, LUTfx_mod_name_to_org={})
620620 return my_1st_sibling
621621
622622
623- def find_silu_gm (gm , LUTfx_mod_name_to_org = {}):
623+ def find_silu_gm (gm , lut_fx_mod_name_to_org = {}):
624624 """Special handle for Conv following silu, specific for EffDet and etc
625625 LLM could use SiLU as well (llama?), but not relavent to this func
626626 """
@@ -633,14 +633,14 @@ def find_silu_gm(gm, LUTfx_mod_name_to_org={}):
633633 gpOp = get_target_op_from_node (gp_nodes [0 ], gm ) if gp_nodes else None
634634
635635 if torch .nn .functional .silu in [pOp , gpOp ]:
636- siluConv [get_org_mod_name_of_fx_node (n , gm , LUTfx_mod_name_to_org )] = {
636+ siluConv [get_org_mod_name_of_fx_node (n , gm , lut_fx_mod_name_to_org )] = {
637637 "qa_mode" : "qsilu"
638638 }
639639
640640 return siluConv
641641
642642
643- def find_rpn_fpn_gm (gm , verbose = False , Nsubgraph = 0 , LUTfx_mod_name_to_org = {}):
643+ def find_rpn_fpn_gm (gm , verbose = False , Nsubgraph = 0 , lut_fx_mod_name_to_org = {}):
644644 """For object detection CNN models, RPN (RegionProposalNetwork) and FPN (FeaturePyramidNetwork)
645645 are commonly used. prefer to skip them, but may be ok to quantize in some cases.
646646
@@ -703,7 +703,7 @@ def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0, LUTfx_mod_name_to_org={}):
703703 targetOp = [torch .nn .Conv2d ],
704704 start_nodes = fpn_st_nodes ,
705705 stop_nodes = [fpn_end_node ],
706- LUTfx_mod_name_to_org = LUTfx_mod_name_to_org ,
706+ lut_fx_mod_name_to_org = lut_fx_mod_name_to_org ,
707707 )
708708 fpn_convs = [mod_name for mod_name , ln in fpn_convs .keys ()] # see Note 4
709709 fpn_adds = dfs_gm (
@@ -730,7 +730,7 @@ def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0, LUTfx_mod_name_to_org={}):
730730 ):
731731 fpn_inner_blocks .append (
732732 get_org_mod_name_of_fx_node (
733- gp , LUTfx2org = LUTfx_mod_name_to_org
733+ gp , lut_fx2org = lut_fx_mod_name_to_org
734734 )
735735 )
736736 fpn_convs += fpn_inner_blocks
@@ -744,7 +744,7 @@ def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0, LUTfx_mod_name_to_org={}):
744744 return fpn_convs
745745
746746
747- def find_and_prep_bmm_gm (gm , LUTfx_mod_name_to_org = {}):
747+ def find_and_prep_bmm_gm (gm , lut_fx_mod_name_to_org = {}):
748748 """Previously with TorchScript, we use this func to perform 2 tasks:
749749 a) create QBmms, and then attach them to the model,
750750 b) set up qcfg["which2patch_contextmanager"] so that patch_torch_bmm() context
@@ -798,7 +798,7 @@ def find_and_prep_bmm_gm(gm, LUTfx_mod_name_to_org={}):
798798 LUTmodname2linenum = {} # see Note 4
799799 for node_line_num , depth in LUT2sort .items ():
800800 node , line_num = node_line_num
801- org_mod_name = get_org_mod_name_of_fx_node (node , gm , LUTfx_mod_name_to_org )
801+ org_mod_name = get_org_mod_name_of_fx_node (node , gm , lut_fx_mod_name_to_org )
802802 if org_mod_name in LUTmodname2linenum :
803803 LUTmodname2linenum [org_mod_name ] += [(node , line_num , depth )]
804804 else :
@@ -941,7 +941,7 @@ def cus_backend_model_analyzer(
941941
942942 """
943943 qcfg ["N_backend_called" ] += 1
944- LUTfx_mod_name_to_org = {
944+ lut_fx_mod_name_to_org = {
945945 n .replace (".weight" , "" ): LUTweight2modname [p ]
946946 for n , p in gm_fx .named_parameters ()
947947 if p in LUTweight2modname
@@ -954,7 +954,7 @@ def cus_backend_model_analyzer(
954954 if isinstance (mod , (torch .nn .Linear , torch .nn .Conv2d )):
955955 real_org_modname = LUTweight2modname [mod .weight ]
956956 part_org_modname = get_org_mod_name_of_fx_node (
957- n , gm_fx , LUTfx_mod_name_to_org
957+ n , gm_fx , lut_fx_mod_name_to_org
958958 )
959959 idx = real_org_modname .rindex (part_org_modname )
960960 if idx > 1 :
@@ -972,7 +972,7 @@ def cus_backend_model_analyzer(
972972 outputname = plotsvg ,
973973 show_details = True ,
974974 Nnode_to_plot = 1000 ,
975- LUTfx_mod_name_to_org = LUTfx_mod_name_to_org ,
975+ lut_fx_mod_name_to_org = lut_fx_mod_name_to_org ,
976976 )
977977
978978 # Graph checks begin. Use append, add prefix if needed
@@ -984,7 +984,7 @@ def cus_backend_model_analyzer(
984984 if isinstance (m , torch .nn .Conv2d ) or issubclass (type (m ), torch .nn .Conv2d )
985985 ]
986986 if len (all_conv ) > 0 :
987- skip_candidates += find_conv_on_shortcut_gm (gm_fx , LUTfx_mod_name_to_org )
987+ skip_candidates += find_conv_on_shortcut_gm (gm_fx , lut_fx_mod_name_to_org )
988988
989989 # Check 2. first/last, see Note 2 and 3
990990 if qcfg ["N_backend_called" ] > 1 :
@@ -993,19 +993,19 @@ def cus_backend_model_analyzer(
993993 _ , last_only = find_1st_last_gm (
994994 gm_fx ,
995995 return_1st_last_sep = True ,
996- LUTfx_mod_name_to_org = LUTfx_mod_name_to_org ,
996+ lut_fx_mod_name_to_org = lut_fx_mod_name_to_org ,
997997 )
998998 skip_candidates += last_only
999999 else :
10001000 # see Note 4
10011001 skip_candidates += find_1st_last_gm (
1002- gm_fx , LUTfx_mod_name_to_org = LUTfx_mod_name_to_org
1002+ gm_fx , lut_fx_mod_name_to_org = lut_fx_mod_name_to_org
10031003 )
10041004 qcfg ["qskip_layer_name" ] += add_prefix_to_list_or_dict (skip_candidates , prefix )
10051005
10061006 # Check 3: single/double sided
10071007 qcfg ["qsinglesided_name" ] += add_prefix_to_list_or_dict (
1008- find_single_sided_op_gm (gm_fx , LUTfx_mod_name_to_org = LUTfx_mod_name_to_org ),
1008+ find_single_sided_op_gm (gm_fx , lut_fx_mod_name_to_org = lut_fx_mod_name_to_org ),
10091009 prefix ,
10101010 )
10111011
@@ -1016,12 +1016,12 @@ def cus_backend_model_analyzer(
10161016 # Check 5: Conv+SiLU
10171017 qcfg ["qspecial_layers" ].update (
10181018 add_prefix_to_list_or_dict (
1019- find_silu_gm (gm_fx , LUTfx_mod_name_to_org ), prefix
1019+ find_silu_gm (gm_fx , lut_fx_mod_name_to_org ), prefix
10201020 )
10211021 )
10221022
10231023 # Check 6: BMM
1024- temp_dict = find_and_prep_bmm_gm (gm_fx , LUTfx_mod_name_to_org ) # see Note 5
1024+ temp_dict = find_and_prep_bmm_gm (gm_fx , lut_fx_mod_name_to_org ) # see Note 5
10251025 if len (temp_dict ["layers_with_bmm" ]) > 0 :
10261026 temp_dict ["layers_with_bmm" ] = add_prefix_to_list_or_dict (
10271027 temp_dict ["layers_with_bmm" ], prefix
@@ -1033,7 +1033,7 @@ def cus_backend_model_analyzer(
10331033
10341034 # Check 7: QKV
10351035 temp_dict = find_qkvsync_candidates_gm (
1036- gm_fx , LUTfx_mod_name_to_org = LUTfx_mod_name_to_org
1036+ gm_fx , lut_fx_mod_name_to_org = lut_fx_mod_name_to_org
10371037 ) # see Note 6
10381038 temp_dict = add_prefix_to_list_or_dict (
10391039 temp_dict , prefix , update_both_k_and_v = True
0 commit comments