Skip to content

Commit 19eef7c

Browse files
authored
[Bench][AMD] Fix HIP capability checks in MoE kernel (#7545)
Fix minor syntax issues on AMD side.
1 parent da3ab2a commit 19eef7c

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

python/triton_kernels/bench/bench_mlp.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from triton_kernels.matmul_ogs import matmul_ogs, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
1111
from triton_kernels.numerics import InFlexData
1212
from triton_kernels.routing import routing
13-
from triton_kernels.target_info import is_hip, get_cdna_version
13+
from triton_kernels.target_info import is_hip, get_cdna_version, is_cuda
1414
from triton_kernels.tensor import convert_layout
1515
from triton_kernels.tensor_details.layout import StridedLayout, BlackwellMXScaleLayout, HopperMXScaleLayout, HopperMXValueLayout
1616
from triton_kernels.tensor import wrap_torch_tensor, FP4
@@ -101,14 +101,15 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP,
101101
optg = dict()
102102
opt1 = dict()
103103
opt2 = dict()
104-
if w_dtype == "mx4" and not is_hip():
104+
if w_dtype == "mx4":
105105
value_layout = StridedLayout
106106
scale_layout = StridedLayout
107-
if torch.cuda.get_device_capability()[0] == 9:
108-
value_layout = HopperMXValueLayout
109-
scale_layout = HopperMXScaleLayout
110-
if torch.cuda.get_device_capability()[0] == 10:
111-
scale_layout = BlackwellMXScaleLayout
107+
if is_cuda():
108+
if torch.cuda.get_device_capability()[0] == 9:
109+
value_layout = HopperMXValueLayout
110+
scale_layout = HopperMXScaleLayout
111+
if torch.cuda.get_device_capability()[0] == 10:
112+
scale_layout = BlackwellMXScaleLayout
112113
opt1 = {"value_layout": value_layout, "scale_layout": scale_layout}
113114
opt2 = deepcopy(opt1)
114115
wg, wg_flex, wg_scale = quantize(wg, "bf16", dev, **optg)

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from triton_kernels import target_info
88
from triton_kernels.numerics import InFlexData, OutFlexData
99
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
10+
from triton_kernels.target_info import is_cuda
1011
# details
1112
from .matmul_ogs_details._matmul_ogs import _compute_writeback_idx
1213
from .matmul_ogs_details._matmul_ogs import _matmul_ogs
@@ -384,7 +385,7 @@ def matmul_ogs(x, w, bias,
384385
# unpack scales
385386
w_scale = precision_config.weight_scale
386387
has_mx = w_scale is not None
387-
is_hopper_fp8 = not target_info.cuda_capability_geq(10, 0) and bitwidth(w.dtype) == 8
388+
is_hopper_fp8 = is_cuda() and not target_info.cuda_capability_geq(10, 0) and bitwidth(w.dtype) == 8
388389
if has_mx: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp"
389390
if is_hopper_fp8: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type FP8 on capability < 10"
390391
if not isinstance(w, Tensor):

0 commit comments

Comments
 (0)