Skip to content

Commit 6c61d60

Browse files
committed
fix codestyle
1 parent a243f3e commit 6c61d60

File tree

8 files changed

+41
-54
lines changed

8 files changed

+41
-54
lines changed

graph_net/analysis_util.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import os
22
import re
3-
import numpy as np
3+
import sys
44
from scipy.stats import gmean
5-
from collections import OrderedDict, defaultdict
65
from graph_net.config.datatype_tolerance_config import get_precision
76

87

@@ -265,7 +264,7 @@ def scan_all_folders(benchmark_path: str) -> dict:
265264
print(f"Detected log file: '{benchmark_path}'")
266265
samples = parse_logs_to_data(benchmark_path)
267266
if not samples:
268-
print(f" - No valid data found in log file.")
267+
print(" - No valid data found in log file.")
269268
return {}
270269

271270
folder_name = (

graph_net/constraint_util.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,12 @@ def __call__(self, model_path):
9090

9191
tensor_metas = self._get_tensor_metas(model_path)
9292
tensor_meta_attrs_list = [asdict(tensor_meta) for tensor_meta in tensor_metas]
93-
logging.warning(f"before create_inputs_by_metas")
93+
logging.warning("before create_inputs_by_metas")
9494
inputs = self.get_dimension_generalizer().create_inputs_by_metas(
9595
module=self.get_model(model_path),
9696
tensor_meta_attrs_list=tensor_meta_attrs_list,
9797
)
98-
logging.warning(f"after create_inputs_by_metas")
98+
logging.warning("after create_inputs_by_metas")
9999
dyn_dim_cstr = make_dyn_dim_cstr_from_tensor_metas(tensor_metas)
100100

101101
def data_input_predicator(input_var_name):
@@ -157,23 +157,23 @@ def get_model(self, model_path):
157157

158158
@contextmanager
159159
def _try_dimension_generalization(self, dim_axes_pairs, model_path, inputs):
160-
logging.warning(f"enter _try_dimension_generalization")
160+
logging.warning("enter _try_dimension_generalization")
161161
if self.config["dimension_generalizer_filepath"] is None:
162162
yield model_path, ()
163163
return
164164
model = self.get_model(model_path)
165165
dim_generalizer = self.get_dimension_generalizer()
166166
dim_gen_pass = dim_generalizer(model, dim_axes_pairs)
167-
logging.warning(f"before need_rewrite")
167+
logging.warning("before need_rewrite")
168168
need_rewrite = dim_gen_pass.need_rewrite(inputs)
169-
logging.warning(f"after need_rewrite")
169+
logging.warning("after need_rewrite")
170170
if not need_rewrite:
171171
yield model_path, ()
172172
return
173173

174-
logging.warning(f"before rewrite")
174+
logging.warning("before rewrite")
175175
graph_module = dim_gen_pass.rewrite(inputs)
176-
logging.warning(f"after rewrite")
176+
logging.warning("after rewrite")
177177
with tempfile.TemporaryDirectory() as tmp_dir:
178178
shutil.copytree(Path(model_path), Path(tmp_dir), dirs_exist_ok=True)
179179
dim_gen_pass.save_graph_module(graph_module, tmp_dir)
@@ -344,17 +344,17 @@ def filter_fn(input_name, input_idx, axis, dim):
344344
(dim, axes) for dim in unqiue_dims[: i + 1] for axes in [dim2axes[dim]]
345345
)
346346
ctx_mgr = dyn_dim_cstr_feasibility_ctx_mgr
347-
logging.warning(f"before dyn_dim_cstr_feasibility_ctx_mgr")
347+
logging.warning("before dyn_dim_cstr_feasibility_ctx_mgr")
348348
with ctx_mgr(dim_axes_pairs) as dyn_dim_cstr_feasibility:
349-
logging.warning(f"enter dyn_dim_cstr_feasibility_ctx_mgr")
349+
logging.warning("enter dyn_dim_cstr_feasibility_ctx_mgr")
350350
tmp_dyn_dim_cstr = copy.deepcopy(cur_dyn_dim_cstr)
351351
tmp_dyn_dim_cstr.update_symbol2example_value(sym2example_value)
352-
logging.warning(f"before dyn_dim_cstr_feasibility")
352+
logging.warning("before dyn_dim_cstr_feasibility")
353353
is_dyn_dim_cstr_feasible = dyn_dim_cstr_feasibility(tmp_dyn_dim_cstr)
354-
logging.warning(f"after dyn_dim_cstr_feasibility")
354+
logging.warning("after dyn_dim_cstr_feasibility")
355355
if not is_dyn_dim_cstr_feasible:
356356
continue
357357
dyn_dim_cstr = cur_dyn_dim_cstr
358358
append_dim_gen_pass_names(dyn_dim_cstr_feasibility.dim_gen_pass_names)
359-
logging.warning(f"leave dyn_dim_cstr_feasibility_ctx_mgr")
359+
logging.warning("leave dyn_dim_cstr_feasibility_ctx_mgr")
360360
return dyn_dim_cstr, total_dim_gen_pass_names

