Skip to content

Commit dfddea4

Browse files
committed
Update
1 parent d698d66 commit dfddea4

File tree

2 files changed

+50
-133
lines changed

2 files changed

+50
-133
lines changed

graph_net/analysis_util.py

Lines changed: 3 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,11 @@
11
import os
2-
import json
32
import re
43
import numpy as np
54
from scipy.stats import gmean
65
from collections import OrderedDict, defaultdict
76
from graph_net.config.datatype_tolerance_config import get_precision
87

98

10-
def extract_speedup_data_from_subdirs(benchmark_path: str) -> dict:
11-
"""
12-
Reads speedup data from JSON files within each immediate subdirectory of the benchmark_path.
13-
Each subdirectory is treated as a separate category.
14-
Returns a dictionary mapping {subdir_name: [speedup_values]}.
15-
"""
16-
data_by_subdir = defaultdict(list)
17-
18-
if not os.path.exists(benchmark_path):
19-
print(f"Error: Path does not exist -> {benchmark_path}")
20-
return {}
21-
22-
try:
23-
subdirs = [
24-
d
25-
for d in os.listdir(benchmark_path)
26-
if os.path.isdir(os.path.join(benchmark_path, d))
27-
]
28-
except FileNotFoundError:
29-
print(f"Error: Benchmark path not found -> {benchmark_path}")
30-
return {}
31-
32-
if not subdirs:
33-
print(f"Warning: No subdirectories found in -> {benchmark_path}")
34-
return {}
35-
36-
print(f"Found subdirectories to process: {', '.join(subdirs)}")
37-
38-
for subdir_name in subdirs:
39-
current_dir_path = os.path.join(benchmark_path, subdir_name)
40-
# Using scan_all_folders and load_one_folder could be an alternative,
41-
# but os.walk is also robust for nested directories if needed in the future.
42-
for root, _, files in os.walk(current_dir_path):
43-
for file in files:
44-
if not file.endswith(".json"):
45-
continue
46-
47-
json_file = os.path.join(root, file)
48-
try:
49-
with open(json_file, "r") as f:
50-
data = json.load(f)
51-
performance = data.get("performance", {})
52-
if not performance:
53-
continue
54-
55-
speedup_data = performance.get("speedup")
56-
if isinstance(speedup_data, dict):
57-
# Prioritize 'e2e' speedup, fallback to 'gpu'
58-
if "e2e" in speedup_data:
59-
data_by_subdir[subdir_name].append(speedup_data["e2e"])
60-
elif "gpu" in speedup_data:
61-
data_by_subdir[subdir_name].append(speedup_data["gpu"])
62-
elif isinstance(speedup_data, (float, int)):
63-
data_by_subdir[subdir_name].append(speedup_data)
64-
65-
except (json.JSONDecodeError, KeyError) as e:
66-
print(
67-
f"Warning: Failed to read or parse file -> {json_file}, Error: {e}"
68-
)
69-
continue
70-
71-
return data_by_subdir
72-
73-
74-
def load_json_file(filepath: str) -> dict:
75-
"""
76-
Safely load a JSON file and return data, return an empty dictionary if loading fails.
77-
"""
78-
try:
79-
with open(filepath, "r", encoding="utf-8") as f:
80-
return json.load(f)
81-
except (json.JSONDecodeError, KeyError) as e:
82-
print(f" Warning: Could not process file {filepath}. Error: {e}")
83-
return {}
84-
85-
869
def detect_sample_error_code(log_text: str) -> str:
8710
"""
8811
Detect the error code for a single sample from log text.
@@ -154,8 +77,8 @@ def parse_logs_to_data(log_file: str) -> list:
15477
Parse a structured log file generated by the benchmark script and
15578
return a list of data dictionaries (one per model-compiler run).
15679
157-
This function directly parses log files without generating intermediate JSON files.
158-
It automatically handles both Paddle (with subgraph) and PyTorch (without subgraph) samples.
80+
This function directly parses log files,
81+
handling both Paddle (with subgraph) and PyTorch (without subgraph) samples.
15982
16083
Args:
16184
log_file: Path to the benchmark log file
@@ -229,8 +152,7 @@ def parse_logs_to_data(log_file: str) -> list:
229152
performance_match = patterns["performance"].search(line)
230153
if performance_match:
231154
key, value_str = performance_match.groups()
232-
# The performance value is a JSON string, so we load it
233-
data["performance"][key.strip()] = json.loads(value_str)
155+
data["performance"][key.strip()] = value_str.strip()
234156
continue
235157

236158
datatype_match = patterns["datatype"].search(line)
@@ -409,7 +331,6 @@ def get_correctness(dtype: str, t: int, correctness_data: dict, index: int) -> b
409331
if atol == 0 and rtol == 0:
410332
metric_key_to_check = "[equal]"
411333
else:
412-
# Use .2E format to ensure two decimal places and use uppercase E to match JSON log format
413334
metric_key_to_check = f"[all_close_atol_{atol:.2E}_rtol_{rtol:.2E}]"
414335

415336
result = correctness_data.get(metric_key_to_check)

graph_net/plot_ESt.py

Lines changed: 47 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def get_verified_aggregated_es_values(es_scores: dict, folder_name: str) -> dict
162162
return verified_es_values
163163

164164

165-
def plot_ES_results(s_scores: dict, cli_args: argparse.Namespace):
165+
def plot_ES_results(s_scores: dict, args: argparse.Namespace):
166166
"""
167167
Plot ES(t) curve
168168
"""
@@ -179,8 +179,7 @@ def plot_ES_results(s_scores: dict, cli_args: argparse.Namespace):
179179
for (
180180
t_key,
181181
score_data,
182-
) in scores_dict.items(): # Change variable name to score_data
183-
# Access the 'score' key from the nested dictionary
182+
) in scores_dict.items():
184183
if isinstance(score_data, dict):
185184
score = score_data["score"]
186185
else:
@@ -234,8 +233,8 @@ def plot_ES_results(s_scores: dict, cli_args: argparse.Namespace):
234233
markersize=6,
235234
)
236235

237-
p = cli_args.negative_speedup_penalty
238-
config = f"p = {p}, b = {cli_args.fpdb}"
236+
p = args.negative_speedup_penalty
237+
config = f"p = {p}, b = {args.fpdb}"
239238
fig.text(0.5, 0.9, config, ha="center", fontsize=16, style="italic")
240239

241240
ax.set_xlabel("t", fontsize=18)
@@ -253,51 +252,7 @@ def plot_ES_results(s_scores: dict, cli_args: argparse.Namespace):
253252
return fig, ax, all_x_coords
254253

255254

256-
def main():
257-
"""Main execution function for plotting ES(t)."""
258-
parser = argparse.ArgumentParser(
259-
description="Calculate and plot ES(t) scores from benchmark results.",
260-
formatter_class=argparse.RawTextHelpFormatter,
261-
)
262-
# Add arguments (same as plot_St)
263-
parser.add_argument(
264-
"--benchmark-path",
265-
type=str,
266-
required=True,
267-
help="Path to the benchmark log file or directory containing benchmark JSON files or sub-folders.",
268-
)
269-
parser.add_argument(
270-
"--output-dir",
271-
type=str,
272-
default="analysis_results",
273-
help="Output directory for saving the plot. Default: analysis_results",
274-
)
275-
parser.add_argument(
276-
"--negative-speedup-penalty",
277-
type=float,
278-
default=0.0,
279-
help="Penalty power (p) for negative speedup. Formula: speedup**(p+1). Default: 0.0.",
280-
)
281-
parser.add_argument(
282-
"--fpdb",
283-
type=float,
284-
default=0.1,
285-
help="Base penalty for severe errors (e.g., crashes, correctness failures).",
286-
)
287-
parser.add_argument(
288-
"--enable-aggregation-mode",
289-
action="store_true",
290-
help="Enable aggregation mode to verify aggregated/microscopic consistency. Default: enabled.",
291-
)
292-
parser.add_argument(
293-
"--disable-aggregation-mode",
294-
dest="enable_aggregation_mode",
295-
action="store_false",
296-
help="Disable aggregation mode verification.",
297-
)
298-
parser.set_defaults(enable_aggregation_mode=True)
299-
args = parser.parse_args()
300-
255+
def main(args):
301256
# 1. Scan folders to get data
302257
all_results = analysis_util.scan_all_folders(args.benchmark_path)
303258
if not all_results:
@@ -433,4 +388,45 @@ def main():
433388

434389

435390
if __name__ == "__main__":
436-
main()
391+
parser = argparse.ArgumentParser(
392+
description="Calculate and plot ES(t) scores from benchmark results.",
393+
formatter_class=argparse.RawTextHelpFormatter,
394+
)
395+
parser.add_argument(
396+
"--benchmark-path",
397+
type=str,
398+
required=True,
399+
help="Path to the benchmark log file or directory containing benchmark JSON files or sub-folders.",
400+
)
401+
parser.add_argument(
402+
"--output-dir",
403+
type=str,
404+
default="analysis_results",
405+
help="Output directory for saving the plot. Default: analysis_results",
406+
)
407+
parser.add_argument(
408+
"--negative-speedup-penalty",
409+
type=float,
410+
default=0.0,
411+
help="Penalty power (p) for negative speedup. Formula: speedup**(p+1). Default: 0.0.",
412+
)
413+
parser.add_argument(
414+
"--fpdb",
415+
type=float,
416+
default=0.1,
417+
help="Base penalty for severe errors (e.g., crashes, correctness failures).",
418+
)
419+
parser.add_argument(
420+
"--enable-aggregation-mode",
421+
action="store_true",
422+
help="Enable aggregation mode to verify aggregated/microscopic consistency. Default: enabled.",
423+
)
424+
parser.add_argument(
425+
"--disable-aggregation-mode",
426+
dest="enable_aggregation_mode",
427+
action="store_false",
428+
help="Disable aggregation mode verification.",
429+
)
430+
parser.set_defaults(enable_aggregation_mode=True)
431+
args = parser.parse_args()
432+
main(args)

0 commit comments

Comments
 (0)