Skip to content

Commit 56492c3

Browse files
authored
fix quantiles and merge_csv issues in bench (#221)
1 parent e527312 commit 56492c3

File tree

4 files changed

+139
-6
lines changed

4 files changed

+139
-6
lines changed

exps/attn/run_benchmark.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,10 @@ def fn():
719719
def ms_to_tflops(ms: float) -> float:
720720
return attn_flops / ms * 1e-9
721721

722-
perf_dict["flops"] = list(map(ms_to_tflops, perf_dict["flops"]))
722+
flops = perf_dict["flops"]
723+
if not isinstance(flops, list):
724+
flops = [flops] # type: ignore[unreachable]
725+
perf_dict["flops"] = list(map(ms_to_tflops, flops))
723726

724727
# disable mem test
725728
# def gb(m):

exps/dist_attn/merge_draw_csv.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright (c) 2025-2026 SandAI. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from pathlib import Path
16+
from typing import Dict, List
17+
18+
import pandas as pd
19+
from baselines.interface import AttnImpl
20+
21+
from magi_attention.benchmarking import Mark
22+
23+
24+
def append_files_with_value_prefix(
25+
files_pack: List[str],
26+
expected_cols: List[str],
27+
prefix_list: List[str],
28+
output_path: str,
29+
short_for_xlables: Dict[str, str] | None = None,
30+
):
31+
"""
32+
Append rows from multiple CSV files in order.
33+
Add prefix to the VALUES of expected_cols[0] for each file.
34+
Optionally replace values using a mapping dictionary.
35+
36+
Args:
37+
files_pack (list[str]): Ordered list of CSV files.
38+
expected_cols (list[str]): Columns to extract (same for all files).
39+
Prefix is applied to expected_cols[0].
40+
prefix_list (list[str]): Prefix per file (same length as files_pack).
41+
output_path (str): Output CSV path.
42+
short_for_xlables (dict[str,str] | None): Optional mapping for replacing values
43+
after prefixing.
44+
"""
45+
46+
if len(files_pack) != len(prefix_list):
47+
raise ValueError("files_pack and prefix_list must have the same length")
48+
49+
if not expected_cols:
50+
raise ValueError("expected_cols must not be empty")
51+
52+
key_col = expected_cols[0]
53+
dfs = []
54+
55+
for file_path, prefix in zip(files_pack, prefix_list):
56+
file_path = Path(file_path) # type: ignore[assignment]
57+
if not file_path.exists(): # type: ignore[attr-defined]
58+
raise FileNotFoundError(file_path)
59+
60+
df = pd.read_csv(file_path)
61+
62+
missing = set(expected_cols) - set(df.columns)
63+
if missing:
64+
raise KeyError(f"{file_path} missing columns: {missing}")
65+
66+
df = df[expected_cols].copy()
67+
68+
def add_prefix_and_replace(v: str) -> str:
69+
v_prefixed = f"{prefix}_{v}"
70+
if short_for_xlables and v in short_for_xlables:
71+
return f"{prefix}_{short_for_xlables[v]}"
72+
return v_prefixed
73+
74+
df[key_col] = df[key_col].astype(str).apply(add_prefix_and_replace)
75+
76+
dfs.append(df)
77+
78+
final_df = pd.concat(dfs, ignore_index=True)
79+
final_df.to_csv(output_path, index=False)
80+
81+
print(f"Saved appended CSV to: {output_path}")
82+
83+
84+
if __name__ == "__main__":
85+
# List of CSV files to be merged as a single file
86+
src_files_pack = [
87+
"./outputs_dcp_1_8/fwd/output-8-full-fwd.csv",
88+
"./outs/output-8-full-fwd.csv",
89+
]
90+
# Columns to extract from CSV files during the merge
91+
expected_cols = ["baseline", "tflops-mean"]
92+
# Prefixes to prepend to the values of the first expected column for each file,
93+
# in order to disinguish between baselines.
94+
prefix_list = ["old", "test"]
95+
# Path to save the resulting merged CSV file
96+
output_path = "./merged.csv"
97+
98+
short_for_xlables = {
99+
AttnImpl.ULYSSES.value: "a2a",
100+
AttnImpl.RING_P2P.value: "p2p",
101+
AttnImpl.RING_ALLGATHER.value: "ag",
102+
AttnImpl.USP.value: "usp",
103+
AttnImpl.LOONGTRAIN.value: "loongt",
104+
AttnImpl.MAGI_ATTENTION.value: "magi",
105+
AttnImpl.HYBRID_DCP.value: "dcp",
106+
}
107+
append_files_with_value_prefix(
108+
src_files_pack,
109+
expected_cols,
110+
prefix_list,
111+
output_path,
112+
short_for_xlables,
113+
)
114+
115+
Mark.draw_from_csv(
116+
csv_path=output_path,
117+
perf_key="test",
118+
plot_name="merged-dist-attn-exps",
119+
line_arg="flops",
120+
save_path="./",
121+
ylabel="Throughout (TFLOPs/s)",
122+
x_int=False,
123+
x_log=False,
124+
)

exps/dist_attn/run_benchmark.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,8 @@ def run_benchmark(
871871
cp_pg_meta=cp_pg_meta,
872872
)
873873

874-
output_n = len(BENCH_CONFIG.quantiles) if BENCH_CONFIG.quantiles else 1
874+
quantiles = getattr(BENCH_CONFIG, "quantiles", None)
875+
output_n = len(quantiles) if quantiles else 1
875876
perf_dict_total = {
876877
"flops": [0] * output_n,
877878
"mem": [0] * output_n,
@@ -967,7 +968,7 @@ def run_benchmark(
967968
)
968969
perf_dict = do_bench(
969970
fn,
970-
quantiles=BENCH_CONFIG.quantiles,
971+
quantiles=quantiles,
971972
mem_record_mode="peak",
972973
return_mode=BENCH_CONFIG.bench_mode,
973974
return_flops=BENCH_CONFIG.bench_flops,
@@ -991,6 +992,8 @@ def mem_to_gb(mem: int) -> float:
991992

992993
if BENCH_CONFIG.bench_flops and not is_profile:
993994
flops = perf_dict["flops"]
995+
if not isinstance(flops, list):
996+
flops = [flops] # type: ignore[unreachable]
994997
flops = torch.tensor(
995998
flops, dtype=torch.float32, device=torch.cuda.current_device()
996999
)
@@ -1049,7 +1052,7 @@ def mem_to_gb(mem: int) -> float:
10491052
perf_dict_total,
10501053
result_info,
10511054
BENCH_CONFIG.bench_mode,
1052-
BENCH_CONFIG.quantiles,
1055+
quantiles,
10531056
BENCH_CONFIG.bench_flops,
10541057
BENCH_CONFIG.bench_mem,
10551058
)

magi_attention/benchmarking/bench.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,8 +370,11 @@ def _call(self, bench: Benchmark, **kwargs):
370370
try:
371371
y_mean, _, _ = v
372372
# y_mean, y_min, y_max = v
373-
except TypeError:
374-
y_mean = v
373+
except ValueError:
374+
try:
375+
y_mean = v[0]
376+
except TypeError:
377+
y_mean = v
375378
# y_mean, y_min, y_max = v, None, None # type: ignore
376379
row_mean.setdefault(k, []).append(y_mean)
377380
# row_min.setdefault(k, []).append(y_min)

0 commit comments

Comments
 (0)