graph_net/dynamic_dim_constraints.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
import sys
21
import sympy
32
import importlib.util as imp
43
from dataclasses import dataclass
5-
import copy
64
from typing import Callable
75
from collections import namedtuple
86

@@ -22,7 +20,7 @@ class DynamicDimConstraints:
2220
kRelations = "dynamic_dim_constraint_relations"
2321

2422
# len(input_shapes) equals number of Model.forward arguments
25-
input_shapes: list[(tuple[sympy.Expr | int], "var-name")]
23+
input_shapes: list[(tuple[sympy.Expr | int], str)]
2624
kInputShapes = "dynamic_dim_constraint_input_shapes"
2725

2826
@classmethod
@@ -44,7 +42,6 @@ def symbolize(
4442
]
4543
Returns created symbol.
4644
"""
47-
import logging
4845

4946
InputDim = namedtuple("InputDim", ["input_idx", "axis", "dim"])
5047
input_dims = [

graph_net/model_path_handler.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,10 @@
11
import traceback
22
import argparse
3-
import importlib.util
43
from graph_net.imp_util import load_module
5-
import inspect
64
import logging
7-
from pathlib import Path
8-
from typing import Type, Any
95
import sys
106
import json
117
import base64
12-
from contextlib import contextmanager
13-
import logging
148

159
logging.basicConfig(
1610
level=logging.WARNING, format="%(asctime)s [%(levelname)s] %(message)s"

graph_net/plot_ESt.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def print_verification_result(
8888
)
8989
else:
9090
print(
91-
f"t={tolerance:3d}: MISMATCH - Microscopic: {microscopic_es:.6f}, Aggregated: {aggregated_es:.6f}, Diff: {diff:.2e} ({relative_diff*100:.4f}%)"
91+
f"t={tolerance:3d}: MISMATCH - Microscopic: {microscopic_es:.6f}, Aggregated: {aggregated_es:.6f}, Diff: {diff:.2e} ({relative_diff * 100:.4f}%)"
9292
)
9393

9494

@@ -110,9 +110,9 @@ def get_verified_aggregated_es_values(es_scores: dict, folder_name: str) -> dict
110110
verified_es_values = {}
111111
mismatches = []
112112

113-
print(f"\n{'='*80}")
113+
print(f"\n{'=' * 80}")
114114
print(f"Verifying Aggregated/Microscopic Consistency for '{folder_name}'")
115-
print(f"{'='*80}")
115+
print(f"{'=' * 80}")
116116

117117
for tolerance, microscopic_es in es_scores.items():
118118
aggregated_es = aggregated_results.get(tolerance)
@@ -136,29 +136,29 @@ def get_verified_aggregated_es_values(es_scores: dict, folder_name: str) -> dict
136136
elif not is_matched:
137137
mismatches.append(
138138
f"t={tolerance}: Mismatch - Microscopic={microscopic_es:.6f}, "
139-
f"Aggregated={aggregated_es:.6f}, Diff={diff:.2e} ({relative_diff*100:.4f}%)"
139+
f"Aggregated={aggregated_es:.6f}, Diff={diff:.2e} ({relative_diff * 100:.4f}%)"
140140
)
141141
else:
142142
verified_es_values[tolerance] = microscopic_es
143143

144144
if mismatches:
145145
error_msg = (
146-
f"\n{'='*80}\n"
146+
f"\n{'=' * 80}\n"
147147
f"ERROR: Aggregated and microscopic results do not match for '{folder_name}'!\n"
148-
f"{'='*80}\n"
148+
f"{'=' * 80}\n"
149149
f"Mismatches:\n"
150150
+ "\n".join(f" - {mismatch}" for mismatch in mismatches)
151151
+ f"\n\nCalculation validation failed. Please verify the calculation logic "
152152
f"using verify_aggregated_params.py\n"
153-
f"{'='*80}\n"
153+
f"{'=' * 80}\n"
154154
)
155155
print(error_msg)
156156
raise AssertionError(error_msg)
157157

158158
print(
159159
f"\nSUCCESS: All aggregated and microscopic results match for '{folder_name}'."
160160
)
161-
print(f"{'='*80}\n")
161+
print(f"{'=' * 80}\n")
162162
return verified_es_values
163163

164164

graph_net/test_compiler_util.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def get_device_utilization(device_id, device_count, synchronizer_func):
5656
cmd = [
5757
"nvidia-smi",
5858
f"--id={selected_gpu_id}",
59-
f"--query-gpu=index,gpu_uuid,utilization.gpu,memory.used,memory.total",
59+
"--query-gpu=index,gpu_uuid,utilization.gpu,memory.used,memory.total",
6060
"--format=csv,noheader,nounits",
6161
]
6262
output = subprocess.check_output(cmd).decode().strip()
@@ -78,7 +78,7 @@ def get_device_utilization(device_id, device_count, synchronizer_func):
7878
cmd = [
7979
"nvidia-smi",
8080
f"--id={selected_gpu_id}",
81-
f"--query-compute-apps=gpu_uuid,pid,used_memory",
81+
"--query-compute-apps=gpu_uuid,pid,used_memory",
8282
"--format=csv,noheader,nounits",
8383
]
8484
output = subprocess.check_output(cmd).decode().strip()
@@ -126,14 +126,14 @@ def get_model_name(model_path):
126126

127127
if model_name is None:
128128
fields = model_path.split(os.sep)
129-
pattern = rf"^subgraph(_\d+)?$"
129+
pattern = r"^subgraph(_\d+)?$"
130130
model_name = fields[-2] if re.match(pattern, fields[-1]) else fields[-1]
131131
return model_name
132132

133133

134134
def get_subgraph_tag(model_path):
135135
fields = model_path.split(os.sep)
136-
pattern = rf"^subgraph(_\d+)?$"
136+
pattern = r"^subgraph(_\d+)?$"
137137
return fields[-1] if re.match(pattern, fields[-1]) else ""
138138

139139

graph_net/verify_aggregated_params.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
1-
import os
21
import argparse
3-
import numpy as np
42
from collections import OrderedDict, Counter
53
from graph_net import analysis_util
64
from graph_net import samples_statistics
75
from graph_net.samples_statistics import (
86
get_errno_from_error_type,
9-
get_error_type_from_errno,
107
)
118

129

@@ -298,9 +295,9 @@ def verify_es_constructor_params_across_tolerances(
298295
"""
299296
total_samples = len(samples)
300297

