Skip to content

Commit e52429a

Browse files
Merge commit '0173f7524d8cfc9a5b4b52dec0010eaedef14526'
2 parents 609e327 + 0173f75 commit e52429a

File tree

19 files changed

+391
-29
lines changed

19 files changed

+391
-29
lines changed

.github/workflows/integration-tests-amd.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ jobs:
1313
integration-tests-amd:
1414
runs-on: ${{ matrix.runner }}
1515
timeout-minutes: 45
16+
continue-on-error: ${{ matrix.runner[1] == 'gfx90a' }}
1617
strategy:
1718
matrix:
1819
runner: ${{ fromJson(inputs.matrix) }}

python/test/unit/language/test_core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1659,6 +1659,8 @@ def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
16591659
@pytest.mark.parametrize("num_ctas", num_ctas_list)
16601660
@pytest.mark.parametrize("dtype_str", ["int32", "int64"])
16611661
def test_atomic_cas(sem, num_ctas, dtype_str, device):
1662+
if is_hip_cdna2():
1663+
pytest.skip("Disabled due to being flaky on CDNA2")
16621664
# 1. make sure that atomic_cas changes the original value (Lock)
16631665
@triton.jit
16641666
def change_value(Lock, triton_dtype: tl.constexpr):

python/triton/experimental/gluon/language/amd/gfx1250/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ def wmma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, _semantic=None)
5959
"accumulator tensor's layout must be (16, 16, 128)"
6060

6161
# TODO: Add more formats
62-
assert a_format.value in {"e2m1"}, f"Unsupported lhs_format: {a_format.value}"
63-
assert b_format.value in {"e2m1"}, f"Unsupported rhs_format: {b_format.value}"
62+
assert a_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported lhs_format: {a_format.value}"
63+
assert b_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported rhs_format: {b_format.value}"
6464

6565
assert a_scale is not None and b_scale is not None, "Scales must not be None"
6666

python/triton_kernels/tests/test_matmul.py

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ class Case:
197197
x_transpose: bool = False
198198
w_transpose: bool = False
199199
y_transpose: bool = False
200+
colmajor_mxfp_weight: bool = True
200201

201202

