Skip to content

Commit 855fe3a

Browse files
authored
[Bench][AMD] Fix matmul tests for gfx950 (#6965)
This PR is part of the efforts fixing the tests for the benchmark on AMD gfx950 hardware. With it now `pytest -s test_matmul.py` has no failures (some tests are skipped for now.)
1 parent 442a63d commit 855fe3a

File tree

5 files changed

+65
-41
lines changed

5 files changed

+65
-41
lines changed

python/triton_kernels/bench/bench_mlp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def quantize(w, dtype, dev, **opt):
3333
MicroscalingCtx()
3434
else:
3535
assert dtype == "mx4", f"{dtype=}"
36-
swizzle_mx_scale = opt["swizzle_mx_scale"]
37-
swizzle_mx_value = opt["swizzle_mx_value"]
36+
swizzle_mx_scale = opt.get("swizzle_mx_scale", None)
37+
swizzle_mx_value = opt.get("swizzle_mx_value", None)
3838
swizzle_axis = 2 if swizzle_mx_scale else None
3939
w = w.to(torch.bfloat16)
4040
w, mx_scales, weight_scale_shape = downcast_to_mxfp(w, torch.uint8, axis=1, swizzle_axis=swizzle_axis,

python/triton_kernels/tests/test_matmul.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# testing utilities
1818
from triton_kernels.testing import assert_close, compute_actual_scale
1919
# target-specific utilities
20-
from triton_kernels.target_info import is_hip
20+
from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4
2121

2222
# ---------------
2323
# initialize data
@@ -75,18 +75,19 @@ def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_sh
7575
# ---------------
7676

7777

78-
def init_precision(out_dtype, act_use_flexpoint, weight_use_flexpoint, n_expts_tot=1, mx_ctx=MicroscalingCtx(),
79-
device="cuda"):
78+
def init_precision(out_dtype, weight_dtype, is_mixed_input, n_expts_tot=1, mx_ctx=MicroscalingCtx(), device="cuda"):
79+
act_use_flexpoint = out_dtype.itemsize == 1
80+
weight_use_flexpoint = weight_dtype.itemsize == 1 and not is_mixed_input
8081
# flexpoint
8182
make_tensor = lambda val0, val1: torch.tensor([val0, val1] * (n_expts_tot // 2) +
8283
([val0]
8384
if n_expts_tot % 2 else []), dtype=torch.float32, device=device)
8485
make_scalar = lambda val: torch.tensor([val], dtype=torch.float32, device=device)
85-
in_flex_data = lambda scale, use_flex: InFlexData(dtype=torch.float8_e5m2, scale=make_scalar(scale)
86+
in_flex_data = lambda scale, use_flex: InFlexData(dtype=out_dtype, scale=make_scalar(scale)
8687
) if use_flex else InFlexData()
87-
in_flex_edata = lambda scale0, scale1, use_flex: InFlexData(dtype=torch.float8_e5m2, scale=make_tensor(
88-
scale0, scale1)) if use_flex else InFlexData()
89-
out_flex_data = lambda scale, use_flex: OutFlexData(dtype=torch.float8_e5m2, expected_scale=make_scalar(
88+
in_flex_edata = lambda scale0, scale1, use_flex: InFlexData(dtype=weight_dtype, scale=make_tensor(scale0, scale1)
89+
) if use_flex else InFlexData()
90+
out_flex_data = lambda scale, use_flex: OutFlexData(dtype=out_dtype, expected_scale=make_scalar(
9091
scale), actual_scale=make_scalar(0), checksum_scale=make_scalar(0)) if use_flex else OutFlexData()
9192
flex_ctx = FlexCtx(
9293
lhs_data=in_flex_data(1.25, act_use_flexpoint),
@@ -211,8 +212,11 @@ class Case:
211212
Case(1000, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 3, 1),
212213
Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2),
213214
Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2, n_expt_shards=2),
214-
Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2),
215215
Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2, split_k=2),
216+
Case(300, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn"),
217+
Case(1000, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 3, 1),
218+
Case(600, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 4, 2),
219+
Case(600, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 4, 2, n_expt_shards=2),
216220
]
217221
],
218222
)
@@ -230,16 +234,26 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
230234
n_expts_act, n_expt_shards, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, epilogue_subtile,
231235
device, opt_flags_scope):
232236
# TODO: remove when Triton FP8 supports proper RTNE
233-
if "float8" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 9:
234-
pytest.skip("Float8 not tested on A100")
235-
if "float8_e4m3fnuz" in weight_dtype_str and not is_hip():
236-
pytest.skip("float8_e4m3fnuz only tested on HIP platforms")
237-
if "mx" in weight_dtype_str and is_hip():
238-
pytest.skip("mxfloat* only tested on CUDA platforms")
239-
if "float16" in act_dtype_str and "mx" in weight_dtype_str and torch.cuda.get_device_capability()[0] >= 10:
240-
pytest.skip("float16 x mx not supported with cuda capability >= 10")
241-
if "float8" in act_dtype_str and "mx" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 10:
242-
pytest.skip("float8 x mx not supported with cuda capability < 10")
237+
if is_cuda():
238+
if "float8" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 9:
239+
pytest.skip("Float8 not tested on A100")
240+
if "float16" in act_dtype_str and "mx" in weight_dtype_str and torch.cuda.get_device_capability()[0] >= 10:
241+
pytest.skip("float16 x mx not supported with cuda capability >= 10")
242+
if "float8" in act_dtype_str and "mx" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 10:
243+
pytest.skip("float8 x mx not supported with cuda capability < 10")
244+
elif is_hip():
245+
if "float8" in act_dtype_str and "mx" in weight_dtype_str and not is_hip_cdna4():
246+
pytest.skip("float8 x mx only supported on CDNA4")
247+
if "float8" in act_dtype_str and "mxfloat8" in weight_dtype_str:
248+
pytest.skip("NYI: float8 x mxfloat8 not tested on AMD GPU")
249+
if is_persistent:
250+
pytest.skip("NYI: Persistent kernel not supported on AMD GPU")
251+
if split_k > 1:
252+
pytest.skip("splitK hasn't been fully tested on AMD GPU.")
253+
254+
if "float8_e4m3fnuz" in (weight_dtype_str, act_dtype_str) and not is_hip_cdna3():
255+
pytest.skip("float8_e4m3fnuz only tested on AMD CDNA3 Platform")
256+
243257
if fused_scatter and split_k > 1:
244258
pytest.skip("fused scatter scratchpad not supported with split_k")
245259
if hbm_swizzling:
@@ -284,9 +298,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
284298
weight_dtype = dtype_str_to_torch(weight_dtype_str)
285299
act_dtype = dtype_str_to_torch(act_dtype_str)
286300
act_is_float8 = act_dtype.itemsize == 1
287-
weight_is_float8 = weight_dtype.itemsize == 1
288-
precision_opt = init_precision(act_dtype, act_is_float8, weight_is_float8 and not is_mixed_input,
289-
n_expts_tot // n_expt_shards, device=device)
301+
precision_opt = init_precision(act_dtype, weight_dtype, is_mixed_input, n_expts_tot // n_expt_shards, device=device)
290302
# precision_opt.x_pad_trans_requires_flexpoint = False
291303
if mode == "ragged":
292304
m, rdata, gindx, sindx = init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter,
@@ -456,7 +468,7 @@ def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter,
456468
else:
457469
rdata = gindx = sindx = None
458470

459-
precision_opt = init_precision(act_dtype, False, False, n_expts_tot // n_expt_shards, device=device)
471+
precision_opt = init_precision(act_dtype, weight_dtype, False, n_expts_tot // n_expt_shards, device=device)
460472
x, w, bias, _, _ = init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_shards, mode,
461473
act_dtype, weight_dtype, False, requires_grad=False, device=device)
462474

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def apply_preprocessing_features(x, w, gather_indx, scatter_indx, routing_data,
313313
has_fused_scatter_scratchpad = opt_flags.fused_scatter and routing_data.n_expts_act > 1
314314
if has_fused_scatter_scratchpad:
315315
M = scatter_indx.src_indx.shape[0]
316-
writeback_idxs = torch.empty((M,), dtype=torch.int32, device=x.device)
316+
writeback_idxs = torch.zeros((M,), dtype=torch.int32, device=x.device)
317317
writeback_size = writeback_idxs.shape[0]
318318
finalize_scatter_idxs = torch.zeros((M // routing_data.n_expts_act + M + 1,), dtype=torch.int32, device=x.device)
319319
BLOCK_M=256
@@ -494,12 +494,12 @@ def init_allocation(x, w, precision_config, fused_activation, routing_data, gath
494494
def apply_allocation(allocation: MatmulAllocation, output):
495495
ret = dict()
496496
if output is None:
497-
output = torch.empty(allocation.output[0], device=allocation.device, dtype=allocation.output[1])
497+
output = torch.zeros(allocation.output[0], device=allocation.device, dtype=allocation.output[1])
498498
else:
499499
assert output.shape == allocation.output[0]
500500
ret["output"] = output[None, :, :]
501501
ret["scratchpad"] = {
502-
k: torch.empty(v[0], device=allocation.device, dtype=v[1])
502+
k: torch.zeros(v[0], device=allocation.device, dtype=v[1])
503503
for k, v in allocation.scratchpads.items()
504504
}
505505
return ret

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,7 @@ def make_default_opt_flags_amd(
8181
# TODO: Does opt_flags_amd.compute_block_nk need to be refactored?
8282
if constraints.get("block_k", None) is not None:
8383
block_k = constraints["block_k"]
84-
if constraints.get("is_persistent", None) is not None:
85-
is_persistent = constraints["is_persistent"]
86-
else:
87-
is_persistent = False
84+
is_persistent = constraints.get("is_persistent", False)
8885
# split_k:
8986
if constraints.get("split_k", None) is not None:
9087
split_k = constraints["split_k"]
@@ -99,14 +96,6 @@ def make_default_opt_flags_amd(
9996
# num_warps, num_stages
10097
num_warps = 2 if (m is not None and m <= 16) else 8
10198
num_stages = 2
102-
if constraints.get("fused_scatter", None) is not None:
103-
fused_scatter = constraints["fused_scatter"]
104-
else:
105-
fused_scatter = False
106-
if constraints.get("epilogue_subtile", None) is not None:
107-
epilogue_subtile = constraints["epilogue_subtile"]
108-
else:
109-
epilogue_subtile = None
11099
# AMD-specific
111100
target_kernel_kwargs = {"waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 1}
112101
ret = OptFlags(
@@ -119,9 +108,9 @@ def make_default_opt_flags_amd(
119108
xcd_swizzle=xcd_swizzle,
120109
w_cache_modifier=w_cache_modifier,
121110
split_k=split_k,
122-
fused_scatter=fused_scatter,
111+
fused_scatter=constraints.get('fused_scatter', False),
123112
is_persistent=is_persistent,
124-
epilogue_subtile=epilogue_subtile,
113+
epilogue_subtile=constraints.get('epilogue_subtile', None),
125114
arch=None,
126115
target_kernel_kwargs=target_kernel_kwargs,
127116
)

python/triton_kernels/triton_kernels/target_info.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,35 @@
44
cached_capabilities = {}
55

66

7+
def is_cuda():
8+
if "is_cuda" not in cached_capabilities:
9+
target = triton.runtime.driver.active.get_current_target()
10+
cached_capabilities["is_cuda"] = False if target is None else target.backend == "cuda"
11+
return cached_capabilities["is_cuda"]
12+
13+
714
def is_hip():
815
if "is_hip" not in cached_capabilities:
916
cached_capabilities["is_hip"] = torch.cuda.is_available() and bool(torch.version.hip)
1017
return cached_capabilities["is_hip"]
1118

1219

20+
def is_hip_cdna3():
21+
if "is_hip_cdna3" not in cached_capabilities:
22+
target = triton.runtime.driver.active.get_current_target()
23+
cached_capabilities["is_hip_cdna3"] = (target is not None and target.backend == 'hip'
24+
and target.arch == 'gfx942')
25+
return cached_capabilities["is_hip_cdna3"]
26+
27+
28+
def is_hip_cdna4():
29+
if "is_hip_cdna4" not in cached_capabilities:
30+
target = triton.runtime.driver.active.get_current_target()
31+
cached_capabilities["is_hip_cdna4"] = (target is not None and target.backend == 'hip'
32+
and target.arch == 'gfx950')
33+
return cached_capabilities["is_hip_cdna4"]
34+
35+
1336
def cuda_capability_geq(major, minor=0):
1437
"""
1538
Determines whether we have compute capability >= (major, minor) and

0 commit comments

Comments
 (0)