Skip to content

Commit ff57a4d

Browse files
authored
[Bench][AMD] Tune MoE compilation config for GFX950 (#7127)
1 parent 6dd7d6a commit ff57a4d

File tree

4 files changed

+12
-5
lines changed

4 files changed

+12
-5
lines changed

python/triton_kernels/bench/bench_mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_
198198
x_bw = [x_bw[0], x_comp[0]]
199199
y_bw = [opints[0] * max_tbps, max_tflops]
200200
y_comp = [max_tflops] * len(x_comp)
201-
ax.plot(x_bw, y_bw, "--", label=f"BW-bound ({max_tbps:.0f} TB/s)")
201+
ax.plot(x_bw, y_bw, "--", label=f"BW-bound ({max_tbps:.1f} TB/s)")
202202
ax.plot(x_comp, y_comp, "--", label=f"Compute-bound ({max_tflops:.0f} TFLOP/s)")
203203
# plot data
204204
ax.scatter(xs, perf, marker="+")

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from dataclasses import dataclass
22
import triton
33
from triton_kernels.numerics_details.mxfp import SwizzlingType
4+
from triton_kernels.target_info import get_cdna_version
45
import torch
56

67
from . import opt_flags_amd, opt_flags_nvidia
@@ -55,15 +56,20 @@ def make_default_opt_flags_amd(
5556
tokens_per_expt = max(1, m // routing_data.n_expts_tot)
5657
else:
5758
tokens_per_expt = routing_data.expected_tokens_per_expt
59+
60+
is_cdna4 = get_cdna_version() == 4
5861
# block_m
5962
if constraints.get("block_m", None):
6063
block_m = constraints["block_m"]
6164
elif enforce_bitwise_invariance:
62-
block_m = 128
65+
block_m = 256 if is_cdna4 else 128
6366
elif tokens_per_expt >= 512 and n >= 2048:
67+
block_m = 256 if is_cdna4 else 128
68+
elif is_cdna4 and m >= 512:
6469
block_m = 128
6570
else:
6671
block_m = max(32, min(triton.next_power_of_2(tokens_per_expt), 64))
72+
6773
if routing_data is not None:
6874
grid_m = routing_data.n_blocks(m, block_m)
6975
else:

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_amd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
def compute_block_nk(n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, microscaling_ctx):
77
lhs_width = lhs_dtype.itemsize
8-
rhs_width = rhs_dtype.itemsize if microscaling_ctx.weight_scale is None else 0.5
8+
rhs_width = rhs_dtype.itemsize if rhs_dtype != torch.uint8 else 0.5
99

1010
# block_n:
1111
n_cu = torch.cuda.get_device_properties(0).multi_processor_count
@@ -27,6 +27,6 @@ def compute_block_nk(n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, microsc
2727

2828
# TODO: block_k = 128 seems to work better for now.
2929
# perhaps due to increased number of k loops to pipeline
30-
if microscaling_ctx.weight_scale is not None:
30+
if microscaling_ctx.weight_scale is not None and get_cdna_version() != 4:
3131
block_k = 128
3232
return block_n, block_k

python/triton_kernels/triton_kernels/routing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .routing_details._routing_compute import _routing_clear_bitmatrix
88
from .routing_details._expt_data import _expt_data_memset
99
from .routing_details._expt_data import _expt_data_compute
10+
from .target_info import is_hip
1011

1112

1213
@dataclass
@@ -202,7 +203,7 @@ def compute_expt_data(expt_hist, n_expts_tot, n_gates):
202203
cdiv = triton.cdiv
203204
# block_ms are all powers-of-two between 16 and 128 (inclusive)
204205
block_m_log2_start = 4
205-
block_m_log2_end = 8
206+
block_m_log2_end = 9 if is_hip() else 8
206207
block_m_num = block_m_log2_end - block_m_log2_start
207208
if n_gates <= n_expts_tot:
208209
max_n_tiles = n_gates

0 commit comments

Comments
 (0)