Skip to content

Commit de959e7

Browse files
Merge pull request #54 from tharapalanivel/fx_lint
Lint for fx
2 parents 2d1d91d + c7cd321 commit de959e7

File tree

4 files changed

+47
-13
lines changed

4 files changed

+47
-13
lines changed

fms_mo/fx/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Imports from fx
16+
"""

fms_mo/fx/dynamo_utils.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"""
1818

1919
# Standard
20+
from typing import Dict, Optional
2021
import logging
2122

2223
# Third Party
@@ -46,7 +47,7 @@ def dfs_gm(
4647
prescreenOp=None,
4748
hook=None,
4849
return_nodes=False,
49-
lut_fx_mod_name_to_org={},
50+
lut_fx_mod_name_to_org: Optional[Dict[str, str]] = None,
5051
):
5152
"""Depth-First Search at FX IR level, to replace our old TorchScript equivalent func
5253
Because FX IR is a higher level IR, should have much fewer
@@ -227,7 +228,9 @@ def _dfs(curr_node, depth):
227228
return node_found
228229

229230

230-
def find_conv_on_shortcut_gm(gm: torch.fx.GraphModule, lut_fx_mod_name_to_org={}):
231+
def find_conv_on_shortcut_gm(
232+
gm: torch.fx.GraphModule, lut_fx_mod_name_to_org: Optional[Dict[str, str]] = None
233+
):
231234
"""Identify Conv on shortcut using FX GM DFS
232235
It's (almost) specific for ResNet-like CNNs, will return a list of module names (as used in the
233236
original model, not FX module names)
@@ -352,7 +355,7 @@ def find_1st_last_gm(
352355
firstOps=None,
353356
lastOps=None,
354357
return_1st_last_sep=False,
355-
lut_fx_mod_name_to_org={},
358+
lut_fx_mod_name_to_org: Optional[Dict[str, str]] = None,
356359
):
357360
"""Identify the first and last layer of interests
358361
Usually only interested in Conv and Linear, but could be others as well
@@ -403,7 +406,11 @@ def find_1st_last_gm(
403406

404407

405408
def find_single_sided_op_gm(
406-
gm, op_of_interest=None, return_LUTs=False, verbose=False, lut_fx_mod_name_to_org={}
409+
gm,
410+
op_of_interest=None,
411+
return_LUTs=False,
412+
verbose=False,
413+
lut_fx_mod_name_to_org: Optional[Dict[str, str]] = None,
407414
):
408415
"""Try to determine the "polarity" of output of "every nodes" based on their inputs and the Op
409416
itself, then decide which Conv/Linear (or user-specified Op) will use single-sided quantizer
@@ -576,7 +583,9 @@ def find_single_sided_op_gm(
576583
return SingleSidedOps
577584

578585

579-
def find_qkvsync_candidates_gm(gm, return_nodes=False, lut_fx_mod_name_to_org={}):
586+
def find_qkvsync_candidates_gm(
587+
gm, return_nodes=False, lut_fx_mod_name_to_org: Optional[Dict[str, str]] = None
588+
):
580589
"""Identify groups of Linears that share the same parent. It's a transformer-specific feature.
581590
582591
NOTE:
@@ -626,7 +635,7 @@ def find_qkvsync_candidates_gm(gm, return_nodes=False, lut_fx_mod_name_to_org={}
626635
return my_1st_sibling
627636

628637

629-
def find_silu_gm(gm, lut_fx_mod_name_to_org={}):
638+
def find_silu_gm(gm, lut_fx_mod_name_to_org: Optional[Dict[str, str]] = None):
630639
"""Special handle for Conv following silu, specific for EffDet and etc
631640
LLM could use SiLU as well (llama?), but not relavent to this func
632641
"""
@@ -646,7 +655,12 @@ def find_silu_gm(gm, lut_fx_mod_name_to_org={}):
646655
return siluConv
647656

648657

649-
def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0, lut_fx_mod_name_to_org={}):
658+
def find_rpn_fpn_gm(
659+
gm,
660+
verbose=False,
661+
Nsubgraph=0,
662+
lut_fx_mod_name_to_org: Optional[Dict[str, str]] = None,
663+
):
650664
"""For object detection CNN models, RPN (RegionProposalNetwork) and FPN (FeaturePyramidNetwork)
651665
are commonly used. prefer to skip them, but may be ok to quantize in some cases.
652666
@@ -750,7 +764,7 @@ def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0, lut_fx_mod_name_to_org={}):
750764
return fpn_convs
751765

752766

753-
def find_and_prep_bmm_gm(gm, lut_fx_mod_name_to_org={}):
767+
def find_and_prep_bmm_gm(gm, lut_fx_mod_name_to_org: Optional[Dict[str, str]] = None):
754768
"""Previously with TorchScript, we use this func to perform 2 tasks:
755769
a) create QBmms, and then attach them to the model,
756770
b) set up qcfg["which2patch_contextmanager"] so that patch_torch_bmm() context
@@ -1153,7 +1167,9 @@ def cus_backend_model_analyzer(
11531167
)
11541168
setattr(mod_bmm_happened, f"QBmm{ln}", newQBmm)
11551169

1156-
# c) identify RPN/FPN TODO this hack only works for torchvision models. will use find_rpn_fpn_gm()
1170+
# c) identify RPN/FPN
1171+
# TODO this hack only works for torchvision models. will use find_rpn_fpn_gm()
1172+
11571173
# Third Party
11581174
from torchvision.models.detection.rpn import RegionProposalNetwork
11591175
from torchvision.ops import FeaturePyramidNetwork

fms_mo/fx/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""Utils for FX graph parsing and external kernel lowering"""
1616

1717
# Standard
18-
from typing import Any
18+
from typing import Any, Dict, Optional
1919
import logging
2020
import operator
2121
import os
@@ -317,7 +317,9 @@ def lname_to_org_name(Lname):
317317
return org_mod_name
318318

319319

320-
def get_org_mod_name_of_fx_node(node, gm=None, lut_fx2org={}):
320+
def get_org_mod_name_of_fx_node(
321+
node, gm=None, lut_fx2org: Optional[Dict[str, str]] = None
322+
):
321323
"""Given a FX node, could be call_module or call_fuction, find out the original module name,
322324
based on meta data
323325
@@ -489,7 +491,7 @@ def plot_graph_module(
489491
skip_nodes=None,
490492
Nnode_to_plot=None,
491493
additional_coloring_rules=None,
492-
lut_fx_mod_name_to_org={},
494+
lut_fx_mod_name_to_org: Optional[Dict[str, str]] = None,
493495
):
494496
"""Plots a GraphModule in .SVG format to visualize the compute graph. If graphviz/pygraphviz is
495497
not installed properly, this function will just print out a message and do nothing.

fms_mo/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""
15-
imports all from prep.py
15+
Imports for utils
1616
"""

0 commit comments

Comments
 (0)