202203
@pytest.mark.parametrize(
@@ -269,6 +270,7 @@ class Case:
269270
Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 2, 1),
270271
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9),
271272
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),
273+
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, colmajor_mxfp_weight=False),
272274
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2),
273275
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),
274276
Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4),
@@ -315,7 +317,7 @@ class Case:
315317
@pytest.mark.parametrize("has_y_gammas", [False, True])
316318
@pytest.mark.parametrize("is_persistent", [False, True])
317319
def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot,
318-
n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, epilogue_subtile,
320+
n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
319321
x_transpose, w_transpose, y_transpose,
320322
device, opt_flags_scope):
321323
# TODO: remove when Triton FP8 supports proper RTNE
@@ -463,14 +465,72 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
463465
w_scale_layout, w_scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(
464466
mx_axis=mx_axis, num_warps=8)
465467
# downcast to mxfp
466-
w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis)
467-
w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=mx_axis)
468-
w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
469-
w_tri = wrap_torch_tensor(w_tri, w_tri_dtype)
470-
w_scale_tri = wrap_torch_tensor(w_scale_tri)
471-
# convert layouts
472-
w_tri = convert_layout(w_tri, w_layout, **w_layout_opts)
473-
w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts)
468+
w_tri_orig = w_tri
469+
if colmajor_mxfp_weight:
470+
w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis)
471+
w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=mx_axis)
472+
w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
473+
w_tri = wrap_torch_tensor(w_tri, w_tri_dtype)
474+
w_scale_tri = wrap_torch_tensor(w_scale_tri)
475+
# convert layouts
476+
w_tri = convert_layout(w_tri, w_layout, **w_layout_opts)
477+
w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts)
478+
else:
479+
if torch.cuda.get_device_capability()[0] < 10:
480+
pytest.skip("transposed mxfp weight not supported with cuda capability < 10")
481+
if block_m == 16:
482+
pytest.skip("PassManager::run failed from Triton compiler")
483+
# TODO: swizzling for rowmajor
484+
485+
# A typical use case is we already quantized col-major weight,
486+
# and we want matmul with its transposed row-major weight w/o
487+
# requantization.
488+
489+
# put abs_max of each 32x32 block to diagonal so scales of transposed agree
490+
w_ndim = w_tri.ndim
491+
if w_ndim == 2:
492+
w_tri = w_tri.unsqueeze(0)
493+
BLOCK_SIZE = int(MXFP_BLOCK_SIZE)
494+
for e, i, j in itertools.product(range(w_tri.shape[0]), range(0, w_tri.shape[1], BLOCK_SIZE), range(0, w_tri.shape[2], BLOCK_SIZE)):
495+
i_end = min(i+BLOCK_SIZE, w_tri.shape[1])
496+
j_end = min(j+BLOCK_SIZE, w_tri.shape[2])
497+
block = w_tri[e, i:i_end, j:j_end]
498+
m_abs = block.abs().max()
499+
i_len = i_end - i
500+
j_len = j_end - j
501+
min_len = min(i_len, j_len)
502+
signs = torch.randint(0, 2, (max(i_len, j_len),), device=w_tri.device) * 2 - 1
503+
block.diagonal(dim1=-2, dim2=-1)[:] = signs[:min_len] * m_abs
504+
if j_len > i_len:
505+
block[i_len - 1, i_len:] = signs[min_len:] * m_abs
506+
elif i_len > j_len:
507+
block[j_len:, j_len - 1] = signs[min_len:] * m_abs
508+
if w_ndim == 2:
509+
w_tri = w_tri.squeeze(0)
510+
511+
# matmul with rowmajor weight expects scale is separately
512+
# constructed (not much additional memory needed).
513+
_, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis)
514+
# reuse quantized value from colmajor
515+
w_tri_rowmajor, w_scale_tri_rowmajor = downcast_to_mxfp(w_tri.mT.contiguous(), weight_dtype, axis=mx_axis)
516+
w_ref = upcast_from_mxfp(w_tri_rowmajor, w_scale_tri_rowmajor, torch.bfloat16, axis=mx_axis).mT.contiguous()
517+
w_tri = w_tri_rowmajor.data.mT
518+
519+
def _pad_and_block(x: torch.Tensor) -> torch.Tensor:
520+
x = torch.nn.functional.pad(x, (0, x.shape[-1] % BLOCK_SIZE), mode="replicate")
521+
return x.view(*x.shape[:-1], x.shape[-1] // BLOCK_SIZE, BLOCK_SIZE)
522+
523+
# check if generated scale is transpose-invariant as intended construction
524+
# [cdiv(K, 32), N] -> dedup to [cdiv(K, 32), cdiv(N, 32)]
525+
w_scale_tri_blocked = _pad_and_block(w_scale_tri)
526+
w_scale_tri_sampled = w_scale_tri_blocked[..., 0:1]
527+
# [cdiv(N, 32), K] -> dedup to [cdiv(N, 32), cdiv(K, 32)]
528+
w_scale_tri_rowmajor_blocked = _pad_and_block(w_scale_tri_rowmajor)
529+
w_scale_tri_rowmajor_sampled = w_scale_tri_rowmajor_blocked[..., 0:1]
530+
assert torch.equal(w_scale_tri_sampled.expand_as(w_scale_tri_blocked), w_scale_tri_blocked)
531+
assert torch.equal(w_scale_tri_rowmajor_sampled.expand_as(w_scale_tri_rowmajor_blocked), w_scale_tri_rowmajor_blocked)
532+
assert torch.equal(w_scale_tri_sampled.squeeze(-1), w_scale_tri_rowmajor_sampled.squeeze(-1).mT)
533+
474534
precision_opt.weight_scale = w_scale_tri
475535
epilogue = None
476536
if act_mxfp8:
@@ -479,7 +539,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
479539
is_input_batched = x_tri.ndim == 3
480540
y_shape = x_tri.shape if is_input_batched else (1,) + x_tri.shape
481541
n_rows = y_shape[1] if gindx is None or mode == "batched" else gindx.dst_indx.shape[0]
482-
y_shape = (y_shape[0], n_rows, w_tri.shape[-1])
542+
y_shape = (y_shape[0], n_rows, w_tri_orig.shape[-1])
483543
if sindx is None or mode == "batched":
484544
if not is_input_batched:
485545
y_shape = (y_shape[1], y_shape[2])

python/triton_kernels/tests/test_mxfp.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,22 @@ def test_mxfp4_rounding_cases(dst_dtype, device):
4545
assert_equal(dequant_torch, dequant)
4646

4747

48+
@pytest.mark.parametrize("src_dtype", ["float4_e2m1", "float8_e5m2", "float8_e4m3fn"])
49+
@pytest.mark.parametrize("dst_dtype", ["float16", "bfloat16", "float32"])
50+
def test_mxfp_extreme_values(src_dtype, dst_dtype, device):
51+
if "float8" in src_dtype and (is_cuda() and torch.cuda.get_device_capability()[0] < 9):
52+
pytest.skip("Float8 not tested on A100")
53+
src_dtype = dtype_str_to_torch(src_dtype)
54+
dst_dtype = dtype_str_to_torch(dst_dtype)
55+
BIG_VALUE = 65470 if dst_dtype == torch.float16 else 3.3895e38
56+
x = torch.tensor([BIG_VALUE, BIG_VALUE], dtype=dst_dtype, device=device)
57+
xq_value, xq_scale = downcast_to_mxfp(x, src_dtype, axis=-1)
58+
xdq = upcast_from_mxfp(xq_value, xq_scale, dst_dtype, axis=-1)
59+
xdq_ref = upcast_from_mxfp_torch(xq_value, xq_scale, dst_dtype, axis=-1)
60+
assert_equal(xdq_ref, xdq)
61+
assert not xdq.isinf().any()
62+
63+
4864
@pytest.mark.parametrize("src_dtype", ["float4_e2m1", "float8_e5m2", "float8_e4m3fn"])
4965
@pytest.mark.parametrize("dst_dtype", ["float16", "bfloat16", "float32"])
5066
def test_mxfp_quant_dequant(src_dtype, dst_dtype, device):

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .matmul_ogs_details._p_matmul_ogs import _p_matmul_ogs, get_per_device_per_stream_alloc_fn
1818
from .matmul_ogs_details._reduce_grouped import _reduce_grouped
1919
from .numerics_details.mxfp import MXFP_BLOCK_SIZE
20+
from .tensor_details.layout_details.strided import StridedLayout
2021
from .matmul_ogs_details.opt_flags import make_opt_flags, update_opt_flags_constraints, InapplicableConstraint
2122
from .specialize import specialize
2223
from .tensor import Storage, Tensor, FP4, bitwidth, wrap_torch_tensor
@@ -441,12 +442,13 @@ def matmul_ogs(x, w, bias,
441442
w_scale = precision_config.weight_scale
442443
w_has_mx = w_scale is not None
443444
is_hopper_fp8 = is_cuda() and not target_info.cuda_capability_geq(10, 0) and bitwidth(w.dtype) == 8
444-
if w_has_mx: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp"
445445
if is_hopper_fp8: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type FP8 on capability < 10"
446446
if not isinstance(w, Tensor):
447447
# TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real
448448
dtype = FP4 if w.dtype == torch.uint8 else w.dtype
449449
w = wrap_torch_tensor(w, dtype=dtype)
450+
if w_has_mx and (torch.cuda.get_device_capability()[0] < 10 or w.storage.layout is not None and not isinstance(w.storage.layout, StridedLayout)):
451+
assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp and (swizzled or not on >=Blackwell)"
450452
if w_scale is not None and not isinstance(w_scale, Tensor):
451453
w_scale = Tensor(w_scale)
452454
if w_scale is not None:

python/triton_kernels/triton_kernels/numerics_details/mxfp.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,17 @@ def upcast_from_mxfp_torch(tensor: torch.Tensor, scale: torch.Tensor, target_dty
297297
padded_tensor = padded_tensor.view(*new_shape)
298298
dq_scale_padded = dq_scale.unsqueeze(-1) # shape: [..., ceil(axis_shape/32), 1]
299299
out_padded = padded_tensor * dq_scale_padded
300+
# Need to clamp since due to rounding, we can have overflow that was within
301+
# the range before quantization.
302+
# e.g., 3.3895e+38 -> log2(3.3895e+38 / max_fp8e4m3=448) ~= 119.17 -> round
303+
# up to 120 + exp_bias=127 -> scale=247
304+
# 3.3895e+38 / 2**120 ~= 254.9976 -> round to 256 in fp8e4m3fn
305+
# Dequantization: 256 * 2**120 > 3.4e38 overflowing 3.38953139e38
306+
finfo = torch.finfo(target_dtype)
307+
out_padded = (padded_tensor * dq_scale_padded).clamp(finfo.min, finfo.max)
308+
if tensor.dtype == torch.float8_e5m2:
309+
# fp8e5m2 can have inf and we want to preserve so separately handle
310+
out_padded = out_padded.where(~padded_tensor.isinf(), padded_tensor.to(target_dtype))
300311

301312
# Flatten back and remove the padded tail
302313
out_padded = out_padded.view(*fp32_tensor.shape[:-1], new_axis_shape)

python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,16 @@ def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr, mx_
119119
scale = scale.reshape(dst_scale.shape)
120120

121121
out_tensor = dst_tensor * dst_scale
122+
if dst_dtype == tl.float32:
123+
max_fin = 3.4028234663852886e+38
124+
elif dst_dtype == tl.bfloat16:
125+
max_fin = 3.3895313892515355e+38
126+
else:
127+
tl.static_assert(dst_dtype == tl.float16)
128+
max_fin = 65504
129+
# TODO: handle infinity same as upcast_from_mxfp_torch together with the
130+
# above FIXME
131+
out_tensor = tl.clamp(out_tensor, min=-max_fin, max=max_fin)
122132
# Correct any NaNs encoded via the scale.
123133
out_tensor = tl.where(scale == 0xFF, float("nan"), out_tensor)
124134
out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])

test/Conversion/amd/async_ops_to_llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
138138
tt.func public @async_commit_group(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
139139
%arg1: i32 {tt.divisibility = 16 : i32},
140140
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
141-
// CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32
141+
// CHECK: llvm.mlir.constant(0 : i32) : i32
142142
// CHECK-NEXT: llvm.return
143143
ttg.async_commit_group
144144
tt.return

test/Conversion/amd/ds_transpose.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#mma32_scaled = #ttg.amd_mfma<{version = 4, warpsPerCTA = [2, 2], instrShape = [32, 32, 64], isTransposed = true}>
66
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
77
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
8+
#padding = #ttg.padded_shared<[512:+16] {order = [0, 1], shape = [128, 64]}>
9+
#padding_vec1 = #ttg.padded_shared<[1:+4] {order = [0, 1], shape = [128, 64]}>
810
#smem = #ttg.shared_memory
911

1012
#linear_ds_tr_tile_out = #ttg.linear<{register = [[0, 1], [0, 2], [0, 8], [0, 4]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [32, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}>
@@ -688,4 +690,25 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
688690
tt.store %ptr1, %a1 : tensor<64x16x!tt.ptr<bf16>, #linear_ds_tr_tile_invalid>
689691
tt.return
690692
}
693+
694+
// CHECK-LABEL: ds_transpose_with_padding
695+
tt.func @ds_transpose_with_padding(%arg0: !ttg.memdesc<128x64xf16, #padding, #smem, mutable>, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
696+
// CHECK-COUNT-16: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16>
697+
// CHECK-NOT: rocdl.ds.read.tr16.b64
698+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #padding, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>>
699+
700+
%ptr1 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>>
701+
tt.store %ptr1, %1 : tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>>
702+
tt.return
703+
}
704+
705+
// CHECK-LABEL: ds_transpose_padding_interval_too_small
706+
tt.func @ds_transpose_padding_interval_too_small(%arg0: !ttg.memdesc<128x64xf16, #padding_vec1, #smem, mutable>, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
707+
// CHECK-NOT: rocdl.ds.read.tr16.b64
708+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #padding_vec1, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>>
709+
710+
%ptr1 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>>
711+
tt.store %ptr1, %1 : tensor<128x64x!tt.ptr<f16>, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>>
712+
tt.return
713+
}
691714
}

0 commit comments

Comments
 (0)