Skip to content

Commit f896b28

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/GraphNet into develop
2 parents 3738057 + bdd6e35 commit f896b28

19 files changed

+541
-87
lines changed

.pre-commit-config.yaml

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,25 @@ repos:
33
rev: 23.1.0
44
hooks:
55
- id: black
6-
language_version: python3
6+
language_version: python3
7+
8+
- repo: https://github.com/astral-sh/ruff-pre-commit
9+
rev: v0.14.4
10+
hooks:
11+
- id: ruff-check
12+
args: [--fix, --exit-non-zero-on-fix, --no-cache]
13+
14+
- repo: https://github.com/PFCCLab/typos-pre-commit-mirror.git
15+
rev: v1.39.2
16+
hooks:
17+
- id: typos
18+
args: [--force-exclude]
19+
20+
- repo: https://github.com/Lucas-C/pre-commit-hooks.git
21+
rev: v1.5.1
22+
hooks:
23+
- id: remove-crlf
24+
- id: remove-tabs
25+
name: Tabs remver (Python)
26+
files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
27+
args: [--whitespaces-count, '4']

graph_net/__init__.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
1-
__all__ = ["torch", "paddle"]
2-
3-
from importlib import import_module
4-
from typing import TYPE_CHECKING, Any, List
5-
6-
7-
def __getattr__(name: str) -> Any:
8-
if name in __all__:
9-
module = import_module(f"{__name__}.{name}")
10-
globals()[name] = module
11-
return module
12-
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
13-
14-
15-
def __dir__() -> List[str]:
16-
return sorted(list(globals().keys()) + __all__)
17-
18-
19-
if TYPE_CHECKING:
20-
from . import torch as torch # type: ignore
21-
from . import paddle as paddle # type: ignore
1+
__all__ = ["torch", "paddle"]
2+
3+
from importlib import import_module
4+
from typing import TYPE_CHECKING, Any, List
5+
6+
7+
def __getattr__(name: str) -> Any:
8+
if name in __all__:
9+
module = import_module(f"{__name__}.{name}")
10+
globals()[name] = module
11+
return module
12+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
13+
14+
15+
def __dir__() -> List[str]:
16+
return sorted(list(globals().keys()) + __all__)
17+
18+
19+
if TYPE_CHECKING:
20+
from . import torch as torch # type: ignore
21+
from . import paddle as paddle # type: ignore

graph_net/analysis_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import re
3-
import numpy as np
3+
import sys
44
from scipy.stats import gmean
55
from graph_net.config.datatype_tolerance_config import get_precision
66

@@ -213,7 +213,7 @@ def scan_all_folders(benchmark_path: str) -> dict:
213213
print(f"Detected log file: '{benchmark_path}'")
214214
samples = parse_logs_to_data(benchmark_path)
215215
if not samples:
216-
print(f" - No valid data found in log file.")
216+
print(" - No valid data found in log file.")
217217
return {}
218218

