Skip to content

Commit aff4b7a

Browse files
yongjikThomasRaoux
andauthored
[KERNEL] Fix _p_matmul_ogs when x is transposed. (#8156)
Also added checks to disable persistent kernel when y is transposed. --------- Co-authored-by: Thomas Raoux <[email protected]>
1 parent e5e3dc0 commit aff4b7a

File tree

6 files changed

+59
-12
lines changed

6 files changed

+59
-12
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ class Case:
159159
split_k: int = 1
160160
hbm_swizzling: bool = False
161161
epilogue_subtile: Union[int, None] = None
162+
x_transpose: bool = False
163+
w_transpose: bool = False
164+
y_transpose: bool = False
162165

163166

164167
@pytest.mark.parametrize(
@@ -252,6 +255,13 @@ class Case:
252255
Case(1000, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 3, 1),
253256
Case(600, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 4, 2),
254257
Case(600, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 4, 2, n_expt_shards=2),
258+
] + [
259+
Case(320, 400, 400, mode, dtype, dtype, x_transpose=x_transpose, w_transpose=w_transpose, y_transpose=y_transpose)
260+
for mode in ("batched", "ragged")
261+
for dtype in ("float16", "float8_e5m2")
262+
for x_transpose in (False, True)
263+
for w_transpose in (False, True)
264+
for y_transpose in (False, True)
255265
]
256266
],
257267
)
@@ -268,6 +278,7 @@ class Case:
268278
@pytest.mark.parametrize("is_persistent", [False, True])
269279
def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas, is_persistent, n_expts_tot,
270280
n_expts_act, n_expt_shards, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, epilogue_subtile,
281+
x_transpose, w_transpose, y_transpose,
271282
device, opt_flags_scope, fresh_knobs):
272283
# TODO: remove when Triton FP8 supports proper RTNE
273284
if is_cuda():
@@ -372,6 +383,17 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
372383
has_y_gammas, requires_grad=test_bwd, device=device)
373384
x_ref, w_ref, bias_ref, gs0_ref, gs1_ref = apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_opt)
374385

386+
if x_transpose:
387+
x_tri = x_tri.detach().transpose(-1, -2).contiguous().transpose(-1, -2).requires_grad_(test_bwd)
388+
if w_transpose:
389+
w_tri = w_tri.detach().transpose(-1, -2).contiguous().transpose(-1, -2).requires_grad_(test_bwd)
390+
if y_transpose:
391+
n_rows = m if gindx is None else gindx.dst_indx.shape[0]
392+
yT_shape = (n_expts_tot, n, n_rows) if mode == "batched" else (n, n_rows)
393+
y_tri_in = torch.empty(yT_shape, dtype=act_dtype, device=device).transpose(-1, -2)
394+
else:
395+
y_tri_in = None
396+
375397
if w_tri.shape[0] == 1 and mode != "batched":
376398
# Test the case when weight has dim 2, i.e., shape (K, N).
377399
w_tri = w_tri.squeeze(0).detach().requires_grad_(test_bwd)
@@ -422,9 +444,14 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
422444

423445
# triton
424446
try:
425-
tri_y = matmul_ogs(x_tri, w_tri, bias_tri, rdata, gindx, sindx, precision_opt, gammas=gs1_ref, epilogue=epilogue)
447+
tri_y = matmul_ogs(x_tri, w_tri, bias_tri, rdata, gindx, sindx, precision_opt,
448+
gammas=gs1_ref, epilogue=epilogue, y=y_tri_in)
426449
except (opt_flags.InapplicableConstraint, NotImplementedError):
427450
pytest.skip("inapplicable opt_flags constraint")
451+
if y_tri_in is not None:
452+
assert tri_y.data_ptr() == y_tri_in.data_ptr()
453+
assert tri_y.shape == y_tri_in.shape
454+
assert tri_y.stride() == y_tri_in.stride()
428455
# If split_k > 1, then the intermediate tensor is fp32.
429456
sep_gather = mode == "ragged" and do_gather and n_expts_act > 1 and split_k == 1
430457
sep_scatter = mode == "ragged" and do_scatter and n_expts_act > 1 and split_k == 1
@@ -534,7 +561,7 @@ def test_set_idle_sms():
534561
num_idle_sms = 24
535562
matmul_ogs_set_idle_sms(num_idle_sms)
536563
flags = make_opt_flags(torch.float32, torch.float32, torch.float32, PrecisionConfig(), \
537-
1, 1024, 1024, 1024, None, True, False, 1)
564+
1, 1024, 1024, 1024, None, True, False, 1, False)
538565
assert flags.idle_sms == num_idle_sms
539566

