diff --git a/benchmark/BENCHMARK_GUIDELINES.md b/benchmark/BENCHMARK_GUIDELINES.md index 907223430..1c0761e37 100644 --- a/benchmark/BENCHMARK_GUIDELINES.md +++ b/benchmark/BENCHMARK_GUIDELINES.md @@ -1,52 +1,45 @@ # Guideline for Adding Benchmark Scripts -This document describes how to add new benchmark scripts to Liger-Kernel in line with the shared framework. - -## 1. Where and how to add a script +## 1. Where to add a script - **Location**: `benchmark/scripts/` -- **Naming**: `benchmark_.py` (e.g. `benchmark_geglu.py`, `benchmark_swiglu.py`) +- **Naming**: `benchmark_.py` (e.g. `benchmark_geglu.py`, `benchmark_dyt.py`) + +> **Baseline implementations**: Import reference (non-Liger) kernels from the +> test suite (e.g. `test/transformers/test_.py`) to use as baselines. +> This keeps benchmark and test implementations in sync and avoids duplicating +> reference code in benchmark scripts. -## 2. Use shared infrastructure +## 2. Shared infrastructure -Do **not** hardcode batch size, sequence length, or model dimensions. Use: +Do **not** hardcode batch size, sequence length, or model dimensions. All benchmark scripts share the following: | Need | Use | |------|-----| -| Model dimensions (hidden_size, vocab_size, etc.) | `benchmark_model_configs.py`: `ModelConfig`, `get_benchmark_model_config()` | -| Safe sweep config (seq_len or hidden_size) | `compute_seq_len_sweep_config()` (returns `SeqLenSweepConfig`) or `compute_hidden_size_sweep_config()` (returns `HiddenSizeSweepConfig`), with optional `estimate_kernel_peak_memory()` | +| Model dimensions (hidden_size, vocab_size, etc.) | `benchmark_model_configs.py`: `ModelConfig`, `MODEL_REGISTRY`, `get_benchmark_model_config()` | +| Memory probing | `benchmark_model_configs.py`: `estimate_kernel_peak_memory()` | +| Safe sweep configs | `compute_seq_len_sweep_config()`, `compute_model_config_sweep_config()` | | Speed / memory measurement | `utils.py`: `run_speed_benchmark()`, `run_memory_benchmark()` | -| CLI (overwrite, model choice) | `utils.py`: `parse_benchmark_script_args()` (includes `--model`) | | Running the grid and writing CSV | `utils.py`: `run_benchmarks()` | +| CLI arguments | `utils.py`: `parse_benchmark_script_args()` — provides `--model`, `--overwrite`, `--sweep-mode`, `--bt` | -## 3. Script structure (three parts) - -### 3.1 Setup factory +### 2.1 Setup factory -Define a single **setup function** that builds inputs and the layer (or callable) from `SingleBenchmarkRunInput`, so both speed and memory benchmarks reuse the same setup. +Define a single **setup function** that builds inputs and the layer from `SingleBenchmarkRunInput`, so both speed and memory benchmarks reuse the same setup. - **Signature**: `_setup_(input: SingleBenchmarkRunInput) -> (tensors, layer_or_fn)` -- **Input**: `input.x` is the varying dimension (e.g. sequence length); `input.extra_benchmark_config` holds `bsz`, `hidden_size`, `dtype`, etc.; `input.kernel_provider` identifies the implementation variant (e.g. `"liger"`, `"huggingface"`, `"torch"`; values are kernel-specific). -- **Return**: Whatever the benchmark helpers need (e.g. `(x, layer)` for a single-tensor forward like GEGLU). - -Example (conceptually): +- **Input**: `input.x` is the varying dimension (e.g. seq_len or hidden_size); `input.extra_benchmark_config` holds fixed params like `bsz`, `hidden_size`, `dtype`; `input.kernel_provider` identifies the implementation variant (`"liger"`, `"huggingface"`, `"torch"`, etc.). ```python def _setup_geglu(input: SingleBenchmarkRunInput): cfg = input.extra_benchmark_config - # Build config, create x tensor, instantiate LigerGEGLUMLP or LlamaMLP by provider + # Build model config, create x tensor, instantiate layer by provider return x, layer ``` -### 3.2 Speed and memory benchmark functions - -Each takes `SingleBenchmarkRunInput` and returns `SingleBenchmarkRunOutput` by calling the shared helpers. +### 2.2 Speed and memory benchmark functions -- **Speed**: `run_speed_benchmark(fwd_fn, mode, input_tensors, rep=...)` -- **Memory**: `run_memory_benchmark(fwd_fn, mode)` -- **Modes**: Use `["full", "forward", "backward"]` for both speed and memory for consistency. - -Example: +Each takes `SingleBenchmarkRunInput` and returns `SingleBenchmarkRunOutput`: ```python def bench_speed_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: @@ -58,44 +51,117 @@ def bench_memory_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutp return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) ``` -For **scalar output** (e.g. loss) or **multiple outputs** (e.g. RoPE), use the appropriate helpers from `utils.py` if available (e.g. loss or multi-output variants), or implement custom measurement and still use the same setup factory and `run_benchmarks()`. +- Use `kernel_operation_modes=["full", "forward", "backward"]` for both speed and memory. +- For **scalar output** (e.g. loss) or **multiple outputs** (e.g. RoPE), implement custom measurement logic but still use the same setup factory and `run_benchmarks()`. + +### 2.3 Memory probing + +Most scripts should probe peak memory before computing sweep configs: + +1. Define a `_probe()` that creates tensors/layers at a small scale and returns the output tensor. +2. Call `peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe)`. +3. Use `peak_bytes` to derive safe sweep parameters (see sections 3 and 4). -### 3.3 `__main__`: model config, shape computation, run +Use the **highest-memory baseline** implementation for probing (e.g. `"huggingface"` or `"torch"`) to get a safe upper bound. -1. Parse args: `args = parse_benchmark_script_args()` and resolve `model = get_benchmark_model_config(args.model)`. -2. (Recommended) Measure peak memory with a small probe using the **highest-memory baseline** implementation (e.g. `"huggingface"` or `"torch"`): - - Define a `_probe()` function that creates tensors/layers, runs a forward pass, and returns the output tensor. `_probe()` owns setup; `estimate_kernel_peak_memory` handles memory-stat reset before the call, runs `.backward()`, and performs cleanup (gc + cache clear) afterward. - - Call `peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe)`. -3. Compute sweep config (device memory is obtained internally by both helpers): - - **Sequence-length sweep** (e.g. GEGLU, SwiGLU): convert peak bytes to per-token (`kernel_bpt = peak_bytes // probe_seq_len`), then `config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt)`. The returned `SeqLenSweepConfig` has `batch_size` and `seq_len`. - - **Hidden-size sweep** (e.g. DyT): pass total peak bytes directly: `config = compute_hidden_size_sweep_config(model, kernel_peak_bytes=peak_bytes, bt=BT)`. The returned `HiddenSizeSweepConfig` has `bt` and `max_hidden_size`. -4. Build `x_values` from `config.seq_len` (seq_len sweep) or `config.max_hidden_size` (hidden_size sweep). -5. Build `extra_benchmark_configs` from `model` and config: - - Seq_len sweep: e.g. `bsz=config.batch_size`, `hidden_size=model.hidden_size`, `dtype=model.dtype`. - - Hidden_size sweep: e.g. `BT=config.bt`, `dtype=model.dtype`. -6. Call `run_benchmarks(..., kernel_operation_modes=["full", "forward", "backward"], ...)` for both speed and memory. +## 3. D1 — Non-model dimension sweep + +Sweep non-model dimensions (e.g. sequence length, BT) with a **fixed model config**. Use `--model` to select which model. + +### 3.1 How to implement + +In `__main__`, the `token_length` sweep mode (default) follows this pattern: + +1. Parse args and resolve model: `args = parse_benchmark_script_args()`, `model = get_benchmark_model_config(args.model)`. +2. Probe and compute sweep config: + - **seq_len sweep** (GEGLU, SwiGLU, etc.): `kernel_bpt = peak_bytes // probe_seq_len`, then `config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt)`. Returns `SeqLenSweepConfig` with `batch_size` and `seq_len`. + - **BT sweep** (other ops): use `BT` directly as a fixed dimension if no sweep is needed. +3. Build `x_values` from `config.seq_len` (e.g. `[2**i for i in range(10, log2(config.seq_len) + 1)]`). +4. Build `extra_benchmark_configs` with fixed model dimensions: `bsz=config.batch_size`, `hidden_size=model.hidden_size`, `dtype=model.dtype`, etc. +5. Call `run_benchmarks(...)` for both speed and memory. + +### 3.2 How to run + +```bash +# Default model (llama_3_8b) +python benchmark_geglu.py + +# Specific model +python benchmark_geglu.py --model llama_2_7b + +# Overwrite existing CSV entries +python benchmark_geglu.py --model llama_3_8b --overwrite +``` -## 4. CLI +### 3.3 Reference scripts -Scripts should support: +- **seq_len sweep**: `benchmark_geglu.py`, `benchmark_swiglu.py` — `compute_seq_len_sweep_config()` -- `--overwrite`: overwrite existing rows in the benchmark CSV. -- `--model`: model profile name from `MODEL_REGISTRY` (e.g. `llama_2_7b`, `llama_3_8b`). Default when not set is `DEFAULT_MODEL_CONFIG` (e.g. `llama_3_8b`). +## 4. D2 — Model dimension sweep -These are provided by `parse_benchmark_script_args()` in `utils.py`. +Sweep across discrete model configs from `MODEL_REGISTRY` with a **fixed token count**. Use `--bt` to set the token count. -## 5. Reference scripts +### 4.1 Discrete model-config sweep -- **Element-wise (single tensor in/out, seq_len sweep)**: `benchmark_geglu.py`, `benchmark_swiglu.py` — `compute_seq_len_sweep_config()`. -- **Element-wise (single tensor in/out, hidden_size sweep)**: `benchmark_dyt.py` — `compute_hidden_size_sweep_config()`. +Sweep across all `MODEL_REGISTRY` entries as discrete data points. Activated by `--sweep-mode model_config`. + +**How to implement:** + +1. Add a `_resolve_model_config_` helper that maps `input.x` (model index) to a standard `SingleBenchmarkRunInput`: + +```python +def _resolve_model_config_geglu(input: SingleBenchmarkRunInput): + """Resolve model-config-sweep input into standard setup args.""" + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][int(input.x)] + return _setup_geglu(SingleBenchmarkRunInput( + x=cfg["seq_len"], + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "bsz": cfg["bsz"], + "hidden_size": model_info["hidden_size"], + "intermediate_size": model_info["intermediate_size"], + "hidden_act": cfg["hidden_act"], + "dtype": model_info["dtype"], + }, + )) +``` + +2. Add `bench_speed__model_config` and `bench_memory__model_config`: + +```python +def bench_speed_geglu_model_config(input): + x, layer = _resolve_model_config_geglu(input) + return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x]) +``` + +3. In `__main__`, gate on `args.sweep_mode == "model_config"`: + - Build `_probe_factory(model_cfg, probe_seq_len)` that returns a probe callable. + - Call `sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=..., bt=args.bt)`. + - Build `model_configs_info` (list of dicts with each model's dimensions) and pass in `extra_benchmark_configs`. + - `x_values = list(range(len(sweep.model_configs)))` (model indices). + - Call `run_benchmarks(bench_test_fn=bench_speed__model_config, ...)`. + +**Reference**: `benchmark_geglu.py`, `benchmark_swiglu.py`, `benchmark_dyt.py` — all support `--sweep-mode model_config`. + +### 4.2 How to run + +```bash +# Discrete model-config sweep with default bt=2048 +python benchmark_geglu.py --sweep-mode model_config + +# With custom bt +python benchmark_geglu.py --sweep-mode model_config --bt 4096 +``` -## 6. Checklist for a new script +## 5. Checklist - [ ] Script under `benchmark/scripts/` named `benchmark_.py`. - [ ] Single `_setup_(SingleBenchmarkRunInput)` used by both speed and memory. -- [ ] Speed/memory implemented via `run_speed_benchmark` / `run_memory_benchmark` (or the correct variant for loss / multi-output). +- [ ] Speed/memory via `run_speed_benchmark` / `run_memory_benchmark` (or custom variant for loss/multi-output). - [ ] `kernel_operation_modes=["full", "forward", "backward"]` for both speed and memory. -- [ ] No hardcoded batch size or sequence length; use `compute_seq_len_sweep_config()` or `compute_hidden_size_sweep_config()` (and optionally `estimate_kernel_peak_memory()`). +- [ ] No hardcoded batch size or sequence length; sweep configs from `compute_*_sweep_config()` + `estimate_kernel_peak_memory()`. - [ ] Model dimensions and dtype from `ModelConfig` / `get_benchmark_model_config()` / `args.model`. -- [ ] CLI via `parse_benchmark_script_args()` (so `--model` and `--overwrite` work). -- [ ] Results written through `run_benchmarks()` so data goes to the shared CSV. +- [ ] CLI via `parse_benchmark_script_args()` (so `--model`, `--overwrite`, `--sweep-mode`, `--bt` all work). +- [ ] Results written through `run_benchmarks()` to the shared CSV. +- [ ] Model-config sweep: `_resolve_model_config_`, `bench_speed__model_config`, `bench_memory__model_config`, and `__main__` model-config code path. diff --git a/benchmark/README.md b/benchmark/README.md index 02c883d92..aefba404b 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -18,26 +18,36 @@ Follow these steps to benchmark and visualize kernel performance: 3. Visualize results - Use the visualization script with optional modes: - * To target specific mode(s), pass `--kernel-operation-mode` one or more values. + * `--sweep-mode`: Select which sweep data to plot. + - `token_length` (default): plots where x-axis is sequence length. + - `model_config`: plots where x-axis is model configuration. + * To target specific operation mode(s), pass `--kernel-operation-mode` one or more values. * If you omit `--kernel-operation-mode`, the script will: - For `speed` metrics: generate plots for all available modes (forward/backward/full). - For `memory` metrics: generate only the `full` plot. Examples: - 1. Specific modes (speed): + 1. Token-length sweep, specific modes (speed): ```bash python benchmarks_visualizer.py \ --kernel-name kto_loss \ --metric-name speed \ --kernel-operation-mode forward backward ``` - 2. All modes (speed): + 2. Token-length sweep, all modes (speed): ```bash python benchmarks_visualizer.py \ --kernel-name kto_loss \ --metric-name speed ``` - 3. Memory (always full): + 3. Model-config sweep (speed): + ```bash + python benchmarks_visualizer.py \ + --kernel-name geglu \ + --metric-name speed \ + --sweep-mode model_config + ``` + 4. Memory (always full): ```bash python benchmarks_visualizer.py \ --kernel-name kto_loss \ @@ -45,4 +55,5 @@ Follow these steps to benchmark and visualize kernel performance: ``` 4. View results - - Generated plots will be saved in `benchmark/visualizations/` \ No newline at end of file + - Generated plots will be saved in `benchmark/visualizations/` + - Filenames include the sweep mode when specified (e.g. `geglu_speed_full_model_config.png`) \ No newline at end of file diff --git a/benchmark/benchmarks_visualizer.py b/benchmark/benchmarks_visualizer.py index e33d844ea..613587a9a 100644 --- a/benchmark/benchmarks_visualizer.py +++ b/benchmark/benchmarks_visualizer.py @@ -13,6 +13,12 @@ VISUALIZATIONS_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "visualizations/")) +# Map --sweep-mode values to the x_name used in benchmark CSV data. +# "model_config" sweeps always write x_name="model_config"; token-length +# sweeps use kernel-specific names (e.g. "T"), so we match them by exclusion. +SWEEP_MODE_X_NAME = "model_config" + + @dataclass class VisualizationsConfig: """ @@ -22,6 +28,9 @@ class VisualizationsConfig: kernel_name (str): Kernel name to benchmark. (Will run `scripts/benchmark_{kernel_name}.py`) metric_name (str): Metric name to visualize (speed/memory) kernel_operation_mode (str): Kernel operation mode to visualize (forward/backward/full). Defaults to "full" + sweep_mode (str, optional): Sweep mode to filter data. "token_length" selects + token/sequence-length sweep data; "model_config" selects model-configuration + sweep data. When None, all data is considered (legacy behaviour). extra_config_filter (str, optional): A string to filter extra_benchmark_config. Can be a substring to match or a 'key=value' pair (e.g., "'H': 4096"). Defaults to None, which means the first available config will be used if multiple exist. @@ -33,7 +42,9 @@ class VisualizationsConfig: kernel_name: str metric_name: str kernel_operation_mode: str = "full" + sweep_mode: str = "token_length" extra_config_filter: str | None = None + gpu_filter: str | None = None display: bool = False overwrite: bool = False @@ -59,6 +70,15 @@ def parse_args() -> VisualizationsConfig: default=None, help="Kernel operation modes to visualize (forward/backward/full). If not provided, generate for all available modes.", ) + parser.add_argument( + "--sweep-mode", + type=str, + choices=["token_length", "model_config"], + default="token_length", + help="Sweep mode used when running the benchmark. " + "'token_length' selects token/sequence-length sweep data (default); " + "'model_config' selects model-configuration sweep data.", + ) parser.add_argument( "--extra-config-filter", type=str, @@ -67,6 +87,14 @@ def parse_args() -> VisualizationsConfig: "Can be a substring to match or a JSON-like 'key=value' pair (e.g., \"'H': 4096\" or \"H=4096\" for simple cases). " "Defaults to None (first available config if multiple exist).", ) + parser.add_argument( + "--gpu-filter", + type=str, + default=None, + help="Filter by GPU name. When multiple devices are present, selects " + "the matching GPU (uses most recent match if multiple found). " + "If omitted, the most recent device is used automatically.", + ) parser.add_argument("--display", action="store_true", help="Display the visualization") parser.add_argument( "--overwrite", @@ -78,35 +106,69 @@ def parse_args() -> VisualizationsConfig: return args -def load_data(config: VisualizationsConfig) -> pd.DataFrame: - """Loads the benchmark data from the CSV file and filters it based on the configuration. +def gpu_name_filter(df: pd.DataFrame, gpu_filter: str | None = None) -> pd.DataFrame: + """Filter benchmark data by GPU name when multiple devices are present. Args: - config (VisualizationsConfig): Configuration object for the visualizations script. - - Raises: - ValueError: If no data is found for the given filters. + df: Pre-filtered benchmark dataframe. + gpu_filter: Optional GPU name substring to match. If provided, selects + the matching GPU (uses most recent if multiple match). If None, + automatically picks the most recent device. Returns: - pd.DataFrame: Filtered benchmark dataframe. + pd.DataFrame: Dataframe filtered to a single GPU. """ - df = pd.read_csv(DATA_PATH) - df["extra_benchmark_config"] = df["extra_benchmark_config_str"].apply(json.loads) + if "gpu_name" not in df.columns or df.empty: + return df + + unique_gpus = df["gpu_name"].unique() + if len(unique_gpus) <= 1: + return df + + if gpu_filter: + matched = [g for g in unique_gpus if gpu_filter in g] + if matched: + if len(matched) > 1: + # Multiple matches — pick the most recent + matched_df = df[df["gpu_name"].isin(matched)] + selected = matched_df.sort_values("timestamp", ascending=False)["gpu_name"].iloc[0] + print( + f"Warning: Multiple GPUs match filter '{gpu_filter}': {matched}. " + f"Using the most recent: '{selected}'." + ) + else: + selected = matched[0] + else: + # No match — fall back to most recent GPU + selected = df.sort_values("timestamp", ascending=False)["gpu_name"].iloc[0] + print( + f"Warning: No GPU matches filter '{gpu_filter}'. " + f"Available GPUs: {list(unique_gpus)}. " + f"Falling back to most recent device: '{selected}'." + ) + else: + # No filter provided — pick the most recent device + selected = df.sort_values("timestamp", ascending=False)["gpu_name"].iloc[0] + print( + f"Warning: Data contains entries from multiple devices: {list(unique_gpus)}. " + f"Using data from the most recent device: '{selected}'. " + f"Use --gpu-filter to select a specific device." + ) - base_filtered_df = df[ - (df["kernel_name"] == config.kernel_name) - & (df["metric_name"] == config.metric_name) - & (df["kernel_operation_mode"] == config.kernel_operation_mode) - ] + return df[df["gpu_name"] == selected] - if base_filtered_df.empty: - raise ValueError( - f"No data found for kernel_name='{config.kernel_name}', " - f"metric_name='{config.metric_name}', " - f"kernel_operation_mode='{config.kernel_operation_mode}'." - ) - unique_extra_configs_str = base_filtered_df["extra_benchmark_config_str"].unique() +def extra_config_filter(df: pd.DataFrame, config: VisualizationsConfig) -> pd.DataFrame: + """Filter benchmark data by extra_benchmark_config. + + Args: + df: Pre-filtered benchmark dataframe (already filtered by kernel, metric, etc.). + config: Visualization configuration with optional extra_config_filter. + + Returns: + pd.DataFrame: Dataframe filtered to a single extra_benchmark_config. + """ + unique_extra_configs_str = df["extra_benchmark_config_str"].unique() selected_extra_config_str = None if len(unique_extra_configs_str) == 0: @@ -114,7 +176,7 @@ def load_data(config: VisualizationsConfig) -> pd.DataFrame: "Warning: No extra_benchmark_config found for the initial filters. " "Proceeding with all data from initial filter." ) - return base_filtered_df + return df if config.extra_config_filter: matched_configs = [] @@ -169,14 +231,12 @@ def load_data(config: VisualizationsConfig) -> pd.DataFrame: print(f"Using unique extra_benchmark_config: {selected_extra_config_str}") if selected_extra_config_str: - final_filtered_df = base_filtered_df[ - base_filtered_df["extra_benchmark_config_str"] == selected_extra_config_str - ] + result_df = df[df["extra_benchmark_config_str"] == selected_extra_config_str] else: print("Warning: Could not select an extra_benchmark_config. Using data from initial filter if any.") - final_filtered_df = base_filtered_df + result_df = df - if final_filtered_df.empty: + if result_df.empty: raise ValueError( f"No data found after attempting to filter by extra_benchmark_config. " f"Selected/Defaulted extra_config_str: {selected_extra_config_str}" @@ -187,7 +247,50 @@ def load_data(config: VisualizationsConfig) -> pd.DataFrame: print( f"Plotting data for extra_benchmark_config: {json.loads(selected_extra_config_str if selected_extra_config_str else '{}')}" ) - return final_filtered_df + return result_df + + +def load_data(config: VisualizationsConfig) -> pd.DataFrame: + """Loads the benchmark data from the CSV file and filters it based on the configuration. + + Applies filters in order: kernel/metric/mode → sweep mode → GPU → extra config. + + Args: + config (VisualizationsConfig): Configuration object for the visualizations script. + + Raises: + ValueError: If no data is found for the given filters. + + Returns: + pd.DataFrame: Filtered benchmark dataframe. + """ + df = pd.read_csv(DATA_PATH) + df["extra_benchmark_config"] = df["extra_benchmark_config_str"].apply(json.loads) + + mask = ( + (df["kernel_name"] == config.kernel_name) + & (df["metric_name"] == config.metric_name) + & (df["kernel_operation_mode"] == config.kernel_operation_mode) + ) + + # Filter by sweep mode early, before extra_benchmark_config resolution. + if config.sweep_mode == "model_config": + mask = mask & (df["x_name"] == SWEEP_MODE_X_NAME) + elif config.sweep_mode == "token_length": + mask = mask & (df["x_name"] != SWEEP_MODE_X_NAME) + + base_filtered_df = df[mask] + + if base_filtered_df.empty: + raise ValueError( + f"No data found for kernel_name='{config.kernel_name}', " + f"metric_name='{config.metric_name}', " + f"kernel_operation_mode='{config.kernel_operation_mode}'." + ) + + # Apply GPU filter, then extra config filter + base_filtered_df = gpu_name_filter(base_filtered_df, config.gpu_filter) + return extra_config_filter(base_filtered_df, config) def plot_data(df: pd.DataFrame, config: VisualizationsConfig): @@ -201,6 +304,14 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig): if col in df.columns: df[col] = pd.to_numeric(df[col], errors="coerce") + # Convert x_value to numeric where possible so matplotlib uses a real + # numeric axis (proper proportional spacing). String x_values (e.g. + # model names) stay as-is and will be treated as categorical (evenly spaced). + x_numeric = pd.to_numeric(df["x_value"], errors="coerce") + is_numeric_x = x_numeric.notna().all() + if is_numeric_x: + df["x_value"] = x_numeric + xlabel = df["x_label"].iloc[0] ylabel = f"{config.metric_name} ({df['metric_unit'].iloc[0]})" # Sort by "kernel_provider" to ensure consistent color assignment @@ -229,12 +340,17 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig): errorbar=None, ) + # For numeric x axes, show tick labels only at actual data points + if is_numeric_x: + tick_values = sorted(df["x_value"].unique()) + ax.set_xticks(tick_values) + ax.set_xticklabels([str(int(v)) if v == int(v) else str(v) for v in tick_values]) + # Seaborn can't plot pre-computed error bars, so we need to do it manually lines = ax.get_lines() colors = [line.get_color() for line in lines] for (_, group_data), color in zip(df.groupby("kernel_provider"), colors): - # for i, row in group_data.iterrows(): y_error_lower = group_data["y_value_50"] - group_data["y_value_20"] y_error_upper = group_data["y_value_80"] - group_data["y_value_50"] y_error = [y_error_lower, y_error_upper] @@ -252,9 +368,10 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig): plt.ylabel(ylabel) plt.tight_layout() + sweep_suffix = f"_{config.sweep_mode}" if config.sweep_mode else "" out_path = os.path.join( VISUALIZATIONS_PATH, - f"{config.kernel_name}_{config.metric_name}_{config.kernel_operation_mode}.png", + f"{config.kernel_name}_{config.metric_name}_{config.kernel_operation_mode}{sweep_suffix}.png", ) if config.display: @@ -288,6 +405,9 @@ def main(): kernel_name=args.kernel_name, metric_name=args.metric_name, kernel_operation_mode=mode, + sweep_mode=args.sweep_mode, + extra_config_filter=args.extra_config_filter, + gpu_filter=args.gpu_filter, display=args.display, overwrite=args.overwrite, ) diff --git a/benchmark/scripts/benchmark_dyt.py b/benchmark/scripts/benchmark_dyt.py index 2c5129000..549093f25 100644 --- a/benchmark/scripts/benchmark_dyt.py +++ b/benchmark/scripts/benchmark_dyt.py @@ -1,9 +1,12 @@ +import math import os import sys import torch -from benchmark_model_configs import compute_hidden_size_sweep_config +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config +from benchmark_model_configs import compute_seq_len_sweep_config from benchmark_model_configs import estimate_kernel_peak_memory from benchmark_model_configs import get_benchmark_model_config from utils import SingleBenchmarkRunInput @@ -26,8 +29,9 @@ def _setup_dyt(input: SingleBenchmarkRunInput): from test.transformers.test_dyt import TorchDyT cfg = input.extra_benchmark_config - hidden_size = input.x - x = torch.randn(cfg["BT"], hidden_size, device=device, dtype=cfg["dtype"], requires_grad=True) + hidden_size = cfg["hidden_size"] + bt = input.x + x = torch.randn(bt, hidden_size, device=device, dtype=cfg["dtype"], requires_grad=True) if input.kernel_provider == "liger": layer = LigerDyT(hidden_size=hidden_size, beta=cfg["beta"]).to(device) elif input.kernel_provider == "torch": @@ -49,48 +53,147 @@ def bench_memory_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) -BT = 4096 +def _resolve_model_config_dyt(input: SingleBenchmarkRunInput): + """Resolve model-config-sweep input into standard setup args.""" + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_dyt( + SingleBenchmarkRunInput( + x=cfg["BT"], + kernel_provider=input.kernel_provider, + extra_benchmark_config={ + "hidden_size": model_info["hidden_size"], + "dtype": model_info["dtype"], + "beta": cfg["beta"], + }, + ) + ) + + +def bench_speed_dyt_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _resolve_model_config_dyt(input) + return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x]) + + +def bench_memory_dyt_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _resolve_model_config_dyt(input) + return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) + if __name__ == "__main__": args = parse_benchmark_script_args() - model = get_benchmark_model_config(args.model) - - for beta in [False, True]: - def _probe(): - probe_input = SingleBenchmarkRunInput( - x=model.hidden_size, - kernel_provider="torch", - extra_benchmark_config={"BT": BT, "dtype": model.dtype, "beta": beta}, + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + + for beta in [False, True]: + + def _probe_factory(model_cfg, probe_bt, _beta=beta): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="torch", + extra_benchmark_config={ + "hidden_size": model_cfg.hidden_size, + "dtype": model_cfg.dtype, + "beta": _beta, + }, + ) + x, layer = _setup_dyt(probe_input) + return layer(x) + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + + model_configs_info = { + cfg.name: { + "hidden_size": cfg.hidden_size, + "dtype": cfg.dtype, + } + for cfg in sweep.model_configs + } + common_configs = { + "kernel_name": f"dyt_beta={beta}", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "torch", "torch_compile"], + "extra_benchmark_configs": [ + { + "model_configs": model_configs_info, + "BT": sweep.bt, + "beta": beta, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_dyt_model_config, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_dyt_model_config, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + probe_bt = 1024 + + for beta in [False, True]: + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="torch", + extra_benchmark_config={ + "hidden_size": model.hidden_size, + "dtype": model.dtype, + "beta": beta, + }, + ) + x, layer = _setup_dyt(probe_input) + return layer(x) + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_bt + + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": f"dyt_beta={beta}", + "x_name": "BT", + "x_label": "B * T", + "x_values": [2**i for i in range(10, int(math.log2(config.batch_size * config.seq_len)) + 1)], + "kernel_providers": ["liger", "torch", "torch_compile"], + "extra_benchmark_configs": [ + { + "hidden_size": model.hidden_size, + "dtype": model.dtype, + "beta": beta, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_dyt, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_dyt, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs, ) - x, layer = _setup_dyt(probe_input) - return layer(x) - - peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) - sweep_config = compute_hidden_size_sweep_config(model, peak_bytes, bt=BT) - x_values = [1024 * i for i in range(1, 17) if 1024 * i <= sweep_config.max_hidden_size] or [model.hidden_size] - - common_configs = { - "kernel_name": f"dyt_beta={beta}", - "x_name": "hidden_size", - "x_label": "hidden_size", - "x_values": x_values, - "kernel_providers": ["liger", "torch", "torch_compile"], - "extra_benchmark_configs": [{"BT": sweep_config.bt, "dtype": model.dtype, "beta": beta}], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_dyt, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_dyt, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="memory", - metric_unit="MB", - **common_configs, - ) diff --git a/benchmark/scripts/benchmark_geglu.py b/benchmark/scripts/benchmark_geglu.py index d59564baf..52a03403b 100644 --- a/benchmark/scripts/benchmark_geglu.py +++ b/benchmark/scripts/benchmark_geglu.py @@ -2,6 +2,8 @@ import torch +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config from benchmark_model_configs import compute_seq_len_sweep_config from benchmark_model_configs import estimate_kernel_peak_memory from benchmark_model_configs import get_benchmark_model_config @@ -55,61 +57,154 @@ def bench_memory_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutp return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) -if __name__ == "__main__": - args = parse_benchmark_script_args() - - model = get_benchmark_model_config(args.model) - probe_seq_len = 1024 - - def _probe(): - probe_input = SingleBenchmarkRunInput( - x=probe_seq_len, - kernel_provider="huggingface", +def _resolve_model_config_geglu(input: SingleBenchmarkRunInput): + """Resolve model-config-sweep input into standard setup args.""" + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_geglu( + SingleBenchmarkRunInput( + x=cfg["seq_len"], + kernel_provider=input.kernel_provider, extra_benchmark_config={ - "bsz": 1, - "hidden_size": model.hidden_size, - "intermediate_size": model.intermediate_size, - "hidden_act": "gelu_pytorch_tanh", - "dtype": model.dtype, + "bsz": cfg["bsz"], + "hidden_size": model_info["hidden_size"], + "intermediate_size": model_info["intermediate_size"], + "hidden_act": cfg["hidden_act"], + "dtype": model_info["dtype"], }, ) - x, layer = _setup_geglu(probe_input) - return layer(x) - - peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) - kernel_bpt = peak_bytes // probe_seq_len - - config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) - - common_configs = { - "kernel_name": "geglu", - "x_name": "T", - "x_label": "sequence length", - "x_values": [2**i for i in range(10, int(math.log2(config.seq_len)) + 1)], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "bsz": config.batch_size, - "hidden_size": model.hidden_size, - "intermediate_size": model.intermediate_size, - "hidden_act": "gelu_pytorch_tanh", - "dtype": model.dtype, - } - ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_geglu, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_geglu, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="memory", - metric_unit="MB", - **common_configs, ) + + +def bench_speed_geglu_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _resolve_model_config_geglu(input) + return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x]) + + +def bench_memory_geglu_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _resolve_model_config_geglu(input) + return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + + def _probe_factory(model_cfg, probe_seq_len): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_seq_len, + kernel_provider="huggingface", + extra_benchmark_config={ + "bsz": 1, + "hidden_size": model_cfg.hidden_size, + "intermediate_size": model_cfg.intermediate_size, + "hidden_act": "gelu_pytorch_tanh", + "dtype": model_cfg.dtype, + }, + ) + x, layer = _setup_geglu(probe_input) + return layer(x) + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + + model_configs_info = { + cfg.name: { + "hidden_size": cfg.hidden_size, + "intermediate_size": cfg.intermediate_size, + "dtype": cfg.dtype, + } + for cfg in sweep.model_configs + } + + common_configs = { + "kernel_name": "geglu", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "model_configs": model_configs_info, + "bsz": sweep.batch_size, + "seq_len": sweep.seq_len, + "hidden_act": "gelu_pytorch_tanh", + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_geglu_model_config, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_geglu_model_config, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + probe_seq_len = 1024 + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_seq_len, + kernel_provider="huggingface", + extra_benchmark_config={ + "bsz": 1, + "hidden_size": model.hidden_size, + "intermediate_size": model.intermediate_size, + "hidden_act": "gelu_pytorch_tanh", + "dtype": model.dtype, + }, + ) + x, layer = _setup_geglu(probe_input) + return layer(x) + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_seq_len + + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "geglu", + "x_name": "T", + "x_label": "sequence length", + "x_values": [2**i for i in range(10, int(math.log2(config.seq_len)) + 1)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "bsz": config.batch_size, + "hidden_size": model.hidden_size, + "intermediate_size": model.intermediate_size, + "hidden_act": "gelu_pytorch_tanh", + "dtype": model.dtype, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_geglu, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_geglu, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_model_configs.py b/benchmark/scripts/benchmark_model_configs.py index 630b0d555..c4b101366 100644 --- a/benchmark/scripts/benchmark_model_configs.py +++ b/benchmark/scripts/benchmark_model_configs.py @@ -28,7 +28,9 @@ from dataclasses import dataclass from typing import Callable from typing import Dict +from typing import List from typing import Optional +from typing import Tuple import torch @@ -73,16 +75,20 @@ class SeqLenSweepConfig: @dataclass(frozen=True) -class HiddenSizeSweepConfig: - """Config for benchmarks that sweep hidden_size with fixed BT (e.g. DyT). +class ModelConfigSweepConfig: + """Config for benchmarks that sweep across model configs. Attributes: - bt: Fixed batch * seq dimension. - max_hidden_size: Upper bound for hidden_size sweep. + model_configs: Model configs to benchmark (as tuple for immutability). + bt: Effective total tokens (batch_size * seq_len). + batch_size: Safe batch size across all model configs. + seq_len: Safe sequence length across all model configs. """ + model_configs: Tuple[ModelConfig, ...] bt: int - max_hidden_size: int + batch_size: int + seq_len: int # ── Model Profiles ────────────────────────────────────────────────────────── @@ -224,35 +230,48 @@ def compute_seq_len_sweep_config( return SeqLenSweepConfig(batch_size=batch_size, seq_len=seq_len) -def compute_hidden_size_sweep_config( - model_cfg: ModelConfig, - kernel_peak_bytes: int, - bt: int = 4096, +def compute_model_config_sweep_config( + model_configs: List[ModelConfig], + probe_fn_factory: Callable[[ModelConfig, int], Callable[[], torch.Tensor]], + bt: int = 2048, memory_utilization: float = 0.4, - max_hidden_size_multiplier: int = 4, -) -> HiddenSizeSweepConfig: - """Compute safe max_hidden_size for hidden_size sweep (e.g. DyT). +) -> ModelConfigSweepConfig: + """Find safe (batch_size, seq_len) that works across all model configs. - For kernels with shape (BT, hidden_size) where BT is fixed and we sweep - hidden_size. Uses probe peak memory to derive max_hidden_size. - Device memory is obtained internally via :func:`~liger_kernel.utils.get_total_gpu_memory`. + Probes each model config at a small token count to measure peak memory, + then picks the most conservative parameters that fit within device memory. Args: - model_cfg: Model config. - kernel_peak_bytes: Peak memory from probe (BT, model.hidden_size). - bt: Fixed BT dimension; must match the probe. + model_configs: Model configs to benchmark. + probe_fn_factory: Factory ``(model_cfg, probe_seq_len) -> probe_fn``. + The returned probe_fn should perform setup + forward pass and + return a tensor suitable for ``.backward()``, same contract as + :func:`estimate_kernel_peak_memory`'s *probe_fn*. + bt: Target total tokens (batch_size * seq_len). memory_utilization: Fraction of device memory to use. - max_hidden_size_multiplier: Cap max_hidden_size at model.hidden_size * this. """ total_memory_gb = get_total_gpu_memory() usable_bytes = total_memory_gb * (1024**3) * memory_utilization - kernel_bpt = max(1, kernel_peak_bytes // bt) - max_hidden_size = min( - model_cfg.hidden_size * max_hidden_size_multiplier, - max( - model_cfg.hidden_size, - int(usable_bytes * model_cfg.hidden_size / (bt * kernel_bpt)), - ), + + probe_seq_len = min(bt, 1024) + max_bytes_per_token = 0 + + for model_cfg in model_configs: + probe_fn = probe_fn_factory(model_cfg, probe_seq_len) + peak_bytes = estimate_kernel_peak_memory(probe_fn) + bpt = max(1, peak_bytes // probe_seq_len) + max_bytes_per_token = max(max_bytes_per_token, bpt) + + max_tokens = max(1, int(usable_bytes / max_bytes_per_token)) + safe_bt = min(bt, max_tokens) + + seq_len = min(safe_bt, 8192) + seq_len = 2 ** int(math.log2(seq_len)) if seq_len >= 1024 else 1024 + batch_size = max(1, safe_bt // seq_len) + + return ModelConfigSweepConfig( + model_configs=tuple(model_configs), + bt=batch_size * seq_len, + batch_size=batch_size, + seq_len=seq_len, ) - max_hidden_size = max(1024, 2 ** int(math.log2(max_hidden_size))) - return HiddenSizeSweepConfig(bt=bt, max_hidden_size=max_hidden_size) diff --git a/benchmark/scripts/benchmark_swiglu.py b/benchmark/scripts/benchmark_swiglu.py index 8d46572fd..dc34fd60d 100644 --- a/benchmark/scripts/benchmark_swiglu.py +++ b/benchmark/scripts/benchmark_swiglu.py @@ -2,6 +2,8 @@ import torch +from benchmark_model_configs import MODEL_REGISTRY +from benchmark_model_configs import compute_model_config_sweep_config from benchmark_model_configs import compute_seq_len_sweep_config from benchmark_model_configs import estimate_kernel_peak_memory from benchmark_model_configs import get_benchmark_model_config @@ -55,61 +57,154 @@ def bench_memory_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOut return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) -if __name__ == "__main__": - args = parse_benchmark_script_args() - - model = get_benchmark_model_config(args.model) - probe_seq_len = 1024 - - def _probe(): - probe_input = SingleBenchmarkRunInput( - x=probe_seq_len, - kernel_provider="huggingface", +def _resolve_model_config_swiglu(input: SingleBenchmarkRunInput): + """Resolve model-config-sweep input into standard setup args.""" + cfg = input.extra_benchmark_config + model_info = cfg["model_configs"][input.x] + return _setup_swiglu( + SingleBenchmarkRunInput( + x=cfg["seq_len"], + kernel_provider=input.kernel_provider, extra_benchmark_config={ - "bsz": 1, - "hidden_size": model.hidden_size, - "intermediate_size": model.intermediate_size, - "hidden_act": "silu", - "dtype": model.dtype, + "bsz": cfg["bsz"], + "hidden_size": model_info["hidden_size"], + "intermediate_size": model_info["intermediate_size"], + "hidden_act": cfg["hidden_act"], + "dtype": model_info["dtype"], }, ) - x, layer = _setup_swiglu(probe_input) - return layer(x) - - peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) - kernel_bpt = peak_bytes // probe_seq_len - - config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) - - common_configs = { - "kernel_name": "swiglu", - "x_name": "T", - "x_label": "sequence length", - "x_values": [2**i for i in range(10, int(math.log2(config.seq_len)) + 1)], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "bsz": config.batch_size, - "hidden_size": model.hidden_size, - "intermediate_size": model.intermediate_size, - "hidden_act": "silu", - "dtype": model.dtype, - } - ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_swiglu, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_swiglu, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="memory", - metric_unit="MB", - **common_configs, ) + + +def bench_speed_swiglu_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _resolve_model_config_swiglu(input) + return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x]) + + +def bench_memory_swiglu_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + x, layer = _resolve_model_config_swiglu(input) + return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + if args.sweep_mode == "model_config": + all_model_configs = list(MODEL_REGISTRY.values()) + + def _probe_factory(model_cfg, probe_seq_len): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_seq_len, + kernel_provider="huggingface", + extra_benchmark_config={ + "bsz": 1, + "hidden_size": model_cfg.hidden_size, + "intermediate_size": model_cfg.intermediate_size, + "hidden_act": "silu", + "dtype": model_cfg.dtype, + }, + ) + x, layer = _setup_swiglu(probe_input) + return layer(x) + + return _probe + + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) + + model_configs_info = { + cfg.name: { + "hidden_size": cfg.hidden_size, + "intermediate_size": cfg.intermediate_size, + "dtype": cfg.dtype, + } + for cfg in sweep.model_configs + } + + common_configs = { + "kernel_name": "swiglu", + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "model_configs": model_configs_info, + "bsz": sweep.batch_size, + "seq_len": sweep.seq_len, + "hidden_act": "silu", + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_swiglu_model_config, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_swiglu_model_config, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) + else: + model = get_benchmark_model_config(args.model) + probe_seq_len = 1024 + + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_seq_len, + kernel_provider="huggingface", + extra_benchmark_config={ + "bsz": 1, + "hidden_size": model.hidden_size, + "intermediate_size": model.intermediate_size, + "hidden_act": "silu", + "dtype": model.dtype, + }, + ) + x, layer = _setup_swiglu(probe_input) + return layer(x) + + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) + kernel_bpt = peak_bytes // probe_seq_len + + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + common_configs = { + "kernel_name": "swiglu", + "x_name": "T", + "x_label": "sequence length", + "x_values": [2**i for i in range(10, int(math.log2(config.seq_len)) + 1)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "bsz": config.batch_size, + "hidden_size": model.hidden_size, + "intermediate_size": model.intermediate_size, + "hidden_act": "silu", + "dtype": model.dtype, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_swiglu, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_swiglu, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/utils.py b/benchmark/scripts/utils.py index e6b4fc9e8..0cb307d19 100644 --- a/benchmark/scripts/utils.py +++ b/benchmark/scripts/utils.py @@ -29,7 +29,7 @@ @dataclass class SingleBenchmarkRunInput: - x: Union[int, float] + x: Union[int, float, str] kernel_provider: str kernel_operation_mode: Optional[str] = "" extra_benchmark_config: Optional[Dict[str, Any]] = None @@ -59,7 +59,7 @@ class BenchmarkData: gpu_name: str x_name: str x_label: str - x_values: List[float] + x_values: List[Union[float, str]] y_values_50: List[float] y_values_20: List[float] y_values_80: List[float] @@ -79,7 +79,7 @@ class BenchmarkDataCSVRow: metric_unit: str x_name: str x_label: str - x_value: float + x_value: Union[float, str] y_value_50: float y_value_20: float y_value_80: float @@ -341,7 +341,7 @@ def run_benchmarks( metric_unit: str, x_name: str, x_label: str, - x_values: List[Union[float, int]], + x_values: List[Union[float, int, str]], kernel_providers: List[str], kernel_operation_modes: Optional[List[str]] = [None], extra_benchmark_configs: Optional[List[Dict[str, Any]]] = None, @@ -425,6 +425,17 @@ def parse_benchmark_script_args(): action="store_true", help="Flag to overwrite existing benchmark data with current run.", ) + parser.add_argument( + "--sweep-mode", + type=str, + default="token_length", + choices=["token_length", "model_config"], + help=( + "Benchmark sweep dimension. " + "'token_length': sweep sequence length with fixed model. " + "'model_config': sweep model configs with fixed token length." + ), + ) parser.add_argument( "--model", type=str, @@ -435,5 +446,14 @@ def parse_benchmark_script_args(): "Defaults to llama_3_8b when not specified." ), ) + parser.add_argument( + "--bt", + type=int, + default=2048, + help=( + "Target total tokens (batch_size * seq_len) for model-config " + "sweep. Only used when --sweep-mode=model_config." + ), + ) args = parser.parse_args() return args