diff --git a/benchmarks/run.py b/benchmarks/run.py index b6273b51..2c25d275 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -101,6 +101,11 @@ "examples.layer_norm", "layer_norm_fwd", ), + "jagged_softmax": ( + "tritonbench.operators.jagged_softmax.operator", + "examples.jagged_softmax", + "jagged_softmax_tritonbench", + ), # Multiple kernel variants: "gemm": ( "tritonbench.operators.gemm.operator", diff --git a/examples/jagged_softmax.py b/examples/jagged_softmax.py new file mode 100644 index 00000000..00de17c2 --- /dev/null +++ b/examples/jagged_softmax.py @@ -0,0 +1,186 @@ +""" +Jagged Softmax Example +=============== + +This example demonstrates how to compute the softmax across each batch in a jagged tensor using Helion. +""" + +# %% +# Imports +# ------- +from __future__ import annotations + +import itertools + +import torch + +import helion +from helion._testing import run_example +import helion.language as hl + + +# %% +# Reference Implementation +# -------------------- +def reference_jagged_softmax_pytorch( + x_data: torch.Tensor, + x_offsets: torch.Tensor, +) -> torch.Tensor: + """ + PyTorch reference implementation for jagged softmax. + + Args: + x_data: 2-D tensor holding all elements + x_offsets: Offsets tensor for row indexing + + Returns: + Tensor containing the per-batch softmax scores (same shape as x_data) + """ + vals = [] + for i, j in itertools.pairwise(x_offsets): + y = x_data[i:j] + vals.append(torch.softmax(y, dim=0)) + return torch.cat(vals, dim=0) + + +# %% +# Jagged Softmax Kernel +# --------------- +@helion.kernel() +def jagged_softmax_kernel( + x_data: torch.Tensor, + x_offsets: torch.Tensor, +) -> torch.Tensor: + """ + Compute the per-batch softmax in a jagged tensor. + + Args: + x_data: 2-D tensor of shape (total_elements, max_M) holding all elements + x_offsets: (num_rows + 1) tensor. Row i is the slice + x_data[x_offsets[i] : x_offsets[i+1], :] + + Returns: + 2-D tensor of shape (total_elements, max_M), containing the per-batch softmax scores. + """ + N = int(x_offsets[-1].item()) + num_rows, M = x_offsets.size(0) - 1, x_data.size(1) + out = torch.zeros(N * M, dtype=x_data.dtype, device=x_data.device) + + # flatten + x_flat = x_data.view(-1) + + for tile_b in hl.tile(num_rows): + starts = x_offsets[tile_b] + ends = x_offsets[tile_b.index + 1] + seqlens = ends - starts + max_seqlen = seqlens.amax() + + for tile_m in hl.tile(M): + block_max = hl.full([tile_b, tile_m], 0.0, dtype=x_data.dtype) + block_new_max = hl.full([tile_b, tile_m], 0.0, dtype=x_data.dtype) + block_L = hl.full([tile_b, tile_m], 0.0, dtype=x_data.dtype) + + for tile_k in hl.tile(max_seqlen): + base_indices = starts[:, None] + tile_k.index[None, :] + flat_indices = ( + base_indices[:, :, None] * M + tile_m.index[None, None, :] + ) + row_mask = tile_k.index[None, :] < seqlens[:, None] + combined_mask = row_mask[:, :, None] & (tile_m.index < M)[None, None, :] + x_slice = hl.load( + x_flat, + [flat_indices], + extra_mask=combined_mask, + ) + slice_max = torch.where(combined_mask, x_slice, float("-inf")).amax( + dim=1 + ) + block_new_max = torch.maximum(block_max, slice_max) + block_L *= torch.exp(block_max - block_new_max) + block_L += torch.exp( + torch.where( + combined_mask, + x_slice - block_new_max[:, None, :], + float("-inf"), + ) + ).sum(dim=1) + block_max = block_new_max + + for tile_k in hl.tile(max_seqlen): + base_indices = starts[:, None] + tile_k.index[None, :] + flat_indices = ( + base_indices[:, :, None] * M + tile_m.index[None, None, :] + ) + row_mask = tile_k.index[None, :] < seqlens[:, None] + combined_mask = row_mask[:, :, None] & (tile_m.index < M)[None, None, :] + x_slice = hl.load( + x_flat, + [flat_indices], + extra_mask=combined_mask, + ) + block_out = ( + torch.exp(x_slice - block_max[:, None, :]) / block_L[:, None, :] + ) + hl.store( + out, + [flat_indices], + block_out, + extra_mask=combined_mask, + ) + + return out.reshape(N, M) + + +# %% +# Benchmark Wrapper +# -------------- +def jagged_softmax_tritonbench( + x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float +) -> torch.Tensor: + """ + Wrapper for tritonbench that matches the expected interface. + + Args: + x: Nested tensor in jagged format with shape (B, *, M) + B: Batch size (unused) + M: Number of features (unused) + seqlen: Maximum sequence length (unused) + sparsity: Sparsity factor (unused) + + Returns: + Tensor of shape (N, M), where N = total number of rows in the jagged tensor + """ + return jagged_softmax_kernel(x._values, x._offsets) # pyright: ignore[reportArgumentType, reportAttributeAccessIssue] + + +# %% +# Main Function +# ----------- +def main() -> None: + """ + Main entry point for jagged softmax kernel verification. + """ + num_rows, max_cols = 512, 64 + device = "cuda" + + lengths = torch.randint(1, max_cols + 1, (num_rows,), device=device) + x_offsets = torch.cat( + [torch.zeros(1, dtype=torch.long, device=device), torch.cumsum(lengths, dim=0)] + ) + nnz = int(x_offsets[-1]) + M = 128 # number of features + x_data = torch.randn(nnz, M, dtype=torch.float32, device=device) + + out_eager = reference_jagged_softmax_pytorch(x_data, x_offsets) + out_hl = jagged_softmax_kernel(x_data, x_offsets) + assert torch.allclose(out_eager, out_hl) + + run_example( + lambda x, o: jagged_softmax_kernel(x, o), + lambda x, o: reference_jagged_softmax_pytorch(x, o), + (x_data, x_offsets), + ) + + +if __name__ == "__main__": + main() diff --git a/test/test_examples.expected b/test/test_examples.expected index e096413d..9c6904bc 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -1014,6 +1014,150 @@ def jagged_mean_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, x_feature_ _launcher(_helion_jagged_mean_kernel, (triton.cdiv(num_rows, _BLOCK_SIZE_0),), x_offsets, x_feature_counts, x_flat, out, out.stride(0), out.stride(1), x_feature_counts.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, max_M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) return out +--- assertExpectedJournal(TestExamples.test_jagged_softmax) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from torch._inductor.runtime import triton_helpers +from torch._inductor.runtime.triton_helpers import math as tl_math +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_jagged_softmax_kernel(x_offsets, x_flat, out, out_stride_0, x_flat_stride_0, x_offsets_stride_0, num_rows, M, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < num_rows + starts = tl.load(x_offsets + indices_0 * x_offsets_stride_0, mask_0, other=0) + v_0 = tl.full([], 1, tl.int32) + v_1 = indices_0 + v_0 + ends = tl.load(x_offsets + v_1 * x_offsets_stride_0, mask_0, other=0) + v_2 = ends - starts + _mask_to = tl.where(mask_0, v_2, -9223372036854775808) + max_seqlen = tl.max(_mask_to, 0) + for offset_1 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_1): + indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_1 < M + max_seqlen_copy = max_seqlen + starts_copy = starts + v_2_copy = v_2 + max_seqlen_copy_0 = max_seqlen_copy + starts_copy_0 = starts_copy + v_2_copy_0 = v_2_copy + block_max = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + block_new_max = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + block_L = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_2 in tl.range(0, max_seqlen_copy_0.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < max_seqlen_copy_0 + starts_copy_0_copy = starts_copy_0 + v_2_copy_0_copy = v_2_copy_0 + block_max_copy = block_max + block_L_copy = block_L + starts_copy_0_copy_0 = starts_copy_0_copy + v_2_copy_0_copy_0 = v_2_copy_0_copy + block_max_copy_0 = block_max_copy + block_L_copy_0 = block_L_copy + subscript = starts_copy_0_copy_0[:, None] + subscript_1 = indices_2[None, :] + v_3 = subscript_1.to(tl.int64) + v_4 = subscript + v_3 + subscript_2 = v_4[:, :, None] + v_5 = subscript_2 * M + subscript_3 = indices_1[None, None, :] + v_6 = subscript_3.to(tl.int64) + v_7 = v_5 + v_6 + subscript_4 = indices_2[None, :] + subscript_5 = v_2_copy_0_copy_0[:, None] + v_8 = subscript_4.to(tl.int64) + v_9 = v_8 < subscript_5 + subscript_6 = v_9[:, :, None] + v_10 = M.to(tl.int32) + v_11 = indices_1 < v_10 + subscript_7 = v_11[None, None, :] + v_12 = subscript_6 & subscript_7 + x_slice = tl.load(x_flat + v_7 * x_flat_stride_0, mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :] & v_12, other=0) + v_13 = float('-inf') + v_14 = v_13[None, None, None] + v_15 = tl.where(v_12, x_slice, v_14) + _mask_to_1 = tl.where(mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :], v_15, float('-inf')) + slice_max = tl.max(_mask_to_1, 1) + block_new_max = triton_helpers.maximum(block_max_copy_0, slice_max) + v_17 = block_max_copy_0 - block_new_max + v_18 = tl_math.exp(v_17) + v_19 = block_L_copy_0 * v_18 + subscript_8 = block_new_max[:, None, :] + v_20 = x_slice - subscript_8 + v_21 = float('-inf') + v_22 = v_21[None, None, None] + v_23 = tl.where(v_12, v_20, v_22) + v_24 = tl_math.exp(v_23) + _mask_to_2 = tl.where(mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :], v_24, 0) + sum_1 = tl.sum(_mask_to_2, 1) + block_L = v_19 + sum_1 + block_max = block_new_max + for offset_3 in tl.range(0, max_seqlen_copy_0.to(tl.int32), _BLOCK_SIZE_3): + indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) + mask_3 = indices_3 < max_seqlen_copy_0 + starts_copy_0_copy_1 = starts_copy_0 + v_2_copy_0_copy_1 = v_2_copy_0 + block_max_copy_1 = block_max + block_L_copy_1 = block_L + starts_copy_0_copy_1_0 = starts_copy_0_copy_1 + v_2_copy_0_copy_1_0 = v_2_copy_0_copy_1 + block_max_copy_1_0 = block_max_copy_1 + block_L_copy_1_0 = block_L_copy_1 + subscript_9 = starts_copy_0_copy_1_0[:, None] + subscript_10 = indices_3[None, :] + v_26 = subscript_10.to(tl.int64) + v_27 = subscript_9 + v_26 + subscript_11 = v_27[:, :, None] + v_28 = subscript_11 * M + subscript_12 = indices_1[None, None, :] + v_29 = subscript_12.to(tl.int64) + v_30 = v_28 + v_29 + subscript_13 = indices_3[None, :] + subscript_14 = v_2_copy_0_copy_1_0[:, None] + v_31 = subscript_13.to(tl.int64) + v_32 = v_31 < subscript_14 + subscript_15 = v_32[:, :, None] + v_33 = M.to(tl.int32) + v_34 = indices_1 < v_33 + subscript_16 = v_34[None, None, :] + v_35 = subscript_15 & subscript_16 + x_slice_1 = tl.load(x_flat + v_30 * x_flat_stride_0, mask_0[:, None, None] & mask_3[None, :, None] & mask_1[None, None, :] & v_35, other=0) + subscript_17 = block_max_copy_1_0[:, None, :] + v_36 = x_slice_1 - subscript_17 + v_37 = tl_math.exp(v_36) + subscript_18 = block_L_copy_1_0[:, None, :] + v_38 = v_37 / subscript_18 + tl.store(out + v_30 * out_stride_0, v_38, mask_0[:, None, None] & mask_3[None, :, None] & mask_1[None, None, :] & v_35) + +def jagged_softmax_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, *, _launcher=_default_launcher): + """ + Compute the per-batch softmax in a jagged tensor. + + Args: + x_data: 2-D tensor of shape (total_elements, max_M) holding all elements + x_offsets: (num_rows + 1) tensor. Row i is the slice + x_data[x_offsets[i] : x_offsets[i+1], :] + + Returns: + 2-D tensor of shape (total_elements, max_M), containing the per-batch softmax scores. + """ + N = int(x_offsets[-1].item()) + num_rows, M = (x_offsets.size(0) - 1, x_data.size(1)) + out = torch.zeros(N * M, dtype=x_data.dtype, device=x_data.device) + x_flat = x_data.view(-1) + _BLOCK_SIZE_0 = 16 + _BLOCK_SIZE_1 = 8 + _BLOCK_SIZE_2 = 16 + _BLOCK_SIZE_3 = 16 + _launcher(_helion_jagged_softmax_kernel, (triton.cdiv(num_rows, _BLOCK_SIZE_0),), x_offsets, x_flat, out, out.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3) + return out.reshape(N, M) + --- assertExpectedJournal(TestExamples.test_layernorm) from __future__ import annotations diff --git a/test/test_examples.py b/test/test_examples.py index 5d021047..52914127 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -629,6 +629,35 @@ def test_layernorm(self): ) ) + @skipIfRefEager("ref eager mode hits CUDA indexing error with hl.store") + def test_jagged_softmax(self): + num_rows, max_cols = 128, 64 + M = 8 # number of features + lengths = torch.randint(1, max_cols + 1, (num_rows,), device=DEVICE) + x_offsets = torch.cat( + [ + torch.zeros(1, dtype=torch.long, device=DEVICE), + torch.cumsum(lengths, dim=0), + ] + ) + nnz = int(x_offsets[-1]) + x_data = torch.randn(nnz, M, dtype=torch.float32, device=DEVICE) + args = (x_data, x_offsets) + + # Import and use the reference implementation + mod = import_path(EXAMPLES_DIR / "jagged_softmax.py") + expected = mod.reference_jagged_softmax_pytorch(x_data, x_offsets) + + self.assertExpectedJournal( + check_example( + "jagged_softmax", + args, + expected, + fn_name="jagged_softmax_kernel", + block_sizes=[16, 8, 16, 16], + ) + ) + if __name__ == "__main__": unittest.main()