1717"""
1818
1919# Standard
20+ from typing import Dict , Optional
2021import logging
2122
2223# Third Party
@@ -46,7 +47,7 @@ def dfs_gm(
4647 prescreenOp = None ,
4748 hook = None ,
4849 return_nodes = False ,
49- lut_fx_mod_name_to_org = {} ,
50+ lut_fx_mod_name_to_org : Optional [ Dict [ str , str ]] = None ,
5051):
5152 """Depth-First Search at FX IR level, to replace our old TorchScript equivalent func
5253 Because FX IR is a higher level IR, should have much fewer
@@ -227,7 +228,9 @@ def _dfs(curr_node, depth):
227228 return node_found
228229
229230
230- def find_conv_on_shortcut_gm (gm : torch .fx .GraphModule , lut_fx_mod_name_to_org = {}):
231+ def find_conv_on_shortcut_gm (
232+ gm : torch .fx .GraphModule , lut_fx_mod_name_to_org : Optional [Dict [str , str ]] = None
233+ ):
231234 """Identify Conv on shortcut using FX GM DFS
232235 It's (almost) specific for ResNet-like CNNs, will return a list of module names (as used in the
233236 original model, not FX module names)
@@ -352,7 +355,7 @@ def find_1st_last_gm(
352355 firstOps = None ,
353356 lastOps = None ,
354357 return_1st_last_sep = False ,
355- lut_fx_mod_name_to_org = {} ,
358+ lut_fx_mod_name_to_org : Optional [ Dict [ str , str ]] = None ,
356359):
357360 """Identify the first and last layer of interests
358361 Usually only interested in Conv and Linear, but could be others as well
@@ -403,7 +406,11 @@ def find_1st_last_gm(
403406
404407
405408def find_single_sided_op_gm (
406- gm , op_of_interest = None , return_LUTs = False , verbose = False , lut_fx_mod_name_to_org = {}
409+ gm ,
410+ op_of_interest = None ,
411+ return_LUTs = False ,
412+ verbose = False ,
413+ lut_fx_mod_name_to_org : Optional [Dict [str , str ]] = None ,
407414):
408415 """Try to determine the "polarity" of output of "every nodes" based on their inputs and the Op
409416 itself, then decide which Conv/Linear (or user-specified Op) will use single-sided quantizer
@@ -576,7 +583,9 @@ def find_single_sided_op_gm(
576583 return SingleSidedOps
577584
578585
579- def find_qkvsync_candidates_gm (gm , return_nodes = False , lut_fx_mod_name_to_org = {}):
586+ def find_qkvsync_candidates_gm (
587+ gm , return_nodes = False , lut_fx_mod_name_to_org : Optional [Dict [str , str ]] = None
588+ ):
580589 """Identify groups of Linears that share the same parent. It's a transformer-specific feature.
581590
582591 NOTE:
@@ -626,7 +635,7 @@ def find_qkvsync_candidates_gm(gm, return_nodes=False, lut_fx_mod_name_to_org={}
626635 return my_1st_sibling
627636
628637
629- def find_silu_gm (gm , lut_fx_mod_name_to_org = {} ):
638+ def find_silu_gm (gm , lut_fx_mod_name_to_org : Optional [ Dict [ str , str ]] = None ):
630639 """Special handle for Conv following silu, specific for EffDet and etc
631640 LLM could use SiLU as well (llama?), but not relavent to this func
632641 """
@@ -646,7 +655,12 @@ def find_silu_gm(gm, lut_fx_mod_name_to_org={}):
646655 return siluConv
647656
648657
649- def find_rpn_fpn_gm (gm , verbose = False , Nsubgraph = 0 , lut_fx_mod_name_to_org = {}):
658+ def find_rpn_fpn_gm (
659+ gm ,
660+ verbose = False ,
661+ Nsubgraph = 0 ,
662+ lut_fx_mod_name_to_org : Optional [Dict [str , str ]] = None ,
663+ ):
650664 """For object detection CNN models, RPN (RegionProposalNetwork) and FPN (FeaturePyramidNetwork)
651665 are commonly used. prefer to skip them, but may be ok to quantize in some cases.
652666
@@ -750,7 +764,7 @@ def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0, lut_fx_mod_name_to_org={}):
750764 return fpn_convs
751765
752766
753- def find_and_prep_bmm_gm (gm , lut_fx_mod_name_to_org = {} ):
767+ def find_and_prep_bmm_gm (gm , lut_fx_mod_name_to_org : Optional [ Dict [ str , str ]] = None ):
754768 """Previously with TorchScript, we use this func to perform 2 tasks:
755769 a) create QBmms, and then attach them to the model,
756770 b) set up qcfg["which2patch_contextmanager"] so that patch_torch_bmm() context
@@ -1153,7 +1167,9 @@ def cus_backend_model_analyzer(
11531167 )
11541168 setattr (mod_bmm_happened , f"QBmm{ ln } " , newQBmm )
11551169
1156- # c) identify RPN/FPN TODO this hack only works for torchvision models. will use find_rpn_fpn_gm()
1170+ # c) identify RPN/FPN
1171+ # TODO this hack only works for torchvision models. will use find_rpn_fpn_gm()
1172+
11571173 # Third Party
11581174 from torchvision .models .detection .rpn import RegionProposalNetwork
11591175 from torchvision .ops import FeaturePyramidNetwork
0 commit comments