[Benchmark]: Add --sweep-mode and --bt to benchmark CLI.#1163
[Benchmark]: Add --sweep-mode and --bt to benchmark CLI.#1163noemotiovon wants to merge 3 commits intolinkedin:mainfrom
Conversation
Benchmark Framework DesignThis document describes the overall design of the Liger-Kernel benchmark suite, including its two benchmark dimensions, the shared infrastructure, and the phased implementation plan. 1. Benchmark DimensionsEvery operator should ideally be benchmarked along two orthogonal dimensions:
D1: Non-model dimension sweep (implemented)Sweep non-model dimensions (e.g. sequence length, BT) with a fixed model config selected via D2: Model dimension sweep (implemented)Sweep model architecture dimensions (e.g. hidden_size, or discrete model configs from 2. D2 Design ChoicesFollowing the maintainer discussion, we evaluated three approaches:
Decision: C as the primary approach, with A as optional enrichment for ops where single-parameter scaling is important. Rationale:
3. Universal Token Length for D2For D2 benchmarks, we need a fixed token-length that is safe (no OOM) across all model configs and all operators. Strategy
Proposed CLI# D1 (existing): token-length sweep with fixed model
python benchmark_geglu.py --model llama_3_8b
# D2 (new): model-config sweep with fixed token length
python benchmark_geglu.py --sweep-mode model_config --bt 2048The 4. Infrastructure Changes4.1 New config type@dataclass(frozen=True)
class ModelConfigSweepConfig:
"""Config for D2 benchmarks that sweep across model configs."""
model_configs: List[ModelConfig] # models to benchmark
bt: int # fixed batch * seq_len
batch_size: int # safe batch size
seq_len: int # safe seq_len4.2 New helperdef compute_model_config_sweep_config(
model_configs: List[ModelConfig],
probe_fn_factory: Callable[[ModelConfig, int], Callable[[], torch.Tensor]],
bt: int = 2048,
memory_utilization: float = 0.4,
) -> ModelConfigSweepConfig:
"""Find safe (batch_size, seq_len) that works across all model configs.
For each model config, runs probe_fn_factory(model_config, bt) to measure
peak memory, then picks the most conservative batch_size / seq_len.
"""
...4.3 Script-level changesEach benchmark script gains a model-config sweep code path gated by if args.sweep_mode == "model_config":
configs = [MODEL_REGISTRY[name] for name in MODEL_REGISTRY]
sweep = compute_model_config_sweep_config(configs, probe_fn_factory=..., bt=args.bt)
# x_values = model config indices
# extra_benchmark_configs = contains all model configs
...
else:
# existing token-length sweep logic
...4.4 VisualizationD2 results produce grouped bar charts (speedup or throughput) rather than line charts:
5. Phased Implementation PlanPhase 1: Foundation (current PR)Status: complete
Phase 2: Model-config sweep (D2)Status: complete
Phase 3: Rollout and visualizationStatus: planned
6. Directory Structure |
GEGLU Test:Script: |
Swiglu Test:Script: |
Dyt Test:Script: |
|
Hi @Tcc0403, can you take a look at my code |
- benchmark_model_configs: replace hidden-size sweep with compute_model_config_sweep_config / ModelConfigSweepConfig; probe each registry model to pick safe batch_size and seq_len for discrete sweeps. - benchmark_geglu / benchmark_swiglu: support model_config sweep across MODEL_REGISTRY via _resolve_model_config_* helpers. - benchmark_dyt: default path sweeps B*T with fixed model dimensions (compute_seq_len_sweep_config); optional model_config sweep; setup uses cfg hidden_size and input.x as BT. - utils: allow string x / x_values for model name indices; extend types. - benchmarks_visualizer: forward extra_config_filter to plotting. - BENCHMARK_GUIDELINES: document D1/D2 sweep patterns and model_config flow.
…aling - Add --sweep-mode argument (token_length|model_config) to benchmarks_visualizer.py for filtering data by sweep type via the x_name column in CSV, defaulting to token_length - Fix x-axis scaling: convert numeric x_values to proper numeric type so matplotlib plots them proportionally instead of equally spaced; string x_values (e.g. model names) remain categorical - Set tick labels only at actual data points for numeric axes - Include sweep_mode suffix in output PNG filenames to avoid overwriting when both sweep types exist for the same kernel - Update README.md with --sweep-mode usage and examples
benchmarks_visualizer.py:
- Add `--gpu-filter` CLI flag to select a specific GPU when benchmark
data contains results from multiple devices; falls back to the most
recent device with a warning when omitted or unmatched.
- Extract `gpu_name_filter()` and `extra_config_filter()` as standalone
helpers; `load_data()` now applies filters in explicit order:
kernel/metric/mode → sweep-mode → GPU → extra config.
BENCHMARK_GUIDELINES.md:
- Add guideline: import baseline kernels from the test suite instead
of duplicating reference implementations in benchmark scripts.
- Remove the continuous hidden-size sweep variant (D2.1) and
`compute_hidden_size_sweep_config()` reference; D2 now covers only
the discrete model-config sweep.
Co-authored-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com>
|
Hi @Tcc0403, I’ve made updates according to the review comments. Happy to discuss further! |






This PR follows PR1162 and implements Phase 2.
Hardware Type: Atlas 800I A2
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence