Skip to content

Commit dd9a05c

Browse files
committed
fix some hint bugs
1 parent 6a7780b commit dd9a05c

File tree

6 files changed

+7
-22
lines changed

6 files changed

+7
-22
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,11 @@ repos:
1111
- id: ruff-check
1212
args: [--fix, --exit-non-zero-on-fix, --no-cache]
1313

14-
- repo: https://github.com/PFCCLab/typos-pre-commit-mirror.git
15-
rev: v1.39.2
16-
hooks:
17-
- id: typos
18-
args: [--force-exclude]
19-
2014
- repo: https://github.com/Lucas-C/pre-commit-hooks.git
2115
rev: v1.5.1
2216
hooks:
2317
- id: remove-crlf
2418
- id: remove-tabs
2519
name: Tabs remver (Python)
2620
files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
27-
args: [--whitespaces-count, '4']
21+
args: [--whitespaces-count, '4']

graph_net/torch/constraint_util.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ def __init__(self, config):
1818
if config is None:
1919
config = {}
2020

21-
graph_net_root = os.path.dirname(graph_net.__file__)
2221
decorator_config = {"use_dummy_inputs": True}
2322
self.predicator = RunModelPredicator(decorator_config)
2423

graph_net/torch/dim_gen_passes/batch_call_method_view_pass.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import torch
21
import torch.fx as fx
32
from graph_net.torch.dim_gen_passes import DimensionGeneralizationPass
43
import os

graph_net/torch/dim_gen_passes/naive_call_method_view_pass.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import torch
21
import torch.fx as fx
32
from graph_net.torch.dim_gen_passes import DimensionGeneralizationPass
43
import os
@@ -22,7 +21,7 @@ def _node_need_rewrite(self, node) -> bool:
2221
if not (node.target == "view"):
2322
return False
2423
print(f"{self.dim=} {node.args[1:]=}")
25-
if not (self.dim in node.args[1:]):
24+
if self.dim not in node.args[1:]:
2625
return False
2726
return True
2827

graph_net/torch/dim_gen_passes/tuple_arg_call_method_view_pass.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import torch
21
import torch.fx as fx
32
from graph_net.torch.dim_gen_passes import DimensionGeneralizationPass
43
import os
@@ -25,7 +24,7 @@ def _node_need_rewrite(self, node) -> bool:
2524
return False
2625
if not (isinstance(node.args[1], tuple)):
2726
return False
28-
if not (self.dim in node.args[1]):
27+
if self.dim not in node.args[1]:
2928
return False
3029
return True
3130

graph_net/torch/static_to_dynamic.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,15 @@
11
import traceback
22
import logging
33
import torch
4-
import torch.fx as fx
54
from graph_net.torch.utils import get_dummy_named_tensors
65
from torch.fx.passes.shape_prop import ShapeProp
76
from graph_net.torch.utils import apply_templates
87
from pathlib import Path
98
import inspect
10-
from typing import Any
11-
from contextlib import contextmanager
12-
from torch.export import export
139
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module
1410
from graph_net.torch.fx_graph_cache_util import (
1511
parse_immutable_model_path_into_sole_graph_module,
1612
)
17-
from graph_net.imp_util import load_module
1813
import os
1914

2015

@@ -51,12 +46,12 @@ def make_config(self, pass_names=()):
5146
"naive_call_method_reshape_pass",
5247
"naive_call_method_expand_pass",
5348
"non_batch_call_method_expand_pass",
54-
"non_batch_call_function_arange_pass",
49+
"non_batch_call_function_arange_pass", # typos: skip
5550
"non_batch_call_function_getitem_slice_pass",
5651
"non_batch_call_function_full_pass",
5752
"non_batch_call_function_full_plus_one_pass",
5853
"non_batch_call_function_zeros_pass",
59-
"non_batch_call_function_arange_plus_one_pass",
54+
"non_batch_call_function_arange_plus_one_pass", # typos: skip
6055
)
6156
return {
6257
"pass_names": pass_names,
@@ -71,7 +66,7 @@ def need_rewrite(self, inputs):
7166
traced_module = self._create_fx_graph_module(inputs)
7267
logging.warning("after _create_fx_graph_module")
7368
ShapeProp(traced_module).propagate(*inputs)
74-
except Exception as e:
69+
except Exception:
7570
traceback.print_exc()
7671
return False
7772
return any(
@@ -112,4 +107,4 @@ def get_conditional_passes(self):
112107
]
113108

114109
def forward(self, *args, **kwargs):
115-
print(f"Do nothing.")
110+
print("Do nothing.")

0 commit comments

Comments
 (0)