301-
print(f"\n{'='*80}")
298+
print(f"\n{'=' * 80}")
302299
print(f"Verifying Aggregated Parameters for '{folder_name}'")
303-
print(f"{'='*80}")
300+
print(f"{'=' * 80}")
304301

305302
tolerances = determine_tolerances(samples)
306303
builder = ToleranceReportBuilder(
@@ -314,9 +311,9 @@ def verify_es_constructor_params_across_tolerances(
314311
(tolerance, builder.build_report(tolerance)) for tolerance in tolerances
315312
)
316313

317-
print(f"\n{'='*80}")
318-
print(f"Aggregated Parameter Verification Complete")
319-
print(f"{'='*80}\n")
314+
print(f"\n{'=' * 80}")
315+
print("Aggregated Parameter Verification Complete")
316+
print(f"{'=' * 80}\n")
320317

321318
return results
322319

@@ -355,12 +352,12 @@ def main():
355352

356353
# Calculate and print aggregated parameters for each curve
357354
for folder_name, samples in all_results.items():
358-
aggregated_results = verify_es_constructor_params_across_tolerances(
355+
_ = verify_es_constructor_params_across_tolerances(
359356
samples,
360357
folder_name,
361358
negative_speedup_penalty=args.negative_speedup_penalty,
362359
fpdb=args.fpdb,
363-
)
360+
) # noqa: F841
364361

365362

366363
if __name__ == "__main__":

tools/update_model_configs.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def process_single_model(model_dir_path: str, root_dir: str, failures: list):
7777
relative_path = os.path.relpath(model_dir_path, root_dir)
7878
potential_model_id = relative_path.replace(os.path.sep, "/")
7979

80-
print(f"\n{'='*20}\nProcessing: {potential_model_id} (at {model_dir_path})")
80+
print(f"\n{'=' * 20}\nProcessing: {potential_model_id} (at {model_dir_path})")
8181

8282
try:
8383
with open(json_path, "r", encoding="utf-8") as f:
@@ -90,9 +90,9 @@ def process_single_model(model_dir_path: str, root_dir: str, failures: list):
9090
print(f" [Phase 1a] Trying precise lookup for '{potential_model_id}'...")
9191
try:
9292
best_match_info = model_info(potential_model_id)
93-
print(f" [Phase 1a] Success! Found exact match from path.")
93+
print(" [Phase 1a] Success! Found exact match from path.")
9494
except RepositoryNotFoundError:
95-
print(f" [Phase 1a] Precise lookup failed.")
95+
print(" [Phase 1a] Precise lookup failed.")
9696
# Phase 1b: If that fails, and it's a single-level dir, try replacing the first '_' with '/'
9797
if "/" not in potential_model_id and "_" in potential_model_id:
9898
hypothetical_id = potential_model_id.replace("_", "/", 1)
@@ -102,10 +102,10 @@ def process_single_model(model_dir_path: str, root_dir: str, failures: list):
102102
try:
103103
best_match_info = model_info(hypothetical_id)
104104
print(
105-
f" [Phase 1b] Success! Found exact match by replacing underscore."
105+
" [Phase 1b] Success! Found exact match by replacing underscore."
106106
)
107107
except RepositoryNotFoundError:
108-
print(f" [Phase 1b] Alternative lookup failed.")
108+
print(" [Phase 1b] Alternative lookup failed.")
109109

110110
# --- Stage 2: Advanced Fuzzy Search (Fallback) ---
111111
if not best_match_info:
@@ -174,7 +174,7 @@ def process_model_directories(root_dir: str):
174174
"""
175175
if not root_dir or not os.path.isdir(root_dir):
176176
print(
177-
f"❌ ERROR: Root directory not provided or not found. Please use the --directory argument."
177+
"❌ ERROR: Root directory not provided or not found. Please use the --directory argument."
178178
)
179179
return
180180

@@ -199,7 +199,7 @@ def process_model_directories(root_dir: str):
199199
for model_path in model_paths_to_process:
200200
process_single_model(model_path, root_dir, failures)
201201

202-
print(f"\n{'='*40}\n🎉 All directories processed!")
202+
print(f"\n{'=' * 40}\n🎉 All directories processed!")
203203

204204
if failures:
205205
log_and_print_failures(failures)

0 commit comments

Comments
 (0)