219219
folder_name = (

graph_net/constraint_util.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,12 @@ def __call__(self, model_path):
9090

9191
tensor_metas = self._get_tensor_metas(model_path)
9292
tensor_meta_attrs_list = [asdict(tensor_meta) for tensor_meta in tensor_metas]
93-
logging.warning(f"before create_inputs_by_metas")
93+
logging.warning("before create_inputs_by_metas")
9494
inputs = self.get_dimension_generalizer().create_inputs_by_metas(
9595
module=self.get_model(model_path),
9696
tensor_meta_attrs_list=tensor_meta_attrs_list,
9797
)
98-
logging.warning(f"after create_inputs_by_metas")
98+
logging.warning("after create_inputs_by_metas")
9999
dyn_dim_cstr = make_dyn_dim_cstr_from_tensor_metas(tensor_metas)
100100

101101
def data_input_predicator(input_var_name):
@@ -157,23 +157,23 @@ def get_model(self, model_path):
157157

158158
@contextmanager
159159
def _try_dimension_generalization(self, dim_axes_pairs, model_path, inputs):
160-
logging.warning(f"enter _try_dimension_generalization")
160+
logging.warning("enter _try_dimension_generalization")
161161
if self.config["dimension_generalizer_filepath"] is None:
162162
yield model_path, ()
163163
return
164164
model = self.get_model(model_path)
165165
dim_generalizer = self.get_dimension_generalizer()
166166
dim_gen_pass = dim_generalizer(model, dim_axes_pairs)
167-
logging.warning(f"before need_rewrite")
167+
logging.warning("before need_rewrite")
168168
need_rewrite = dim_gen_pass.need_rewrite(inputs)
169-
logging.warning(f"after need_rewrite")
169+
logging.warning("after need_rewrite")
170170
if not need_rewrite:
171171
yield model_path, ()
172172
return
173173

174-
logging.warning(f"before rewrite")
174+
logging.warning("before rewrite")
175175
graph_module = dim_gen_pass.rewrite(inputs)
176-
logging.warning(f"after rewrite")
176+
logging.warning("after rewrite")
177177
with tempfile.TemporaryDirectory() as tmp_dir:
178178
shutil.copytree(Path(model_path), Path(tmp_dir), dirs_exist_ok=True)
179179
dim_gen_pass.save_graph_module(graph_module, tmp_dir)
@@ -293,20 +293,20 @@ def symbolize_data_input_dims(
293293
Returns new DynamicDimConstraints if success.
294294
Returns None if no symbolicable dim .
295295
"""
296-
unqiue_dims = []
296+
unique_dims = []
297297
dim2axes = {}
298298

299299
def dumpy_filter_fn(input_name, input_idx, axis, dim):
300300
if is_data_input(input_name):
301301
print("data_input", input_name, input_idx, axis, dim)
302-
if dim not in unqiue_dims:
303-
unqiue_dims.append(dim)
302+
if dim not in unique_dims:
303+
unique_dims.append(dim)
304304
dim2axes[dim] = []
305305
dim2axes[dim].append(axis)
306306
# No symbolization by returning False
307307
return False
308308

309-
# Collect input dimensions into `unqiue_dims`
309+
# Collect input dimensions into `unique_dims`
310310
assert dyn_dim_cstr.symbolize(dumpy_filter_fn) is None
311311
total_dim_gen_pass_names = ()
312312

@@ -323,7 +323,7 @@ def append_dim_gen_pass_names(dim_gen_pass_names):
323323
]
324324
)
325325

326-
for i, picked_dim in enumerate(unqiue_dims):
326+
for i, picked_dim in enumerate(unique_dims):
327327
logging.warning(f"{i=} {picked_dim=}")
328328
cur_dyn_dim_cstr = copy.deepcopy(dyn_dim_cstr)
329329

@@ -341,20 +341,20 @@ def filter_fn(input_name, input_idx, axis, dim):
341341
if not cur_dyn_dim_cstr.check_delta_symbol2example_value(sym2example_value):
342342
continue
343343
dim_axes_pairs = tuple(
344-
(dim, axes) for dim in unqiue_dims[: i + 1] for axes in [dim2axes[dim]]
344+
(dim, axes) for dim in unique_dims[: i + 1] for axes in [dim2axes[dim]]
345345
)
346346
ctx_mgr = dyn_dim_cstr_feasibility_ctx_mgr
347-
logging.warning(f"before dyn_dim_cstr_feasibility_ctx_mgr")
347+
logging.warning("before dyn_dim_cstr_feasibility_ctx_mgr")
348348
with ctx_mgr(dim_axes_pairs) as dyn_dim_cstr_feasibility:
349-
logging.warning(f"enter dyn_dim_cstr_feasibility_ctx_mgr")
349+
logging.warning("enter dyn_dim_cstr_feasibility_ctx_mgr")
350350
tmp_dyn_dim_cstr = copy.deepcopy(cur_dyn_dim_cstr)
351351
tmp_dyn_dim_cstr.update_symbol2example_value(sym2example_value)
352-
logging.warning(f"before dyn_dim_cstr_feasibility")
352+
logging.warning("before dyn_dim_cstr_feasibility")
353353
is_dyn_dim_cstr_feasible = dyn_dim_cstr_feasibility(tmp_dyn_dim_cstr)
354-
logging.warning(f"after dyn_dim_cstr_feasibility")
354+
logging.warning("after dyn_dim_cstr_feasibility")
355355
if not is_dyn_dim_cstr_feasible:
356356
continue
357357
dyn_dim_cstr = cur_dyn_dim_cstr
358358
append_dim_gen_pass_names(dyn_dim_cstr_feasibility.dim_gen_pass_names)
359-
logging.warning(f"leave dyn_dim_cstr_feasibility_ctx_mgr")
359+
logging.warning("leave dyn_dim_cstr_feasibility_ctx_mgr")
360360
return dyn_dim_cstr, total_dim_gen_pass_names

graph_net/dynamic_dim_constraints.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
import sys
21
import sympy
32
import importlib.util as imp
43
from dataclasses import dataclass
5-
import copy
64
from typing import Callable
75
from collections import namedtuple
86

@@ -22,7 +20,7 @@ class DynamicDimConstraints:
2220
kRelations = "dynamic_dim_constraint_relations"
2321

2422
# len(input_shapes) equals number of Model.forward arguments
25-
input_shapes: list[(tuple[sympy.Expr | int], "var-name")]
23+
input_shapes: list[(tuple[sympy.Expr | int], str)]
2624
kInputShapes = "dynamic_dim_constraint_input_shapes"
2725

2826
@classmethod
@@ -44,7 +42,6 @@ def symbolize(
4442
]
4543
Returns created symbol.
4644
"""
47-
import logging
4845

4946
InputDim = namedtuple("InputDim", ["input_idx", "axis", "dim"])
5047
input_dims = [
@@ -157,7 +154,7 @@ def get_module_attr(cls, module, attr, default):
157154
return getattr(module, attr) if hasattr(module, attr) else default
158155

159156
@classmethod
160-
def load_module(cls, path, name="unamed"):
157+
def load_module(cls, path, name="unnamed"):
161158
spec = imp.spec_from_file_location(name, path)
162159
module = imp.module_from_spec(spec)
163160
spec.loader.exec_module(module)

graph_net/imp_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import importlib.util as imp
22

33

4-
def load_module(path, name="unamed"):
4+
def load_module(path, name="unnamed"):
55
spec = imp.spec_from_file_location(name, path)
66
module = imp.module_from_spec(spec)
77
module.__file__ = path

graph_net/model_path_handler.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,10 @@
11
import traceback
22
import argparse
3-
import importlib.util
43
from graph_net.imp_util import load_module
5-
import inspect
64
import logging
7-
from pathlib import Path
8-
from typing import Type, Any
95
import sys
106
import json
117
import base64
12-
from contextlib import contextmanager
13-
import logging
148

159
logging.basicConfig(
1610
level=logging.WARNING, format="%(asctime)s [%(levelname)s] %(message)s"

graph_net/plot_ESt.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def print_verification_result(
8888
)
8989
else:
9090
print(
91-
f"t={tolerance:3d}: MISMATCH - Microscopic: {microscopic_es:.6f}, Aggregated: {aggregated_es:.6f}, Diff: {diff:.2e} ({relative_diff*100:.4f}%)"
91+
f"t={tolerance:3d}: MISMATCH - Microscopic: {microscopic_es:.6f}, Aggregated: {aggregated_es:.6f}, Diff: {diff:.2e} ({relative_diff * 100:.4f}%)"
9292
)
9393

9494

@@ -110,9 +110,9 @@ def get_verified_aggregated_es_values(es_scores: dict, folder_name: str) -> dict
110110
verified_es_values = {}
111111
mismatches = []
112112

113-
print(f"\n{'='*80}")
113+
print(f"\n{'=' * 80}")
114114
print(f"Verifying Aggregated/Microscopic Consistency for '{folder_name}'")
115-
print(f"{'='*80}")
115+
print(f"{'=' * 80}")
116116

117117
for tolerance, microscopic_es in es_scores.items():
118118
aggregated_es = aggregated_results.get(tolerance)
@@ -136,29 +136,29 @@ def get_verified_aggregated_es_values(es_scores: dict, folder_name: str) -> dict
136136
elif not is_matched:
137137
mismatches.append(
138138
f"t={tolerance}: Mismatch - Microscopic={microscopic_es:.6f}, "
139-
f"Aggregated={aggregated_es:.6f}, Diff={diff:.2e} ({relative_diff*100:.4f}%)"
139+
f"Aggregated={aggregated_es:.6f}, Diff={diff:.2e} ({relative_diff * 100:.4f}%)"
140140
)
141141
else:
142142
verified_es_values[tolerance] = microscopic_es
143143

144144
if mismatches:
145145
error_msg = (
146-
f"\n{'='*80}\n"
146+
f"\n{'=' * 80}\n"
147147
f"ERROR: Aggregated and microscopic results do not match for '{folder_name}'!\n"
148-
f"{'='*80}\n"
148+
f"{'=' * 80}\n"
149149
f"Mismatches:\n"
150150
+ "\n".join(f" - {mismatch}" for mismatch in mismatches)
151151
+ f"\n\nCalculation validation failed. Please verify the calculation logic "
152152
f"using verify_aggregated_params.py\n"
153-
f"{'='*80}\n"
153+
f"{'=' * 80}\n"
154154
)
155155
print(error_msg)
156156
raise AssertionError(error_msg)
157157

158158
print(
159159
f"\nSUCCESS: All aggregated and microscopic results match for '{folder_name}'."
160160
)
161-
print(f"{'='*80}\n")
161+
print(f"{'=' * 80}\n")
162162
return verified_es_values
163163

164164

graph_net/status.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def main():
2828
prog="python -m graph_net.status",
2929
description="List contents of the $GRAPH_NET_EXTRACT_WORKSPACE directory (like ls)",
3030
)
31-
args = parser.parse_args()
31+
args = parser.parse_args() # noqa: F841
3232

3333
ws = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE")
3434
if not ws:

graph_net/tensor_meta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def _convert_cls_to_attrs(cls, tensor_meta_cls):
4848
return attrs
4949

5050
@classmethod
51-
def _get_classes(cls, file_path, name="unamed"):
51+
def _get_classes(cls, file_path, name="unnamed"):
5252
spec = imp.spec_from_file_location("unnamed", file_path)
5353
unnamed = imp.module_from_spec(spec)
5454
spec.loader.exec_module(unnamed)

0 commit comments

Comments
 (0)