diff --git a/benchmarks/third_party/sglang/scaled_mm_benchmark.py b/benchmarks/third_party/sglang/scaled_mm_benchmark.py index d79850a3b9..772e468415 100644 --- a/benchmarks/third_party/sglang/scaled_mm_benchmark.py +++ b/benchmarks/third_party/sglang/scaled_mm_benchmark.py @@ -23,20 +23,20 @@ def is_weak_contiguous(x: torch.Tensor): def get_matmul_batched_autotune_configs() -> List[triton.Config]: configs = [ - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 32, 'grf_mode': 'large'}, num_stages=s, num_warps=32) + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 32, 'grf_mode': '256'}, num_stages=s, num_warps=32) for s in [2, 3] ] + [ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'grf_mode': m}, num_stages=s, num_warps=w) for s in [2] - for (m, w) in ([('large', 32), ('small', 64)]) + for (m, w) in ([('256', 32), ('128', 64)]) ] + [ - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'grf_mode': 'large'}, num_stages=s, num_warps=32) + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'grf_mode': '256'}, num_stages=s, num_warps=32) for s in [2] ] + [ - triton.Config({'BLOCK_M': 8, 'BLOCK_N': 512, 'BLOCK_K': 64, 'grf_mode': 'large'}, num_stages=s, num_warps=32) + triton.Config({'BLOCK_M': 8, 'BLOCK_N': 512, 'BLOCK_K': 64, 'grf_mode': '256'}, num_stages=s, num_warps=32) for s in [2] ] + [ - triton.Config({'BLOCK_M': 8, 'BLOCK_N': 128, 'BLOCK_K': 64, 'grf_mode': 'large'}, num_stages=s, num_warps=4) + triton.Config({'BLOCK_M': 8, 'BLOCK_N': 128, 'BLOCK_K': 64, 'grf_mode': '256'}, num_stages=s, num_warps=4) for s in [2] ] return configs diff --git a/benchmarks/third_party/vllm/batched_moe_benchmark.py b/benchmarks/third_party/vllm/batched_moe_benchmark.py index 18685edf95..e76e411477 100644 --- a/benchmarks/third_party/vllm/batched_moe_benchmark.py +++ b/benchmarks/third_party/vllm/batched_moe_benchmark.py @@ -201,20 +201,20 @@ def expert_triton_kernel( def get_matmul_batched_autotune_configs() -> List[triton.Config]: configs = [ - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 32, 'grf_mode': 'large'}, num_stages=s, num_warps=32) + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 32, 'grf_mode': '256'}, num_stages=s, num_warps=32) for s in [2, 3] ] + [ triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'grf_mode': m}, num_stages=s, num_warps=w) for s in [2] - for (m, w) in ([('large', 32), ('small', 64)]) + for (m, w) in ([('256', 32), ('128', 64)]) ] + [ - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'grf_mode': 'large'}, num_stages=s, num_warps=32) + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'grf_mode': '256'}, num_stages=s, num_warps=32) for s in [2] ] + [ - triton.Config({'BLOCK_M': 8, 'BLOCK_N': 512, 'BLOCK_K': 64, 'grf_mode': 'large'}, num_stages=s, num_warps=32) + triton.Config({'BLOCK_M': 8, 'BLOCK_N': 512, 'BLOCK_K': 64, 'grf_mode': '256'}, num_stages=s, num_warps=32) for s in [2] ] + [ - triton.Config({'BLOCK_M': 8, 'BLOCK_N': 128, 'BLOCK_K': 64, 'grf_mode': 'large'}, num_stages=s, num_warps=4) + triton.Config({'BLOCK_M': 8, 'BLOCK_N': 128, 'BLOCK_K': 64, 'grf_mode': '256'}, num_stages=s, num_warps=4) for s in [2] ] return configs diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py index 9f87db696b..6ad0c8734b 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py @@ -171,7 +171,7 @@ def _attn_fwd_with_block_pointers(Q, K, V, sm_scale, M, Out, # configs = [ - triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN, 'grf_mode': 'large'}, num_stages=s, num_warps=w) \ + triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN, 'grf_mode': '256'}, num_stages=s, num_warps=w) \ for BM in [128, 256] \ for BN in [32, 64] \ for s in [2, 3, 4] \ diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index fddd5ebcfb..da4989a576 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -21,18 +21,18 @@ def get_matmul_autotune_configs() -> List[triton.Config]: configs = [ triton.Config( - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=s, num_warps=32) for s in [1, 2, 3] ] + [ triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m}, - num_stages=s, num_warps=w) for s in [2, 3, 4] for (m, w) in ([('large', 32), ('small', 64)]) + num_stages=s, num_warps=w) for s in [2, 3, 4] for (m, w) in ([('256', 32), ('128', 64)]) ] + [ triton.Config( - {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=s, num_warps=32) for s in [2] ] + [ triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': m}, - num_stages=s, num_warps=w) for s in [2, 3] for (m, w) in ([('large', 32), ('small', 64)]) + num_stages=s, num_warps=w) for s in [2, 3] for (m, w) in ([('256', 32), ('128', 64)]) ] return configs @@ -88,26 +88,26 @@ def matmul_kernel_with_block_pointers( def get_matmul_batched_autotune_configs() -> List[triton.Config]: configs = [ triton.Config( - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=s, num_warps=32) for s in [2, 3] ] + [ triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m}, - num_stages=s, num_warps=w) for s in [2] for (m, w) in ([('large', 32), ('small', 64)]) + num_stages=s, num_warps=w) for s in [2] for (m, w) in ([('256', 32), ('128', 64)]) ] + [ triton.Config( - {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 1024, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 1024, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=s, num_warps=32) for s in [2, 3] ] + [ triton.Config( - {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=s, num_warps=32) for s in [2] ] + [ triton.Config( - {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': '256'}, num_stages=s, num_warps=32) for s in [2] ] + [ triton.Config( - {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': '256'}, num_stages=s, num_warps=4) for s in [2] ] return configs diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py index 2cbd928ea3..d45075260f 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py @@ -37,19 +37,19 @@ def suffix(): @triton.autotune( configs=[ triton.Config( - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=2, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=3, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=2, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=2, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': '256'}, num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], @@ -105,22 +105,22 @@ def matmul_kernel_with_tensor_descriptors( @triton.autotune( configs=[ triton.Config( - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=2, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=3, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=2, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=2, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': '256'}, num_stages=2, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': '256'}, num_stages=2, num_warps=4), ], key=['M', 'N', 'K'], diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py index 560c552b3f..e5c7f6e9e0 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py @@ -35,19 +35,19 @@ def gelu(x): @triton.autotune( configs=[ triton.Config( - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=2, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=3, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=2, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=2, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': '256'}, num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], @@ -97,22 +97,22 @@ def matmul_kernel_with_tensor_descriptors( @triton.autotune( configs=[ triton.Config( - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=2, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=3, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=2, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=2, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': '256'}, num_stages=2, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': '256'}, num_stages=2, num_warps=4), ], key=['M', 'N', 'K'], diff --git a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py index 612691fdb7..2b030879cb 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py @@ -17,19 +17,19 @@ @triton.autotune( configs=[ triton.Config( - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=2, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=3, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=2, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=2, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': '256'}, num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], @@ -82,22 +82,22 @@ def matmul_kernel_with_tensor_descriptors( @triton.autotune( configs=[ triton.Config( - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=2, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=3, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=2, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=2, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': '256'}, num_stages=2, num_warps=32), triton.Config( - {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': '256'}, num_stages=2, num_warps=4), ], key=['M', 'N', 'K'], diff --git a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py index e6ca4c8bf6..82d7c8c78f 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py @@ -14,7 +14,7 @@ @triton.autotune( configs=[ - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 4, 'SPLIT_K': 4, 'grf_mode': 'large'}, + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 4, 'SPLIT_K': 4, 'grf_mode': '256'}, num_stages=4, num_warps=32), ], key=['M', 'N', 'K'], diff --git a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py index 27a519a28f..414a4077fb 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py @@ -96,7 +96,7 @@ def mac_loop( @triton.autotune( configs=[ triton.Config( - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], @@ -132,7 +132,7 @@ def first_wave( @triton.autotune( configs=[ triton.Config( - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], diff --git a/python/test/unit/language/test_matmul.py b/python/test/unit/language/test_matmul.py index bcb0077f6e..56a6767555 100644 --- a/python/test/unit/language/test_matmul.py +++ b/python/test/unit/language/test_matmul.py @@ -1299,7 +1299,7 @@ def create_operand(dtype: str, size0: int, size1: int, k_dim: int, transpose: bo if is_xpu() and (128, 256, 256) == (BLOCK_M, BLOCK_N, BLOCK_K) and not CONST_SCALE and not PACK_B_ALONG_K: kernel_kwargs["num_warps"] = 8 if is_xpu(): - kernel_kwargs["grf_mode"] = "large" + kernel_kwargs["grf_mode"] = "256" out = mxfp8_mxfp4_matmul[grid](a, b, output, a_scale, b_scale, M, N, K, stride_scale, a.stride(0), a.stride(1), b.stride(0), b.stride(1), output.stride(0), output.stride(1), not CONST_SCALE, dtype_converter[A_DATA_TYPE], dtype_converter[B_DATA_TYPE], BLOCK_M, BLOCK_N, diff --git a/python/triton/tools/compile.py b/python/triton/tools/compile.py index ff9ebc5b4c..dc6e27e56e 100644 --- a/python/triton/tools/compile.py +++ b/python/triton/tools/compile.py @@ -77,7 +77,7 @@ def main(): parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename") parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True) parser.add_argument("--grid", "-g", type=str, help="Launch grid of the kernel", required=True) - parser.add_argument("--grf-mode", "-gm", type=str, default="large", help="Detemine spv build flags") + parser.add_argument("--grf-mode", "-gm", type=str, default="256", help="Detemine spv build flags") parser.add_argument("--generate-native-code", "-gnc", action="store_true", help="Generate native binary instead of SPV for XPU") cli_args = parser.parse_args() diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index f7238f14e2..e5d05b4db2 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -168,7 +168,7 @@ def is_xpu(): def get_xpu_autotune_config(): return [ triton.Config( - {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=4, num_warps=32), triton.Config( {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'grf_mode': 'auto'}, diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index ff93984710..be2dc05965 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -258,7 +258,7 @@ def forward(ctx, x, normalized_shape, weight, bias, eps): _layer_norm_fwd_fused[(M, )]( # x_arg, y, weight, bias, mean, rstd, # x_arg.stride(0), N, eps, # - BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1, grf_mode='large') + BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1, grf_mode='256') ctx.save_for_backward(x, weight, bias, mean, rstd) ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps @@ -290,13 +290,13 @@ def backward(ctx, dy): x_arg.stride(0), N, # BLOCK_SIZE_N=ctx.BLOCK_SIZE, # GROUP_SIZE_M=GROUP_SIZE_M, # - num_warps=ctx.num_warps, grf_mode='large') + num_warps=ctx.num_warps, grf_mode='256') grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE_N']), ) # accumulate partial sums in separate kernel _layer_norm_bwd_dwdb[grid]( _dw, _db, dw, db, min(GROUP_SIZE_M, M), N, # BLOCK_SIZE_M=32, # - BLOCK_SIZE_N=128, num_warps=ctx.num_warps, num_ctas=1, grf_mode='large') + BLOCK_SIZE_N=128, num_warps=ctx.num_warps, num_ctas=1, grf_mode='256') return dx, None, dw, db, None diff --git a/python/tutorials/10-experimental-block-pointer.py b/python/tutorials/10-experimental-block-pointer.py index 2f8532b28e..2bc456480b 100644 --- a/python/tutorials/10-experimental-block-pointer.py +++ b/python/tutorials/10-experimental-block-pointer.py @@ -99,18 +99,18 @@ @triton.autotune( configs=[ triton.Config( - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=s, num_warps=32) for s in [1, 2, 3] ] + [ triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m}, - num_stages=s, num_warps=w) for s in [2, 3, 4] for (m, w) in ([('large', 32), ('small', 64)]) + num_stages=s, num_warps=w) for s in [2, 3, 4] for (m, w) in ([('256', 32), ('128', 64)]) ] + [ triton.Config( - {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, num_stages=s, num_warps=32) for s in [2] ] + [ triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': m}, - num_stages=s, num_warps=w) for s in [2, 3] for (m, w) in ([('large', 32), ('small', 64)]) + num_stages=s, num_warps=w) for s in [2, 3] for (m, w) in ([('256', 32), ('128', 64)]) ], key=['M', 'N', 'K'], ) diff --git a/scripts/flash_attention.py b/scripts/flash_attention.py index 2f7415b484..1966aa324a 100755 --- a/scripts/flash_attention.py +++ b/scripts/flash_attention.py @@ -42,7 +42,7 @@ def get_configs(options): warps_values = options.warps if options.warps else [8, 16, 32] split_barriers_scope = options.split_barriers_scope if options.split_barriers_scope else 'None' return [ - triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN, 'grf_mode': 'large', 'split_barriers_scope': split_barriers_scope}, + triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN, 'grf_mode': '256', 'split_barriers_scope': split_barriers_scope}, num_stages=s, num_warps=w) for BM in bm_values for BN in bn_values diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index f6525864b2..d5b4baef01 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -33,7 +33,7 @@ class XPUOptions: allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee", 'bf16x3', 'bf16x6') allow_fp8e4nv: bool = False allow_fp8e4b15: bool = True - grf_mode: tuple = ('small', 'large', 'auto', 'default') + grf_mode: str = 'default' split_barriers_scope: str = 'None' max_num_imprecise_acc_default: int = 0 # `max_num_imprecise_acc` only applies to fp8 -> fp32 dot on sm_90 for cuda extern_libs: dict = None @@ -369,14 +369,16 @@ def make_spv(src, metadata, options, device_arch): spirv, name = intel.translate_to_spirv(src) metadata["name"] = name metadata.setdefault("build_flags", "") - if options.grf_mode == 'small': + if options.grf_mode == '128': metadata["build_flags"] += " -cl-intel-128-GRF-per-thread" - elif options.grf_mode == 'large': + elif options.grf_mode == '256': if options.num_warps > 32: - raise RuntimeError("grf_mode = large cannot be used with num_warps > 32") + raise RuntimeError("grf_mode = 256 cannot be used with num_warps > 32") metadata["build_flags"] += " -cl-intel-256-GRF-per-thread" elif options.grf_mode == 'auto': metadata["build_flags"] += " -cl-intel-enable-auto-large-GRF-mode" + elif options.grf_mode != 'default': + raise RuntimeError(f"Unknown grf_mode: {options.grf_mode}") if knobs.intel.disable_igc_opt: metadata["build_flags"] += " -cl-opt-disable"