Skip to content

Commit c7cd321

Browse files
Fix type
Signed-off-by: Thara Palanivel <[email protected]>
1 parent 92ffd63 commit c7cd321

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

fms_mo/fx/dynamo_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def dfs_gm(
4747
prescreenOp=None,
4848
hook=None,
4949
return_nodes=False,
50-
lut_fx_mod_name_to_org: Optional[Dict[int, str]] = None,
50+
lut_fx_mod_name_to_org: Optional[Dict[str, str]] = None,
5151
):
5252
"""Depth-First Search at FX IR level, to replace our old TorchScript equivalent func
5353
Because FX IR is a higher level IR, should have much fewer
@@ -229,7 +229,7 @@ def _dfs(curr_node, depth):
229229

230230

231231
def find_conv_on_shortcut_gm(
232-
gm: torch.fx.GraphModule, lut_fx_mod_name_to_org: Optional[Dict[int, str]] = None
232+
gm: torch.fx.GraphModule, lut_fx_mod_name_to_org: Optional[Dict[str, str]] = None
233233
):
234234
"""Identify Conv on shortcut using FX GM DFS
235235
It's (almost) specific for ResNet-like CNNs, will return a list of module names (as used in the
@@ -355,7 +355,7 @@ def find_1st_last_gm(
355355
firstOps=None,
356356
lastOps=None,
357357
return_1st_last_sep=False,
358-
lut_fx_mod_name_to_org: Optional[Dict[int, str]] = None,
358+
lut_fx_mod_name_to_org: Optional[Dict[str, str]] = None,
359359
):
360360
"""Identify the first and last layer of interests
361361
Usually only interested in Conv and Linear, but could be others as well
@@ -410,7 +410,7 @@ def find_single_sided_op_gm(
410410
op_of_interest=None,
411411
return_LUTs=False,
412412
verbose=False,
413-
lut_fx_mod_name_to_org: Optional[Dict[int, str]] = None,
413+
lut_fx_mod_name_to_org: Optional[Dict[str, str]] = None,
414414
):
415415
"""Try to determine the "polarity" of output of "every nodes" based on their inputs and the Op
416416
itself, then decide which Conv/Linear (or user-specified Op) will use single-sided quantizer
@@ -584,7 +584,7 @@ def find_single_sided_op_gm(
584584

585585

586586
def find_qkvsync_candidates_gm(
587-
gm, return_nodes=False, lut_fx_mod_name_to_org: Optional[Dict[int, str]] = None
587+
gm, return_nodes=False, lut_fx_mod_name_to_org: Optional[Dict[str, str]] = None
588588
):
589589
"""Identify groups of Linears that share the same parent. It's a transformer-specific feature.
590590
@@ -635,7 +635,7 @@ def find_qkvsync_candidates_gm(
635635
return my_1st_sibling
636636

637637

638-
def find_silu_gm(gm, lut_fx_mod_name_to_org: Optional[Dict[int, str]] = None):
638+
def find_silu_gm(gm, lut_fx_mod_name_to_org: Optional[Dict[str, str]] = None):
639639
"""Special handle for Conv following silu, specific for EffDet and etc
640640
LLM could use SiLU as well (llama?), but not relavent to this func
641641
"""
@@ -659,7 +659,7 @@ def find_rpn_fpn_gm(
659659
gm,
660660
verbose=False,
661661
Nsubgraph=0,
662-
lut_fx_mod_name_to_org: Optional[Dict[int, str]] = None,
662+
lut_fx_mod_name_to_org: Optional[Dict[str, str]] = None,
663663
):
664664
"""For object detection CNN models, RPN (RegionProposalNetwork) and FPN (FeaturePyramidNetwork)
665665
are commonly used. prefer to skip them, but may be ok to quantize in some cases.
@@ -764,7 +764,7 @@ def find_rpn_fpn_gm(
764764
return fpn_convs
765765

766766

767-
def find_and_prep_bmm_gm(gm, lut_fx_mod_name_to_org: Optional[Dict[int, str]] = None):
767+
def find_and_prep_bmm_gm(gm, lut_fx_mod_name_to_org: Optional[Dict[str, str]] = None):
768768
"""Previously with TorchScript, we use this func to perform 2 tasks:
769769
a) create QBmms, and then attach them to the model,
770770
b) set up qcfg["which2patch_contextmanager"] so that patch_torch_bmm() context

fms_mo/fx/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def lname_to_org_name(Lname):
318318

319319

320320
def get_org_mod_name_of_fx_node(
321-
node, gm=None, lut_fx2org: Optional[Dict[int, str]] = None
321+
node, gm=None, lut_fx2org: Optional[Dict[str, str]] = None
322322
):
323323
"""Given a FX node, could be call_module or call_fuction, find out the original module name,
324324
based on meta data
@@ -491,7 +491,7 @@ def plot_graph_module(
491491
skip_nodes=None,
492492
Nnode_to_plot=None,
493493
additional_coloring_rules=None,
494-
lut_fx_mod_name_to_org: Optional[Dict[int, str]] = None,
494+
lut_fx_mod_name_to_org: Optional[Dict[str, str]] = None,
495495
):
496496
"""Plots a GraphModule in .SVG format to visualize the compute graph. If graphviz/pygraphviz is
497497
not installed properly, this function will just print out a message and do nothing.

0 commit comments

Comments
 (0)