Skip to content

Commit 6effefb

Browse files
Merge commit 'de4376e90a3c2b5ca30ada25a50cccadeadf7f1a'
2 parents 6a95135 + de4376e commit 6effefb

File tree

11 files changed

+153
-181
lines changed

11 files changed

+153
-181
lines changed

python/test/unit/cuda/test_tma_descriptor.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,12 @@ def example_load_store_kernel(X, Y, x_off, y_off, x_size, y_size):
5555
store_ragged(Y, y_off, y_size, [0, 0], data)
5656

5757

58-
@pytest.mark.parametrize("write_only", [False, True])
59-
@pytest.mark.parametrize("dtype", ["float16", "float32", "float64"])
60-
def test_ragged_tma(dtype, write_only):
58+
@pytest.mark.parametrize("dtype", [
59+
"bfloat16", "float16", "float32", "float64", # floating-point
60+
"int8", "int16", "int32", "int64", # signed integers
61+
"uint8", "uint16", "uint32", "uint64" # unsigned integers
62+
])
63+
def test_ragged_tma(dtype):
6164

6265
if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 9:
6366
pytest.skip("Test requires Hopper or Blackwell target.")
@@ -67,10 +70,10 @@ def test_ragged_tma(dtype, write_only):
6770

6871
src = torch.randn((1024, 80), dtype=torch.float32, device="cuda").to(dtype)
6972
ref = torch.randn((1024, 80), dtype=torch.float32, device="cuda").to(dtype)
70-
dst = 1.0 * ref
73+
dst = ref.clone()
7174

7275
X = create_ragged_descriptor(src, [32, 128])
73-
Y = create_ragged_descriptor(dst, [32, 128], write_only=write_only)
76+
Y = create_ragged_descriptor(dst, [32, 128])
7477

7578
x_off = 42
7679
y_off = 51

python/triton/tools/ragged_tma.py

Lines changed: 5 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,8 @@
44

55
# fmt: off
66

7-
class TensorDescriptorPtr:
8-
def __init__(self, data_ptr, dtype):
9-
self._data_ptr = data_ptr
10-
self.dtype = dtype
117

12-
def data_ptr(self):
13-
return self._data_ptr
14-
15-
16-
def create_ragged_descriptor(T, block_shape, ragged_dim=0, write_only=False):
8+
def create_ragged_descriptor(T, block_shape, ragged_dim=0):
179
"""
1810
Given a 2- or 3-dimensional tensor T, this creates a 'ragged descriptor'
1911
which behaves like a concatenation (along the first axis) of subarrays
@@ -33,11 +25,7 @@ def create_ragged_descriptor(T, block_shape, ragged_dim=0, write_only=False):
3325
ragged_dim += rank
3426

3527
assert 0 <= ragged_dim < rank - 1, "last dimension cannot be ragged"
36-
37-
if write_only:
38-
assert rank <= 4, "write-only ragged descriptors must have at most 4 dimensions"
39-
else:
40-
assert rank <= 3, "read-write ragged descriptors must have at most 3 dimensions"
28+
assert rank <= 3, "read-write ragged descriptors must have at most 3 dimensions"
4129

4230
assert len(block_shape) == rank, "block shape must have same length as tensor shape"
4331

@@ -53,15 +41,8 @@ def create_ragged_descriptor(T, block_shape, ragged_dim=0, write_only=False):
5341
tma_stride = [2**34 - ragged_stride, ragged_stride] + [T.stride(i) for i in range(rank)]
5442
tma_shape = [max_int, max_int] + tensor_shape
5543
box_shape = [1, 1] + block_shape
56-
ptr = T.data_ptr()
5744

58-
if write_only:
59-
tma_stride = tma_stride[1:]
60-
tma_shape = tma_shape[1:]
61-
box_shape = box_shape[1:]
62-
ptr = (ptr - billion * ragged_stride * T.element_size()) % (2**64)
63-
64-
return TensorDescriptor(TensorDescriptorPtr(ptr, T.dtype), tma_shape, tma_stride, box_shape)
45+
return TensorDescriptor(T, tma_shape, tma_stride, box_shape)
6546

6647

6748
@triton.jit
@@ -106,18 +87,6 @@ def store_ragged(TMA, batch_offset, batch_size, coords, data, ragged_dim: tl.con
10687
TMA.store().
10788
"""
10889