540567

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ def apply_allocation(allocation: MatmulAllocation, output):
177177
if output is None:
178178
output = torch.empty(allocation.output[0], device=allocation.device, dtype=allocation.output[1])
179179
else:
180+
if output.ndim == 2:
181+
output = output[None, :, :]
180182
assert output.shape == allocation.output[0]
181183
ret["output"] = output[None, :, :]
182184
ret["scratchpad"] = {
@@ -350,6 +352,7 @@ def matmul_ogs(x, w, bias,
350352
x_scale = Tensor(x_scale)
351353
if not isinstance(x, Tensor):
352354
x = Tensor(x, dtype=x.dtype)
355+
x_transpose = x.stride(-1) != 1
353356
# determine shapes
354357
has_gather = gather_indx is not None
355358
has_scatter = scatter_indx is not None
@@ -362,14 +365,20 @@ def matmul_ogs(x, w, bias,
362365
assert x.shape[0] == w.shape[0]
363366
# compute optimization flags
364367
out_dtype = precision_config.out_dtype or x.dtype
365-
can_use_tma = x.numel() > 0 and x.storage.is_tma_compliant() and \
366-
w.numel() > 0 and w.storage.is_tma_compliant() and \
367-
(w_scale is None or w_scale.storage.is_tma_compliant())
368+
can_use_tma = (
369+
x.numel() > 0 and x.storage.is_tma_compliant() and
370+
w.numel() > 0 and w.storage.is_tma_compliant() and
371+
(w_scale is None or w_scale.storage.is_tma_compliant()) and
372+
(not is_ragged or x.stride(-1) == 1) and
373+
# Currently we don't support tma if y is column major; may revisit later if this becomes an issue.
374+
(y is None or y.stride(-1) == 1)
375+
)
368376
# hopper w/ mxfp4 doesn't support TMA
369377
can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4)
370378
can_use_fused_scatter = has_scatter and (fused_activation.specs.fn is None) and (epilogue.specs.fn is None) and (routing_data.n_expts_act == 1)
371379
opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config,
372-
batch_size, M, N, K, routing_data, can_use_tma, can_use_fused_scatter, epilogue.effective_itemsize,
380+
batch_size, M, N, K, routing_data, can_use_tma, can_use_fused_scatter,
381+
epilogue.effective_itemsize, x_transpose,
373382
)
374383
if not can_use_fused_scatter and opt_flags.fused_scatter:
375384
raise InapplicableConstraint("Fused scatter is not supported")
@@ -469,7 +478,7 @@ def matmul_ogs(x, w, bias,
469478
y_tensor_or_tma, y_storage.data, *out_matmul.stride(),
470479
*((None, out_matmul_scale, None) if out_matmul_has_mx else out_matmul_flex),
471480
*out_matmul_scale_strides[-4:],
472-
x_tensor_or_tma, x_storage.data, *x_strides,
481+
x_tensor_or_tma, x_storage.data, *x_strides, x_transpose,
473482
flex.lhs_data.scale,
474483
None if x_scale is None else x_scale.data.view(torch.uint8), *x_scale_strides,
475484
w_tensor_or_tma, w_storage.data, *w_storage.data.stride(), w_transpose,

python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _matmul_ogs(
3434
Y, YPtr, stride_y_k, stride_y_z, stride_y_m, stride_y_n,
3535
YExpectedScale, YActualScale, YChecksumScale,
3636
stride_y_mx_k, stride_y_mx_z, stride_y_mx_m, stride_y_mx_n,
37-
X, XPtr, stride_x_z, stride_x_m, stride_x_k,
37+
X, XPtr, stride_x_z, stride_x_m, stride_x_k, X_TRANSPOSE: tl.constexpr,
3838
XScale,
3939
XMxScale, stride_x_mx_z, stride_x_mx_m, stride_x_mx_k,
4040
W, WPtr, stride_w_e, stride_w_k, stride_w_n, W_TRANSPOSE: tl.constexpr,

python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def _p_matmul_ogs(
8282
Y, YPtr, stride_y_k, stride_y_z, stride_y_m, stride_y_n,
8383
YExpectedScale, YActualScale, YChecksumScale,
8484
stride_y_mx_k, stride_y_mx_z, stride_y_mx_m, stride_y_mx_n,
85-
X, XPtr, stride_x_z, stride_x_m, stride_x_k,
85+
X, XPtr, stride_x_z, stride_x_m, stride_x_k, X_TRANSPOSE: tl.constexpr,
8686
XScale,
8787
XMxScale, stride_x_mx_z, stride_x_mx_m, stride_x_mx_k,
8888
W, WPtr, stride_w_e, stride_w_k, stride_w_n, W_TRANSPOSE: tl.constexpr,
@@ -287,8 +287,12 @@ def _p_matmul_ogs(
287287
if USE_GATHER_TMA:
288288
x = X.gather(offs_x_m, off_k)
289289
elif X_TMA_MODE == "dense":
290-
x = X.load([start_z, start_m + off_m, off_k])
291-
x = x.reshape(BLOCK_M, BLOCK_K)
290+
if X_TRANSPOSE:
291+
x = X.load([start_z, off_k, start_m + off_m])
292+
x = x.reshape(BLOCK_K, BLOCK_M).T
293+
else:
294+
x = X.load([start_z, start_m + off_m, off_k])
295+
x = x.reshape(BLOCK_M, BLOCK_K)
292296
elif X_TMA_MODE == "ragged":
293297
x = load_ragged(X, start_m, eM, [start_z, off_m, off_k], ragged_dim=1)
294298
x = x.reshape(BLOCK_M, BLOCK_K)

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def make_default_opt_flags_amd(
4444
can_use_fused_scatter,
4545
enforce_bitwise_invariance,
4646
epilogue_effective_itemsize,
47+
x_transpose,
4748
constraints,
4849
):
4950
constraints_supported = ["block_m", "block_n", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile"]
@@ -143,6 +144,7 @@ def make_default_opt_flags_nvidia(
143144
can_use_fused_scatter,
144145
enforce_bitwise_invariance,
145146
epilogue_effective_itemsize,
147+
x_transpose,
146148
constraints,
147149
):
148150
constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages", "idle_sms"]
@@ -207,6 +209,7 @@ def make_default_opt_flags_nvidia(
207209
out_dtype,
208210
lhs_dtype,
209211
rhs_dtype,
212+
x_transpose,
210213
)
211214

212215
if constraints.get("epilogue_subtile", None) is not None:
@@ -286,6 +289,7 @@ def make_opt_flags(
286289
can_use_persistent_tma,
287290
can_use_fused_scatter,
288291
epilogue_effective_itemsize,
292+
x_transpose,
289293
):
290294
if _opt_flags_constraints.get("is_persistent", False) and not can_use_persistent_tma:
291295
raise InapplicableConstraint("cannot enforce `is_persistent=True` constraint")
@@ -297,7 +301,7 @@ def make_opt_flags(
297301
return _opt_flags
298302
args = [out_dtype, lhs_dtype, rhs_dtype, precision_config, batch_size, m, n, k,
299303
routing_data, can_use_persistent_tma, can_use_fused_scatter,
300-
enforce_bitwise_invariance, epilogue_effective_itemsize,
304+
enforce_bitwise_invariance, epilogue_effective_itemsize, x_transpose,
301305
_opt_flags_constraints]
302306
backend = triton.runtime.driver.active.get_current_target().backend
303307
if backend == "hip":

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def compute_num_stages(
7272
out_dtype,
7373
lhs_dtype,
7474
rhs_dtype,
75+
x_transpose,
7576
epilogue_subtile,
7677
epilogue_effective_itemsize,
7778
):
@@ -103,6 +104,8 @@ def compute_num_stages(
103104
# pipelined layout conversion before store of the accumulator
104105
# note: layout conversion has some padding
105106
smem_capacity -= int((block_m + 4) * acc_block_n * acc_size)
107+
if x_transpose:
108+
smem_capacity -= block_m * block_k * lhs_dtype.itemsize
106109
if precision_config.weight_scale is not None:
107110
# mx scales
108111
stage_size += block_n * (block_k // int(MXFP_BLOCK_SIZE))

0 commit comments

Comments
 (0)