@@ -47,7 +47,7 @@ def dfs_gm(
4747 prescreenOp = None ,
4848 hook = None ,
4949 return_nodes = False ,
50- lut_fx_mod_name_to_org : Optional [Dict [int , str ]] = None ,
50+ lut_fx_mod_name_to_org : Optional [Dict [str , str ]] = None ,
5151):
5252 """Depth-First Search at FX IR level, to replace our old TorchScript equivalent func
5353 Because FX IR is a higher level IR, should have much fewer
@@ -229,7 +229,7 @@ def _dfs(curr_node, depth):
229229
230230
231231def find_conv_on_shortcut_gm (
232- gm : torch .fx .GraphModule , lut_fx_mod_name_to_org : Optional [Dict [int , str ]] = None
232+ gm : torch .fx .GraphModule , lut_fx_mod_name_to_org : Optional [Dict [str , str ]] = None
233233):
234234 """Identify Conv on shortcut using FX GM DFS
235235 It's (almost) specific for ResNet-like CNNs, will return a list of module names (as used in the
@@ -355,7 +355,7 @@ def find_1st_last_gm(
355355 firstOps = None ,
356356 lastOps = None ,
357357 return_1st_last_sep = False ,
358- lut_fx_mod_name_to_org : Optional [Dict [int , str ]] = None ,
358+ lut_fx_mod_name_to_org : Optional [Dict [str , str ]] = None ,
359359):
360360 """Identify the first and last layer of interests
361361 Usually only interested in Conv and Linear, but could be others as well
@@ -410,7 +410,7 @@ def find_single_sided_op_gm(
410410 op_of_interest = None ,
411411 return_LUTs = False ,
412412 verbose = False ,
413- lut_fx_mod_name_to_org : Optional [Dict [int , str ]] = None ,
413+ lut_fx_mod_name_to_org : Optional [Dict [str , str ]] = None ,
414414):
415415 """Try to determine the "polarity" of output of "every nodes" based on their inputs and the Op
416416 itself, then decide which Conv/Linear (or user-specified Op) will use single-sided quantizer
@@ -584,7 +584,7 @@ def find_single_sided_op_gm(
584584
585585
586586def find_qkvsync_candidates_gm (
587- gm , return_nodes = False , lut_fx_mod_name_to_org : Optional [Dict [int , str ]] = None
587+ gm , return_nodes = False , lut_fx_mod_name_to_org : Optional [Dict [str , str ]] = None
588588):
589589 """Identify groups of Linears that share the same parent. It's a transformer-specific feature.
590590
@@ -635,7 +635,7 @@ def find_qkvsync_candidates_gm(
635635 return my_1st_sibling
636636
637637
638- def find_silu_gm (gm , lut_fx_mod_name_to_org : Optional [Dict [int , str ]] = None ):
638+ def find_silu_gm (gm , lut_fx_mod_name_to_org : Optional [Dict [str , str ]] = None ):
639639 """Special handle for Conv following silu, specific for EffDet and etc
640640 LLM could use SiLU as well (llama?), but not relavent to this func
641641 """
@@ -659,7 +659,7 @@ def find_rpn_fpn_gm(
659659 gm ,
660660 verbose = False ,
661661 Nsubgraph = 0 ,
662- lut_fx_mod_name_to_org : Optional [Dict [int , str ]] = None ,
662+ lut_fx_mod_name_to_org : Optional [Dict [str , str ]] = None ,
663663):
664664 """For object detection CNN models, RPN (RegionProposalNetwork) and FPN (FeaturePyramidNetwork)
665665 are commonly used. prefer to skip them, but may be ok to quantize in some cases.
@@ -764,7 +764,7 @@ def find_rpn_fpn_gm(
764764 return fpn_convs
765765
766766
767- def find_and_prep_bmm_gm (gm , lut_fx_mod_name_to_org : Optional [Dict [int , str ]] = None ):
767+ def find_and_prep_bmm_gm (gm , lut_fx_mod_name_to_org : Optional [Dict [str , str ]] = None ):
768768 """Previously with TorchScript, we use this func to perform 2 tasks:
769769 a) create QBmms, and then attach them to the model,
770770 b) set up qcfg["which2patch_contextmanager"] so that patch_torch_bmm() context
0 commit comments