Skip to content

Commit fc01015

Browse files
committed
[Benchmark]: add --sweep-mode to visualizer and fix numeric x-axis scaling
- Add --sweep-mode argument (token_length|model_config) to benchmarks_visualizer.py for filtering data by sweep type via the x_name column in CSV, defaulting to token_length - Fix x-axis scaling: convert numeric x_values to proper numeric type so matplotlib plots them proportionally instead of equally spaced; string x_values (e.g. model names) remain categorical - Set tick labels only at actual data points for numeric axes - Include sweep_mode suffix in output PNG filenames to avoid overwriting when both sweep types exist for the same kernel - Update README.md with --sweep-mode usage and examples
1 parent 205ee1c commit fc01015

File tree

2 files changed

+62
-9
lines changed

2 files changed

+62
-9
lines changed

benchmark/README.md

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,31 +18,42 @@ Follow these steps to benchmark and visualize kernel performance:
1818
3. Visualize results
1919
- Use the visualization script with optional modes:
2020

21-
* To target specific mode(s), pass `--kernel-operation-mode` one or more values.
21+
* `--sweep-mode`: Select which sweep data to plot.
22+
- `token_length` (default): plots where x-axis is sequence length.
23+
- `model_config`: plots where x-axis is model configuration.
24+
* To target specific operation mode(s), pass `--kernel-operation-mode` one or more values.
2225
* If you omit `--kernel-operation-mode`, the script will:
2326
- For `speed` metrics: generate plots for all available modes (forward/backward/full).
2427
- For `memory` metrics: generate only the `full` plot.
2528

2629
Examples:
27-
1. Specific modes (speed):
30+
1. Token-length sweep, specific modes (speed):
2831
```bash
2932
python benchmarks_visualizer.py \
3033
--kernel-name kto_loss \
3134
--metric-name speed \
3235
--kernel-operation-mode forward backward
3336
```
34-
2. All modes (speed):
37+
2. Token-length sweep, all modes (speed):
3538
```bash
3639
python benchmarks_visualizer.py \
3740
--kernel-name kto_loss \
3841
--metric-name speed
3942
```
40-
3. Memory (always full):
43+
3. Model-config sweep (speed):
44+
```bash
45+
python benchmarks_visualizer.py \
46+
--kernel-name geglu \
47+
--metric-name speed \
48+
--sweep-mode model_config
49+
```
50+
4. Memory (always full):
4151
```bash
4252
python benchmarks_visualizer.py \
4353
--kernel-name kto_loss \
4454
--metric-name memory
4555
```
4656

4757
4. View results
48-
- Generated plots will be saved in `benchmark/visualizations/`
58+
- Generated plots will be saved in `benchmark/visualizations/`
59+
- Filenames include the sweep mode when specified (e.g. `geglu_speed_full_model_config.png`)

benchmark/benchmarks_visualizer.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
VISUALIZATIONS_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "visualizations/"))
1414

1515

