@@ -46,6 +46,7 @@ def dfs_gm(
4646 prescreenOp = None ,
4747 hook = None ,
4848 return_nodes = False ,
49+ LUTfx_mod_name_to_org = {},
4950):
5051 """Depth-First Search at FX IR level, to replace our old TorchScript equivalent func
5152 Because FX IR is a higher level IR, should have much fewer
@@ -215,9 +216,9 @@ def _dfs(curr_node, depth):
215216 org_mod_names = {}
216217 for n_ln , d in node_found .items ():
217218 n , line_num = n_ln # unpack tuple
218- org_mod_names [get_org_mod_name_of_fx_node ( n , gm ), line_num ] = (
219- d # see Note 2
220- )
219+ org_mod_names [
220+ get_org_mod_name_of_fx_node ( n , gm , LUTfx_mod_name_to_org ), line_num
221+ ] = d # see Note 2
221222
222223 return dict (
223224 sorted (org_mod_names .items (), key = lambda item : item [1 ])
@@ -226,7 +227,7 @@ def _dfs(curr_node, depth):
226227 return node_found
227228
228229
229- def find_conv_on_shortcut_gm (gm : torch .fx .GraphModule ):
230+ def find_conv_on_shortcut_gm (gm : torch .fx .GraphModule , LUTfx_mod_name_to_org = {} ):
230231 """Identify Conv on shortcut using FX GM DFS
231232 It's (almost) specific for ResNet-like CNNs, will return a list of module names (as used in the
232233 original model, not FX module names)
@@ -333,14 +334,20 @@ def find_conv_on_shortcut_gm(gm: torch.fx.GraphModule):
333334 if n_conv_i .op == "call_module" :
334335 conv_mod = gm .get_submodule (n_conv_i .target )
335336 else :
336- conv_mod = get_org_mod_name_of_fx_node (n_conv_i )
337+ conv_mod = get_org_mod_name_of_fx_node (
338+ n_conv_i , LUTfx2org = LUTfx_mod_name_to_org
339+ )
337340 if conv_mod .out_channels > conv_mod .in_channels : # see Note 2
338- qconv_candidate .append (get_org_mod_name_of_fx_node (n_conv_i , gm ))
341+ qconv_candidate .append (
342+ get_org_mod_name_of_fx_node (n_conv_i , gm , LUTfx_mod_name_to_org )
343+ )
339344
340345 return qconv_candidate
341346
342347
343- def find_1st_last_gm (gm , firstOps = None , lastOps = None , return_1st_last_sep = False ):
348+ def find_1st_last_gm (
349+ gm , firstOps = None , lastOps = None , return_1st_last_sep = False , LUTfx_mod_name_to_org = {}
350+ ):
344351 """Identify the first and last layer of interests
345352 Usually only interested in Conv and Linear, but could be others as well
346353 NOTE:
@@ -355,8 +362,19 @@ def find_1st_last_gm(gm, firstOps=None, lastOps=None, return_1st_last_sep=False)
355362 firstOps = [torch .nn .Conv2d , torch .nn .Linear ]
356363 if lastOps is None :
357364 lastOps = [torch .nn .Conv2d , torch .nn .Linear ]
358- first_candidates = dfs_gm (gm , targetOp = firstOps , find1stOnly = True )
359- last_candidates = dfs_gm (gm , targetOp = lastOps , find1stOnly = True , reverse = True )
365+ first_candidates = dfs_gm (
366+ gm ,
367+ targetOp = firstOps ,
368+ find1stOnly = True ,
369+ LUTfx_mod_name_to_org = LUTfx_mod_name_to_org ,
370+ )
371+ last_candidates = dfs_gm (
372+ gm ,
373+ targetOp = lastOps ,
374+ find1stOnly = True ,
375+ reverse = True ,
376+ LUTfx_mod_name_to_org = LUTfx_mod_name_to_org ,
377+ )
360378
361379 min_depth = min (list (first_candidates .values ()) + [999 ])
362380 real_first = [
@@ -379,10 +397,7 @@ def find_1st_last_gm(gm, firstOps=None, lastOps=None, return_1st_last_sep=False)
379397
380398
381399def find_single_sided_op_gm (
382- gm ,
383- op_of_interest = None ,
384- return_LUTs = False ,
385- verbose = False ,
400+ gm , op_of_interest = None , return_LUTs = False , verbose = False , LUTfx_mod_name_to_org = {}
386401):
387402 """Try to determine the "polarity" of output of "every nodes" based on their inputs and the Op
388403 itself, then decide which Conv/Linear (or user-specified Op) will use single-sided quantizer
@@ -527,7 +542,12 @@ def find_single_sided_op_gm(
527542 if return_LUTs :
528543 return isActOutPositiveOnly , isActOutBounded
529544
530- node_of_interest = dfs_gm (gm , targetOp = op_of_interest , return_nodes = True )
545+ node_of_interest = dfs_gm (
546+ gm ,
547+ targetOp = op_of_interest ,
548+ return_nodes = True ,
549+ LUTfx_mod_name_to_org = LUTfx_mod_name_to_org ,
550+ )
531551
532552 SingleSidedOps = []
533553 risky_nodes = []
@@ -536,7 +556,9 @@ def find_single_sided_op_gm(
536556 if (False not in input_pos ) and (None in input_pos ): # see Note g
537557 risky_nodes .append (n )
538558 if all (input_pos ):
539- SingleSidedOps .append (get_org_mod_name_of_fx_node (n , gm ))
559+ SingleSidedOps .append (
560+ get_org_mod_name_of_fx_node (n , gm , LUTfx_mod_name_to_org )
561+ )
540562
541563 if risky_nodes :
542564 logger .warning (
@@ -548,7 +570,7 @@ def find_single_sided_op_gm(
548570 return SingleSidedOps
549571
550572
551- def find_qkvsync_candidates_gm (gm , return_nodes = False ):
573+ def find_qkvsync_candidates_gm (gm , return_nodes = False , LUTfx_mod_name_to_org = {} ):
552574 """Identify groups of Linears that share the same parent. It's a transformer-specific feature.
553575
554576 NOTE:
@@ -586,7 +608,9 @@ def find_qkvsync_candidates_gm(gm, return_nodes=False):
586608 Nshared_parents = 0
587609 for depth , nodes in LUTdep2nodes .items ():
588610 parents = [ni .all_input_nodes [0 ] for ni in nodes ]
589- org_mod_names = [get_org_mod_name_of_fx_node (ni , gm ) for ni in nodes ]
611+ org_mod_names = [
612+ get_org_mod_name_of_fx_node (ni , gm , LUTfx_mod_name_to_org ) for ni in nodes
613+ ]
590614 if all (p == parents [0 ] for p in parents [1 :]):
591615 Nshared_parents += 1
592616 for org_name_i in org_mod_names :
@@ -596,7 +620,7 @@ def find_qkvsync_candidates_gm(gm, return_nodes=False):
596620 return my_1st_sibling
597621
598622
599- def find_silu_gm (gm ):
623+ def find_silu_gm (gm , LUTfx_mod_name_to_org = {} ):
600624 """Special handle for Conv following silu, specific for EffDet and etc
601625 LLM could use SiLU as well (llama?), but not relavent to this func
602626 """
@@ -609,12 +633,14 @@ def find_silu_gm(gm):
609633 gpOp = get_target_op_from_node (gp_nodes [0 ], gm ) if gp_nodes else None
610634
611635 if torch .nn .functional .silu in [pOp , gpOp ]:
612- siluConv [get_org_mod_name_of_fx_node (n , gm )] = {"qa_mode" : "qsilu" }
636+ siluConv [get_org_mod_name_of_fx_node (n , gm , LUTfx_mod_name_to_org )] = {
637+ "qa_mode" : "qsilu"
638+ }
613639
614640 return siluConv
615641
616642
617- def find_rpn_fpn_gm (gm , verbose = False , Nsubgraph = 0 ):
643+ def find_rpn_fpn_gm (gm , verbose = False , Nsubgraph = 0 , LUTfx_mod_name_to_org = {} ):
618644 """For object detection CNN models, RPN (RegionProposalNetwork) and FPN (FeaturePyramidNetwork)
619645 are commonly used. prefer to skip them, but may be ok to quantize in some cases.
620646
@@ -677,6 +703,7 @@ def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0):
677703 targetOp = [torch .nn .Conv2d ],
678704 start_nodes = fpn_st_nodes ,
679705 stop_nodes = [fpn_end_node ],
706+ LUTfx_mod_name_to_org = LUTfx_mod_name_to_org ,
680707 )
681708 fpn_convs = [mod_name for mod_name , ln in fpn_convs .keys ()] # see Note 4
682709 fpn_adds = dfs_gm (
@@ -701,7 +728,11 @@ def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0):
701728 isinstance (tar_op_gp , torch .nn .Conv2d )
702729 and tar_op_gp .kernel_size in [1 , (1 , 1 )]
703730 ):
704- fpn_inner_blocks .append (get_org_mod_name_of_fx_node (gp ))
731+ fpn_inner_blocks .append (
732+ get_org_mod_name_of_fx_node (
733+ gp , LUTfx2org = LUTfx_mod_name_to_org
734+ )
735+ )
705736 fpn_convs += fpn_inner_blocks
706737
707738 if verbose :
@@ -713,7 +744,7 @@ def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0):
713744 return fpn_convs
714745
715746
716- def find_and_prep_bmm_gm (gm ):
747+ def find_and_prep_bmm_gm (gm , LUTfx_mod_name_to_org = {} ):
717748 """Previously with TorchScript, we use this func to perform 2 tasks:
718749 a) create QBmms, and then attach them to the model,
719750 b) set up qcfg["which2patch_contextmanager"] so that patch_torch_bmm() context
@@ -767,7 +798,7 @@ def find_and_prep_bmm_gm(gm):
767798 LUTmodname2linenum = {} # see Note 4
768799 for node_line_num , depth in LUT2sort .items ():
769800 node , line_num = node_line_num
770- org_mod_name = get_org_mod_name_of_fx_node (node , gm )
801+ org_mod_name = get_org_mod_name_of_fx_node (node , gm , LUTfx_mod_name_to_org )
771802 if org_mod_name in LUTmodname2linenum :
772803 LUTmodname2linenum [org_mod_name ] += [(node , line_num , depth )]
773804 else :
@@ -910,14 +941,21 @@ def cus_backend_model_analyzer(
910941
911942 """
912943 qcfg ["N_backend_called" ] += 1
944+ LUTfx_mod_name_to_org = {
945+ n .replace (".weight" , "" ): LUTweight2modname [p ]
946+ for n , p in gm_fx .named_parameters ()
947+ if p in LUTweight2modname
948+ }
913949 prefix = None
914950 if qcfg ["N_backend_called" ] > 1 : # subgraph found, see Note 2
915951 for n in gm_fx .graph .nodes :
916952 if n .op == "call_module" :
917953 mod = gm_fx .get_submodule (n .target )
918954 if isinstance (mod , (torch .nn .Linear , torch .nn .Conv2d )):
919955 real_org_modname = LUTweight2modname [mod .weight ]
920- part_org_modname = get_org_mod_name_of_fx_node (n , gm_fx )
956+ part_org_modname = get_org_mod_name_of_fx_node (
957+ n , gm_fx , LUTfx_mod_name_to_org
958+ )
921959 idx = real_org_modname .rindex (part_org_modname )
922960 if idx > 1 :
923961 prefix = real_org_modname [: idx - 1 ] # remove trailing '.'
@@ -930,28 +968,45 @@ def cus_backend_model_analyzer(
930968 if not isinstance (plotsvg , str ):
931969 plotsvg = f"debug{ qcfg ['N_backend_called' ]} .svg"
932970 plot_graph_module (
933- gm_fx , outputname = plotsvg , show_details = True , Nnode_to_plot = 1000
971+ gm_fx ,
972+ outputname = plotsvg ,
973+ show_details = True ,
974+ Nnode_to_plot = 1000 ,
975+ LUTfx_mod_name_to_org = LUTfx_mod_name_to_org ,
934976 )
935977
936978 # Graph checks begin. Use append, add prefix if needed
937979 skip_candidates = []
938980 # Check 1. Conv on shortcut
939- skip_candidates += find_conv_on_shortcut_gm (gm_fx )
981+ all_conv = [
982+ m
983+ for _ , m in gm_fx .named_modules ()
984+ if isinstance (m , torch .nn .Conv2d ) or issubclass (type (m ), torch .nn .Conv2d )
985+ ]
986+ if len (all_conv ) > 0 :
987+ skip_candidates += find_conv_on_shortcut_gm (gm_fx , LUTfx_mod_name_to_org )
940988
941989 # Check 2. first/last, see Note 2 and 3
942990 if qcfg ["N_backend_called" ] > 1 :
943991 skip_candidates += []
944992 elif is_transformers :
945- _ , last_only = find_1st_last_gm (gm_fx , return_1st_last_sep = True )
993+ _ , last_only = find_1st_last_gm (
994+ gm_fx ,
995+ return_1st_last_sep = True ,
996+ LUTfx_mod_name_to_org = LUTfx_mod_name_to_org ,
997+ )
946998 skip_candidates += last_only
947999 else :
9481000 # see Note 4
949- skip_candidates += find_1st_last_gm (gm_fx )
1001+ skip_candidates += find_1st_last_gm (
1002+ gm_fx , LUTfx_mod_name_to_org = LUTfx_mod_name_to_org
1003+ )
9501004 qcfg ["qskip_layer_name" ] += add_prefix_to_list_or_dict (skip_candidates , prefix )
9511005
9521006 # Check 3: single/double sided
9531007 qcfg ["qsinglesided_name" ] += add_prefix_to_list_or_dict (
954- find_single_sided_op_gm (gm_fx ), prefix
1008+ find_single_sided_op_gm (gm_fx , LUTfx_mod_name_to_org = LUTfx_mod_name_to_org ),
1009+ prefix ,
9551010 )
9561011
9571012 # Check 4: identify RPN/FPN
@@ -960,11 +1015,13 @@ def cus_backend_model_analyzer(
9601015 # NOTE: The following 3 funcs return dict instead of list. Use update() instead of append().
9611016 # Check 5: Conv+SiLU
9621017 qcfg ["qspecial_layers" ].update (
963- add_prefix_to_list_or_dict (find_silu_gm (gm_fx ), prefix )
1018+ add_prefix_to_list_or_dict (
1019+ find_silu_gm (gm_fx , LUTfx_mod_name_to_org ), prefix
1020+ )
9641021 )
9651022
9661023 # Check 6: BMM
967- temp_dict = find_and_prep_bmm_gm (gm_fx ) # see Note 5
1024+ temp_dict = find_and_prep_bmm_gm (gm_fx , LUTfx_mod_name_to_org ) # see Note 5
9681025 if len (temp_dict ["layers_with_bmm" ]) > 0 :
9691026 temp_dict ["layers_with_bmm" ] = add_prefix_to_list_or_dict (
9701027 temp_dict ["layers_with_bmm" ], prefix
@@ -975,7 +1032,9 @@ def cus_backend_model_analyzer(
9751032 qcfg ["bmm_prep" ]["layers_with_bmm" ].update (temp_dict ["layers_with_bmm" ])
9761033
9771034 # Check 7: QKV
978- temp_dict = find_qkvsync_candidates_gm (gm_fx ) # see Note 6
1035+ temp_dict = find_qkvsync_candidates_gm (
1036+ gm_fx , LUTfx_mod_name_to_org = LUTfx_mod_name_to_org
1037+ ) # see Note 6
9791038 temp_dict = add_prefix_to_list_or_dict (
9801039 temp_dict , prefix , update_both_k_and_v = True
9811040 )
0 commit comments