diff --git a/benchmarks/nightly/autogen.yaml b/benchmarks/nightly/autogen.yaml index 845411b4..972af20d 100644 --- a/benchmarks/nightly/autogen.yaml +++ b/benchmarks/nightly/autogen.yaml @@ -140,3 +140,6 @@ rope_bwd: swiglu_bwd: op: swiglu args: --op swiglu --baseline torch_swiglu --metrics speedup --bwd --only liger_swiglu,torch_swiglu +launch_latency: + op: launch_latency + args: --op launch_latency --metrics walltime diff --git a/benchmarks/nightly/gen.py b/benchmarks/nightly/gen.py index 55e23413..030169ba 100644 --- a/benchmarks/nightly/gen.py +++ b/benchmarks/nightly/gen.py @@ -76,6 +76,8 @@ def process_manual_options( run_configs[benchmark]["disabled"] = True for benchmark in extra_args: run_configs[benchmark]["args"] = extra_args[benchmark]["args"] + for benchmark, benchmark_config in options.get("enabled", {}).items(): + run_configs[benchmark] = benchmark_config.copy() return run_configs diff --git a/benchmarks/nightly/manual.yaml b/benchmarks/nightly/manual.yaml index 469dac82..91a1b1fd 100644 --- a/benchmarks/nightly/manual.yaml +++ b/benchmarks/nightly/manual.yaml @@ -7,6 +7,10 @@ disabled: - fp8_gemm_fwd - fp8_gemm_rowwise_fwd - fp8_gemm_rowwise_grouped_fwd +enabled: + launch_latency: + op: launch_latency + args: --op launch_latency --metrics walltime extra_args: # triton_tutorial_flash_v2_opt does not work on Triton main branch bf16_flash_attention_fwd: diff --git a/tritonbench/operators/launch_latency/operator.py b/tritonbench/operators/launch_latency/operator.py index e2168b10..f5f219ed 100644 --- a/tritonbench/operators/launch_latency/operator.py +++ b/tritonbench/operators/launch_latency/operator.py @@ -1,6 +1,8 @@ import triton.language as tl from torch import zeros from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +from torch._inductor.utils import triton_version_uses_attrs_dict from triton.compiler import CompiledKernel from tritonbench.utils.triton_op import ( @@ -39,7 +41,10 @@ def nop_triton_compiled_kernel_run(self, *args): else: bin = nop_with_args_kernel[1,](*args) - args = args[:-5] # remove tl.constexpr args + # triton <= 3.3 does not include tl.constexpr args in call + # but triton 3.4 does + if not triton_version_uses_attrs_dict(): + args = args[:-5] function = bin.function metadata = ( bin.packed_metadata if hasattr(bin, "packed_metadata") else bin.metadata