16+
# Map --sweep-mode values to the x_name used in benchmark CSV data.
17+
# "model_config" sweeps always write x_name="model_config"; token-length
18+
# sweeps use kernel-specific names (e.g. "T"), so we match them by exclusion.
19+
SWEEP_MODE_X_NAME = "model_config"
20+
21+
1622
@dataclass
1723
class VisualizationsConfig:
1824
"""
@@ -22,6 +28,9 @@ class VisualizationsConfig:
2228
kernel_name (str): Kernel name to benchmark. (Will run `scripts/benchmark_{kernel_name}.py`)
2329
metric_name (str): Metric name to visualize (speed/memory)
2430
kernel_operation_mode (str): Kernel operation mode to visualize (forward/backward/full). Defaults to "full"
31+
sweep_mode (str, optional): Sweep mode to filter data. "token_length" selects
32+
token/sequence-length sweep data; "model_config" selects model-configuration
33+
sweep data. When None, all data is considered (legacy behaviour).
2534
extra_config_filter (str, optional): A string to filter extra_benchmark_config.
2635
Can be a substring to match or a 'key=value' pair (e.g., "'H': 4096").
2736
Defaults to None, which means the first available config will be used if multiple exist.
@@ -33,6 +42,7 @@ class VisualizationsConfig:
3342
kernel_name: str
3443
metric_name: str
3544
kernel_operation_mode: str = "full"
45+
sweep_mode: str = "token_length"
3646
extra_config_filter: str | None = None
3747
display: bool = False
3848
overwrite: bool = False
@@ -59,6 +69,15 @@ def parse_args() -> VisualizationsConfig:
5969
default=None,
6070
help="Kernel operation modes to visualize (forward/backward/full). If not provided, generate for all available modes.",
6171
)
72+
parser.add_argument(
73+
"--sweep-mode",
74+
type=str,
75+
choices=["token_length", "model_config"],
76+
default="token_length",
77+
help="Sweep mode used when running the benchmark. "
78+
"'token_length' selects token/sequence-length sweep data (default); "
79+
"'model_config' selects model-configuration sweep data.",
80+
)
6281
parser.add_argument(
6382
"--extra-config-filter",
6483
type=str,
@@ -93,11 +112,19 @@ def load_data(config: VisualizationsConfig) -> pd.DataFrame:
93112
df = pd.read_csv(DATA_PATH)
94113
df["extra_benchmark_config"] = df["extra_benchmark_config_str"].apply(json.loads)
95114

96-
base_filtered_df = df[
115+
mask = (
97116
(df["kernel_name"] == config.kernel_name)
98117
& (df["metric_name"] == config.metric_name)
99118
& (df["kernel_operation_mode"] == config.kernel_operation_mode)
100-
]
119+
)
120+
121+
# Filter by sweep mode early, before extra_benchmark_config resolution.
122+
if config.sweep_mode == "model_config":
123+
mask = mask & (df["x_name"] == SWEEP_MODE_X_NAME)
124+
elif config.sweep_mode == "token_length":
125+
mask = mask & (df["x_name"] != SWEEP_MODE_X_NAME)
126+
127+
base_filtered_df = df[mask]
101128

102129
if base_filtered_df.empty:
103130
raise ValueError(
@@ -201,6 +228,14 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
201228
if col in df.columns:
202229
df[col] = pd.to_numeric(df[col], errors="coerce")
203230

231+
# Convert x_value to numeric where possible so matplotlib uses a real
232+
# numeric axis (proper proportional spacing). String x_values (e.g.
233+
# model names) stay as-is and will be treated as categorical (evenly spaced).
234+
x_numeric = pd.to_numeric(df["x_value"], errors="coerce")
235+
is_numeric_x = x_numeric.notna().all()
236+
if is_numeric_x:
237+
df["x_value"] = x_numeric
238+
204239
xlabel = df["x_label"].iloc[0]
205240
ylabel = f"{config.metric_name} ({df['metric_unit'].iloc[0]})"
206241
# Sort by "kernel_provider" to ensure consistent color assignment
@@ -229,12 +264,17 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
229264
errorbar=None,
230265
)
231266

267+
# For numeric x axes, show tick labels only at actual data points
268+
if is_numeric_x:
269+
tick_values = sorted(df["x_value"].unique())
270+
ax.set_xticks(tick_values)
271+
ax.set_xticklabels([str(int(v)) if v == int(v) else str(v) for v in tick_values])
272+
232273
# Seaborn can't plot pre-computed error bars, so we need to do it manually
233274
lines = ax.get_lines()
234275
colors = [line.get_color() for line in lines]
235276

236277
for (_, group_data), color in zip(df.groupby("kernel_provider"), colors):
237-
# for i, row in group_data.iterrows():
238278
y_error_lower = group_data["y_value_50"] - group_data["y_value_20"]
239279
y_error_upper = group_data["y_value_80"] - group_data["y_value_50"]
240280
y_error = [y_error_lower, y_error_upper]
@@ -252,9 +292,10 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
252292
plt.ylabel(ylabel)
253293
plt.tight_layout()
254294

295+
sweep_suffix = f"_{config.sweep_mode}" if config.sweep_mode else ""
255296
out_path = os.path.join(
256297
VISUALIZATIONS_PATH,
257-
f"{config.kernel_name}_{config.metric_name}_{config.kernel_operation_mode}.png",
298+
f"{config.kernel_name}_{config.metric_name}_{config.kernel_operation_mode}{sweep_suffix}.png",
258299
)
259300

260301
if config.display:
@@ -288,6 +329,7 @@ def main():
288329
kernel_name=args.kernel_name,
289330
metric_name=args.metric_name,
290331
kernel_operation_mode=mode,
332+
sweep_mode=args.sweep_mode,
291333
extra_config_filter=args.extra_config_filter,
292334
display=args.display,
293335
overwrite=args.overwrite,

0 commit comments

Comments
 (0)