109-
if len(TMA.shape) == len(coords) + 1:
110-
write_only: tl.constexpr = True
111-
elif len(TMA.shape) == len(coords) + 2:
112-
write_only: tl.constexpr = False
113-
else:
114-
tl.static_assert(False, "TMA must be a ragged descriptor")
115-
11690
c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
117-
118-
if write_only:
119-
data = tl.reshape(data, [1] + data.shape)
120-
TMA.store([c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data)
121-
else:
122-
data = tl.reshape(data, [1, 1] + data.shape)
123-
TMA.store([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data)
91+
data = tl.reshape(data, [1, 1] + data.shape)
92+
TMA.store([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data)

python/triton_kernels/tests/test_matmul.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,13 @@ class Case:
161161
", ".join(f.name for f in fields(Case)),
162162
[
163163
tuple(getattr(case, f.name) for f in fields(Case)) for case in [
164+
# Zero-sized args:
165+
Case(0, 5, 7, "ragged", "float16", "float16"),
166+
Case(5, 0, 7, "ragged", "float16", "float16"),
167+
Case(5, 7, 0, "ragged", "float16", "float16"),
168+
Case(0, 5, 7, "batched", "float16", "float16"),
169+
Case(5, 0, 7, "batched", "float16", "float16"),
170+
Case(5, 7, 0, "batched", "float16", "float16"),
164171
# Non-mx types:
165172
Case(16, 256, 256, "ragged", "float16", "float16", 128, 4),
166173
Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, n_expt_shards=2),
@@ -301,7 +308,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
301308
pytest.skip("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles).")
302309

303310
# launch metadata for batched / mx types may not work yet.
304-
test_launch_metadata = (mode == "ragged") and ("mx" not in weight_dtype_str)
311+
test_launch_metadata = (mode == "ragged") and ("mx" not in weight_dtype_str) and fused_scatter and m*n*k != 0
305312

306313
torch.manual_seed(0)
307314

@@ -349,7 +356,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
349356
has_y_gammas, requires_grad=test_bwd, device=device)
350357
x_ref, w_ref, bias_ref, gs0_ref, gs1_ref = apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_opt)
351358

352-
if w_tri.shape[0] == 1:
359+
if w_tri.shape[0] == 1 and mode != "batched":
353360
# Test the case when weight has dim 2, i.e., shape (K, N).
354361
w_tri = w_tri.squeeze(0).detach().requires_grad_(test_bwd)
355362
w_ref = w_ref.squeeze(0).detach().requires_grad_(test_bwd)

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
import triton
88
from enum import Enum, auto
9+
import math
910
# utilities
1011
from triton_kernels import target_info
1112
from triton_kernels.numerics import InFlexData, OutFlexData
@@ -427,6 +428,7 @@ def matmul_ogs(x, w, bias,
427428
if not isinstance(x, Tensor):
428429
x = Tensor(x, dtype=x.dtype)
429430
# determine shapes
431+
is_ragged = routing_data.expt_hist is not None
430432
M = x.shape[-2] if gather_indx is None else gather_indx.src_indx.shape[0]
431433
batch_size = w.shape[0] if routing_data.expt_hist is None and w.ndim == 3 else 1
432434
K, N = w.shape[-2:]
@@ -457,6 +459,11 @@ def matmul_ogs(x, w, bias,
457459
opt_flags, preprocessing_features, postprocessing_features
458460
)
459461
memory = apply_allocation(allocation, y)
462+
if batch_size * M * N == 0:
463+
ret = memory["output"].squeeze(0)
464+
if not is_input_batched:
465+
ret = ret.squeeze(0)
466+
return ret
460467
# TMA descriptors require a global memory allocation
461468
if opt_flags.is_persistent:
462469
triton.set_allocator(get_per_device_per_stream_alloc_fn(x.device))
@@ -505,19 +512,31 @@ def matmul_ogs(x, w, bias,
505512
grid = min(target_info.num_sms() - opt_flags.idle_sms, max_grid) if opt_flags.is_persistent else max_grid
506513
# canonicalize storage
507514
has_gather = gather_indx is not None
508-
x_storage = _canonicalize_storage(x.storage, 2 if has_gather else 3, flex.lhs_data)
515+
has_scatter = writeback_idxs is not None
516+
has_gather_tma = has_gather and target_info.has_tma_gather()
517+
has_scatter_tma = has_scatter and target_info.has_tma_gather()
518+
y = wrap_torch_tensor(out0.view(math.prod(out0.shape[:-1]), out0.shape[-1]) if has_scatter else out0.view(math.prod(out0.shape[:-2]), *out0.shape[-2:]))
519+
x_storage = _canonicalize_storage(x.storage, 2 if has_gather_tma else 3, flex.lhs_data)
509520
w_storage = _canonicalize_storage(w.storage, 3, flex.rhs_data)
521+
y_storage = _canonicalize_storage(y.storage, 2 if has_scatter_tma else 3, flex.out_data)
510522
# create tma descriptor for x
511-
x_has_tma = ((not has_gather) or (has_gather and target_info.has_tma_gather())) and opt_flags.is_persistent
512-
x_block_tma = ([1] if has_gather else [1, opt_flags.block_m]) + [opt_flags.block_k]
513-
x_tensor_or_tma = x_storage.make_tma(x_block_tma) if x_has_tma else x_storage.data
523+
x_has_tma = opt_flags.is_persistent and (has_gather_tma or not has_gather)
524+
x_tma_block_size = [1, opt_flags.block_k] if has_gather_tma else [1, opt_flags.block_m, opt_flags.block_k]
525+
x_tma_mode = None if not x_has_tma else "ragged" if is_ragged and not has_gather_tma else "dense"
526+
x_tensor_or_tma = x_storage.make_tma(x_tma_block_size, x_tma_mode) if x_has_tma else x_storage.data
527+
# create tma descriptor for y
528+
y_has_tma = opt_flags.is_persistent and (has_scatter_tma or not has_scatter)
529+
block_n = opt_flags.block_n // opt_flags.epilogue_subtile // fused_activation.reduction_n
530+
y_tma_block_size = [1, block_n] if has_scatter_tma else [1, opt_flags.block_m, block_n]
531+
y_tma_mode = None if not y_has_tma else "ragged" if is_ragged and not has_scatter_tma else "dense"
532+
y_tensor_or_tma = y_storage.make_tma(y_tma_block_size, y_tma_mode) if y_has_tma else y_storage.data
514533
# create tma descriptor for w
515534
w_has_tma = opt_flags.is_persistent
516-
w_tensor_or_tma = w_storage.make_tma([1, opt_flags.block_k, opt_flags.block_n]) if w_has_tma else w_storage.data
535+
w_tensor_or_tma = w_storage.make_tma([1, opt_flags.block_k, opt_flags.block_n], "dense") if w_has_tma else w_storage.data
517536
# create tma descriptor for w_scale
518537
w_scale_tensor_or_tma = w_scale
519538
w_scale_has_tma = opt_flags.is_persistent and w_scale is not None
520-
w_scale_tensor_or_tma = w_scale.storage.make_tma([opt_flags.block_n, opt_flags.block_k]) if w_scale_has_tma else w_scale
539+
w_scale_tensor_or_tma = w_scale.storage.make_tma([opt_flags.block_n, opt_flags.block_k], "dense") if w_scale_has_tma else w_scale
521540
# canonicalize strides
522541
x_strides = [0]*(3 - x_storage.data.ndim) + list(x_storage.data.stride())
523542
x_scale_strides = x_scale.stride() if x_has_mx else (None, None, None)
@@ -529,14 +548,13 @@ def matmul_ogs(x, w, bias,
529548
# launch kernel
530549
kernels = get_kernels(epilogue.specs, fused_activation.specs)
531550
(kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(grid,)](
532-
flex.out_data.reinterpret(memory["output"]),
533-
flex.out_data.reinterpret(out0), *out0.stride(),
551+
y_tensor_or_tma, y_storage.data, *out0.stride(),
534552
*((None, out_scale, None) if out_has_mx else out0_flex),
535553
*out_scale_strides[-3:],
536554
x_tensor_or_tma, x_storage.data, *x_strides,
537555
flex.lhs_data.scale,
538556
None if x_scale is None else x_scale.data.view(torch.uint8), *x_scale_strides,
539-
w_tensor_or_tma, *w_storage.data.stride(), w_storage.data.stride()[-1] != 1,
557+
w_tensor_or_tma, w_storage.data, *w_storage.data.stride(), w_storage.data.stride()[-1] != 1,
540558
flex.rhs_data.scale,
541559
w_scale_tensor_or_tma, *w_scale_strides,
542560
bias, bias_stride,
@@ -574,7 +592,8 @@ def matmul_ogs(x, w, bias,
574592
num_stages=opt_flags.num_stages,
575593
arch=opt_flags.arch,
576594
UPCAST_INDICES=should_upcast_indices(x, w, out0),
577-
DISABLE_Y_TMA=out0.stride(-2) * out0.dtype.itemsize % 16 != 0,
595+
X_TMA_MODE=x_tma_mode,
596+
Y_TMA_MODE=y_tma_mode,
578597
SWAP_XW=preprocessing_features.swap_xw,
579598
IS_EPILOGUE_DEQUANT_MXFP8=epilogue.specs.name == FnName.DEQUANTIZE_MXFP8.name,
580599
NUM_SMS = grid if opt_flags.is_persistent else 0,

python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import triton
44
import triton.language as tl
5-
from triton.tools.tensor_descriptor import TensorDescriptor
65

76
# -----------------------------------------------------------------------------
87
# Utilities
@@ -94,7 +93,7 @@ def matmul_launch_metadata(grid, kernel, args):
9493

9594
ret = dict()
9695
M, N, K = args["M"], args["N"], args["K"]
97-
Y, X, W = [t.base if isinstance(t, TensorDescriptor) else t for t in [args["Y"], args["X"], args["W"]]]
96+
Y, X, W = args["YPtr"], args["XPtr"], args["WPtr"]
9897
tokens_per_expt = args.get("TOKENS_PER_EXPT_FOR_ANNOTATION")
9998
hist = args["ExptHist"]
10099
if hist is not None:

python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@ def _zero_masked_rows(
3030
@triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"],
3131
repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata)
3232
def _matmul_ogs(
33-
Y, Out, stride_y_k, stride_y_z, stride_y_m, stride_y_n,
33+
Y, YPtr, stride_y_k, stride_y_z, stride_y_m, stride_y_n,
3434
YExpectedScale, YActualScale, YChecksumScale,
3535
stride_y_mx_z, stride_y_mx_m, stride_y_mx_n,
3636
X, XPtr, stride_x_z, stride_x_m, stride_x_k,
3737
XScale,
3838
XMxScale, stride_x_mx_z, stride_x_mx_m, stride_x_mx_k,
39-
W, stride_w_e, stride_w_k, stride_w_n, W_TRANSPOSE: tl.constexpr,
39+
W, WPtr, stride_w_e, stride_w_k, stride_w_n, W_TRANSPOSE: tl.constexpr,
4040
WScale,
4141
WMxScale, stride_w_mx_e, stride_w_mx_k, stride_w_mx_n,
4242
B, stride_b_e, # Bias
@@ -72,13 +72,13 @@ def _matmul_ogs(
7272
EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr,
7373
W_CACHE_MODIFIER: tl.constexpr,
7474
NUM_SMS: tl.constexpr,
75+
X_TMA_MODE: tl.constexpr,
76+
Y_TMA_MODE: tl.constexpr,
7577
TOKENS_PER_EXPT_FOR_ANNOTATION=None,
7678
UPCAST_INDICES: tl.constexpr = False,
77-
DISABLE_Y_TMA: tl.constexpr = True,
7879
SWAP_XW: tl.constexpr = False,
7980
IS_EPILOGUE_DEQUANT_MXFP8: tl.constexpr = False):
8081

81-
Y = Out # Y is passed for the purposes of annotation; replace it with Out
8282
is_w_microscaled: tl.constexpr = WMxScale is not None
8383
MX_PACK_DIVISOR: tl.constexpr = MXFP_BLOCK_SIZE
8484
if is_w_microscaled:

0 commit comments

Comments
 (0)