Skip to content

Commit de4376e

Browse files
authored
[triton_kernels][matmul] support inputs with 0 elements (#7808)
1 parent c850534 commit de4376e

File tree

3 files changed

+19
-3
lines changed

3 files changed

+19
-3
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,13 @@ class Case:
161161
", ".join(f.name for f in fields(Case)),
162162
[
163163
tuple(getattr(case, f.name) for f in fields(Case)) for case in [
164+
# Zero-sized args:
165+
Case(0, 5, 7, "ragged", "float16", "float16"),
166+
Case(5, 0, 7, "ragged", "float16", "float16"),
167+
Case(5, 7, 0, "ragged", "float16", "float16"),
168+
Case(0, 5, 7, "batched", "float16", "float16"),
169+
Case(5, 0, 7, "batched", "float16", "float16"),
170+
Case(5, 7, 0, "batched", "float16", "float16"),
164171
# Non-mx types:
165172
Case(16, 256, 256, "ragged", "float16", "float16", 128, 4),
166173
Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, n_expt_shards=2),
@@ -301,7 +308,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
301308
pytest.skip("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles).")
302309

303310
# launch metadata for batched / mx types may not work yet.
304-
test_launch_metadata = (mode == "ragged") and ("mx" not in weight_dtype_str) and fused_scatter
311+
test_launch_metadata = (mode == "ragged") and ("mx" not in weight_dtype_str) and fused_scatter and m*n*k != 0
305312

306313
torch.manual_seed(0)
307314

@@ -349,7 +356,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
349356
has_y_gammas, requires_grad=test_bwd, device=device)
350357
x_ref, w_ref, bias_ref, gs0_ref, gs1_ref = apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_opt)
351358

352-
if w_tri.shape[0] == 1:
359+
if w_tri.shape[0] == 1 and mode != "batched":
353360
# Test the case when weight has dim 2, i.e., shape (K, N).
354361
w_tri = w_tri.squeeze(0).detach().requires_grad_(test_bwd)
355362
w_ref = w_ref.squeeze(0).detach().requires_grad_(test_bwd)

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
import triton
88
from enum import Enum, auto
9+
import math
910
# utilities
1011
from triton_kernels import target_info
1112
from triton_kernels.numerics import InFlexData, OutFlexData
@@ -458,6 +459,11 @@ def matmul_ogs(x, w, bias,
458459
opt_flags, preprocessing_features, postprocessing_features
459460
)
460461
memory = apply_allocation(allocation, y)
462+
if batch_size * M * N == 0:
463+
ret = memory["output"].squeeze(0)
464+
if not is_input_batched:
465+
ret = ret.squeeze(0)
466+
return ret
461467
# TMA descriptors require a global memory allocation
462468
if opt_flags.is_persistent:
463469
triton.set_allocator(get_per_device_per_stream_alloc_fn(x.device))
@@ -509,7 +515,7 @@ def matmul_ogs(x, w, bias,
509515
has_scatter = writeback_idxs is not None
510516
has_gather_tma = has_gather and target_info.has_tma_gather()
511517
has_scatter_tma = has_scatter and target_info.has_tma_gather()
512-
y = wrap_torch_tensor(out0.view(-1, out0.shape[-1]) if has_scatter else out0.view(-1, *out0.shape[-2:]))
518+
y = wrap_torch_tensor(out0.view(math.prod(out0.shape[:-1]), out0.shape[-1]) if has_scatter else out0.view(math.prod(out0.shape[:-2]), *out0.shape[-2:]))
513519
x_storage = _canonicalize_storage(x.storage, 2 if has_gather_tma else 3, flex.lhs_data)
514520
w_storage = _canonicalize_storage(w.storage, 3, flex.rhs_data)
515521
y_storage = _canonicalize_storage(y.storage, 2 if has_scatter_tma else 3, flex.out_data)

python/triton_kernels/triton_kernels/testing.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ def assert_close(ref, tri, maxtol=None, rmstol=None, description="--", verbose=T
2222
return
2323
ref = ref_as_type
2424

25+
if ref.numel() == 0:
26+
return
27+
2528
if maxtol is None:
2629
maxtol = 2e-2
2730
if rmstol is None:

0 commit comments

Comments
 (0)