Skip to content

Commit f21b341

Browse files
Revert "[mxfp] remove col-major assert for mx weight (#8249)"
This reverts commit 60605d8.
1 parent e52429a commit f21b341

File tree

2 files changed

+11
-73
lines changed

2 files changed

+11
-73
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 10 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,6 @@ class Case:
197197
x_transpose: bool = False
198198
w_transpose: bool = False
199199
y_transpose: bool = False
200-
colmajor_mxfp_weight: bool = True
201200

202201

203202
@pytest.mark.parametrize(
@@ -270,7 +269,6 @@ class Case:
270269
Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 2, 1),
271270
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9),
272271
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),
273-
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, colmajor_mxfp_weight=False),
274272
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2),
275273
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),
276274
Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4),
@@ -317,7 +315,7 @@ class Case:
317315
@pytest.mark.parametrize("has_y_gammas", [False, True])
318316
@pytest.mark.parametrize("is_persistent", [False, True])
319317
def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot,
320-
n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
318+
n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, epilogue_subtile,
321319
x_transpose, w_transpose, y_transpose,
322320
device, opt_flags_scope):
323321
# TODO: remove when Triton FP8 supports proper RTNE
@@ -465,72 +463,14 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
465463
w_scale_layout, w_scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(
466464
mx_axis=mx_axis, num_warps=8)
467465
# downcast to mxfp
468-
w_tri_orig = w_tri
469-
if colmajor_mxfp_weight:
470-
w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis)
471-
w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=mx_axis)
472-
w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
473-
w_tri = wrap_torch_tensor(w_tri, w_tri_dtype)
474-
w_scale_tri = wrap_torch_tensor(w_scale_tri)
475-
# convert layouts
476-
w_tri = convert_layout(w_tri, w_layout, **w_layout_opts)
477-
w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts)
478-
else:
479-
if torch.cuda.get_device_capability()[0] < 10:
480-
pytest.skip("transposed mxfp weight not supported with cuda capability < 10")
481-
if block_m == 16:
482-
pytest.skip("PassManager::run failed from Triton compiler")
483-
# TODO: swizzling for rowmajor
484-
485-
# A typical use case is we already quantized col-major weight,
486-
# and we want matmul with its transposed row-major weight w/o
487-
# requantization.
488-
489-
# put abs_max of each 32x32 block to diagonal so scales of transposed agree
490-
w_ndim = w_tri.ndim
491-
if w_ndim == 2:
492-
w_tri = w_tri.unsqueeze(0)
493-
BLOCK_SIZE = int(MXFP_BLOCK_SIZE)
494-
for e, i, j in itertools.product(range(w_tri.shape[0]), range(0, w_tri.shape[1], BLOCK_SIZE), range(0, w_tri.shape[2], BLOCK_SIZE)):
495-
i_end = min(i+BLOCK_SIZE, w_tri.shape[1])
496-
j_end = min(j+BLOCK_SIZE, w_tri.shape[2])
497-
block = w_tri[e, i:i_end, j:j_end]
498-
m_abs = block.abs().max()
499-
i_len = i_end - i
500-
j_len = j_end - j
501-
min_len = min(i_len, j_len)
502-
signs = torch.randint(0, 2, (max(i_len, j_len),), device=w_tri.device) * 2 - 1
503-
block.diagonal(dim1=-2, dim2=-1)[:] = signs[:min_len] * m_abs
504-
if j_len > i_len:
505-
block[i_len - 1, i_len:] = signs[min_len:] * m_abs
506-
elif i_len > j_len:
507-
block[j_len:, j_len - 1] = signs[min_len:] * m_abs
508-
if w_ndim == 2:
509-
w_tri = w_tri.squeeze(0)
510-
511-
# matmul with rowmajor weight expects scale is separately
512-
# constructed (not much additional memory needed).
513-
_, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis)
514-
# reuse quantized value from colmajor
515-
w_tri_rowmajor, w_scale_tri_rowmajor = downcast_to_mxfp(w_tri.mT.contiguous(), weight_dtype, axis=mx_axis)
516-
w_ref = upcast_from_mxfp(w_tri_rowmajor, w_scale_tri_rowmajor, torch.bfloat16, axis=mx_axis).mT.contiguous()
517-
w_tri = w_tri_rowmajor.data.mT
518-
519-
def _pad_and_block(x: torch.Tensor) -> torch.Tensor:
520-
x = torch.nn.functional.pad(x, (0, x.shape[-1] % BLOCK_SIZE), mode="replicate")
521-
return x.view(*x.shape[:-1], x.shape[-1] // BLOCK_SIZE, BLOCK_SIZE)
522-
523-
# check if generated scale is transpose-invariant as intended construction
524-
# [cdiv(K, 32), N] -> dedup to [cdiv(K, 32), cdiv(N, 32)]
525-
w_scale_tri_blocked = _pad_and_block(w_scale_tri)
526-
w_scale_tri_sampled = w_scale_tri_blocked[..., 0:1]
527-
# [cdiv(N, 32), K] -> dedup to [cdiv(N, 32), cdiv(K, 32)]
528-
w_scale_tri_rowmajor_blocked = _pad_and_block(w_scale_tri_rowmajor)
529-
w_scale_tri_rowmajor_sampled = w_scale_tri_rowmajor_blocked[..., 0:1]
530-
assert torch.equal(w_scale_tri_sampled.expand_as(w_scale_tri_blocked), w_scale_tri_blocked)
531-
assert torch.equal(w_scale_tri_rowmajor_sampled.expand_as(w_scale_tri_rowmajor_blocked), w_scale_tri_rowmajor_blocked)
532-
assert torch.equal(w_scale_tri_sampled.squeeze(-1), w_scale_tri_rowmajor_sampled.squeeze(-1).mT)
533-
466+
w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis)
467+
w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=mx_axis)
468+
w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
469+
w_tri = wrap_torch_tensor(w_tri, w_tri_dtype)
470+
w_scale_tri = wrap_torch_tensor(w_scale_tri)
471+
# convert layouts
472+
w_tri = convert_layout(w_tri, w_layout, **w_layout_opts)
473+
w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts)
534474
precision_opt.weight_scale = w_scale_tri
535475
epilogue = None
536476
if act_mxfp8:
@@ -539,7 +479,7 @@ def _pad_and_block(x: torch.Tensor) -> torch.Tensor:
539479
is_input_batched = x_tri.ndim == 3
540480
y_shape = x_tri.shape if is_input_batched else (1,) + x_tri.shape
541481
n_rows = y_shape[1] if gindx is None or mode == "batched" else gindx.dst_indx.shape[0]
542-
y_shape = (y_shape[0], n_rows, w_tri_orig.shape[-1])
482+
y_shape = (y_shape[0], n_rows, w_tri.shape[-1])
543483
if sindx is None or mode == "batched":
544484
if not is_input_batched:
545485
y_shape = (y_shape[1], y_shape[2])

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from .matmul_ogs_details._p_matmul_ogs import _p_matmul_ogs, get_per_device_per_stream_alloc_fn
1818
from .matmul_ogs_details._reduce_grouped import _reduce_grouped
1919
from .numerics_details.mxfp import MXFP_BLOCK_SIZE
20-
from .tensor_details.layout_details.strided import StridedLayout
2120
from .matmul_ogs_details.opt_flags import make_opt_flags, update_opt_flags_constraints, InapplicableConstraint
2221
from .specialize import specialize
2322
from .tensor import Storage, Tensor, FP4, bitwidth, wrap_torch_tensor
@@ -442,13 +441,12 @@ def matmul_ogs(x, w, bias,
442441
w_scale = precision_config.weight_scale
443442
w_has_mx = w_scale is not None
444443
is_hopper_fp8 = is_cuda() and not target_info.cuda_capability_geq(10, 0) and bitwidth(w.dtype) == 8
444+
if w_has_mx: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp"
445445
if is_hopper_fp8: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type FP8 on capability < 10"
446446
if not isinstance(w, Tensor):
447447
# TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real
448448
dtype = FP4 if w.dtype == torch.uint8 else w.dtype
449449
w = wrap_torch_tensor(w, dtype=dtype)
450-
if w_has_mx and (torch.cuda.get_device_capability()[0] < 10 or w.storage.layout is not None and not isinstance(w.storage.layout, StridedLayout)):
451-
assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp and (swizzled or not on >=Blackwell)"
452450
if w_scale is not None and not isinstance(w_scale, Tensor):
453451
w_scale = Tensor(w_scale)
454452
if w_scale is not None:

0 commit comments

Comments
 (0)