Skip to content

Commit 043d0f4

Browse files
authored
[example] add jagged_softmax example (#480)
1 parent db41224 commit 043d0f4

File tree

4 files changed

+364
-0
lines changed

4 files changed

+364
-0
lines changed

benchmarks/run.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,11 @@
101101
"examples.layer_norm",
102102
"layer_norm_fwd",
103103
),
104+
"jagged_softmax": (
105+
"tritonbench.operators.jagged_softmax.operator",
106+
"examples.jagged_softmax",
107+
"jagged_softmax_tritonbench",
108+
),
104109
# Multiple kernel variants:
105110
"gemm": (
106111
"tritonbench.operators.gemm.operator",

examples/jagged_softmax.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
"""
2+
Jagged Softmax Example
3+
===============
4+
5+
This example demonstrates how to compute the softmax across each batch in a jagged tensor using Helion.
6+
"""
7+
8+
# %%
9+
# Imports
10+
# -------
11+
from __future__ import annotations
12+
13+
import itertools
14+
15+
import torch
16+
17+
import helion
18+
from helion._testing import run_example
19+
import helion.language as hl
20+
21+
22+
# %%
23+
# Reference Implementation
24+
# --------------------
25+
def reference_jagged_softmax_pytorch(
26+
x_data: torch.Tensor,
27+
x_offsets: torch.Tensor,
28+
) -> torch.Tensor:
29+
"""
30+
PyTorch reference implementation for jagged softmax.
31+
32+
Args:
33+
x_data: 2-D tensor holding all elements
34+
x_offsets: Offsets tensor for row indexing
35+
36+
Returns:
37+
Tensor containing the per-batch softmax scores (same shape as x_data)
38+
"""
39+
vals = []
40+
for i, j in itertools.pairwise(x_offsets):
41+
y = x_data[i:j]
42+
vals.append(torch.softmax(y, dim=0))
43+
return torch.cat(vals, dim=0)
44+
45+
46+
# %%
47+
# Jagged Softmax Kernel
48+
# ---------------
49+
@helion.kernel()
50+
def jagged_softmax_kernel(
51+
x_data: torch.Tensor,
52+
x_offsets: torch.Tensor,
53+
) -> torch.Tensor:
54+
"""
55+
Compute the per-batch softmax in a jagged tensor.
56+
57+
Args:
58+
x_data: 2-D tensor of shape (total_elements, max_M) holding all elements
59+
x_offsets: (num_rows + 1) tensor. Row i is the slice
60+
x_data[x_offsets[i] : x_offsets[i+1], :]
61+
62+
Returns:
63+
2-D tensor of shape (total_elements, max_M), containing the per-batch softmax scores.
64+
"""
65+
N = int(x_offsets[-1].item())
66+
num_rows, M = x_offsets.size(0) - 1, x_data.size(1)
67+
out = torch.zeros(N * M, dtype=x_data.dtype, device=x_data.device)
68+
69+
# flatten
70+
x_flat = x_data.view(-1)
71+
72+
for tile_b in hl.tile(num_rows):
73+
starts = x_offsets[tile_b]
74+
ends = x_offsets[tile_b.index + 1]
75+
seqlens = ends - starts
76+
max_seqlen = seqlens.amax()
77+
78+
for tile_m in hl.tile(M):
79+
block_max = hl.full([tile_b, tile_m], 0.0, dtype=x_data.dtype)
80+
block_new_max = hl.full([tile_b, tile_m], 0.0, dtype=x_data.dtype)
81+
block_L = hl.full([tile_b, tile_m], 0.0, dtype=x_data.dtype)
82+
83+
for tile_k in hl.tile(max_seqlen):
84+
base_indices = starts[:, None] + tile_k.index[None, :]
85+
flat_indices = (
86+
base_indices[:, :, None] * M + tile_m.index[None, None, :]
87+
)
88+
row_mask = tile_k.index[None, :] < seqlens[:, None]
89+
combined_mask = row_mask[:, :, None] & (tile_m.index < M)[None, None, :]
90+
x_slice = hl.load(
91+
x_flat,
92+
[flat_indices],
93+
extra_mask=combined_mask,
94+
)
95+
slice_max = torch.where(combined_mask, x_slice, float("-inf")).amax(
96+
dim=1
97+
)
98+
block_new_max = torch.maximum(block_max, slice_max)
99+
block_L *= torch.exp(block_max - block_new_max)
100+
block_L += torch.exp(
101+
torch.where(
102+
combined_mask,
103+
x_slice - block_new_max[:, None, :],
104+
float("-inf"),
105+
)
106+
).sum(dim=1)
107+
block_max = block_new_max
108+
109+
for tile_k in hl.tile(max_seqlen):
110+
base_indices = starts[:, None] + tile_k.index[None, :]
111+
flat_indices = (
112+
base_indices[:, :, None] * M + tile_m.index[None, None, :]
113+
)
114+
row_mask = tile_k.index[None, :] < seqlens[:, None]
115+
combined_mask = row_mask[:, :, None] & (tile_m.index < M)[None, None, :]
116+
x_slice = hl.load(
117+
x_flat,
118+
[flat_indices],
119+
extra_mask=combined_mask,
120+
)
121+
block_out = (
122+
torch.exp(x_slice - block_max[:, None, :]) / block_L[:, None, :]
123+
)
124+
hl.store(
125+
out,
126+
[flat_indices],
127+
block_out,
128+
extra_mask=combined_mask,
129+
)
130+
131+
return out.reshape(N, M)
132+
133+
134+
# %%
135+
# Benchmark Wrapper
136+
# --------------
137+
def jagged_softmax_tritonbench(
138+
x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
139+
) -> torch.Tensor:
140+
"""
141+
Wrapper for tritonbench that matches the expected interface.
142+
143+
Args:
144+
x: Nested tensor in jagged format with shape (B, *, M)
145+
B: Batch size (unused)
146+
M: Number of features (unused)
147+
seqlen: Maximum sequence length (unused)
148+
sparsity: Sparsity factor (unused)
149+
150+
Returns:
151+
Tensor of shape (N, M), where N = total number of rows in the jagged tensor
152+
"""
153+
return jagged_softmax_kernel(x._values, x._offsets) # pyright: ignore[reportArgumentType, reportAttributeAccessIssue]
154+
155+
156+
# %%
157+
# Main Function
158+
# -----------
159+
def main() -> None:
160+
"""
161+
Main entry point for jagged softmax kernel verification.
162+
"""
163+
num_rows, max_cols = 512, 64
164+
device = "cuda"
165+
166+
lengths = torch.randint(1, max_cols + 1, (num_rows,), device=device)
167+
x_offsets = torch.cat(
168+
[torch.zeros(1, dtype=torch.long, device=device), torch.cumsum(lengths, dim=0)]
169+
)
170+
nnz = int(x_offsets[-1])
171+
M = 128 # number of features
172+
x_data = torch.randn(nnz, M, dtype=torch.float32, device=device)
173+
174+
out_eager = reference_jagged_softmax_pytorch(x_data, x_offsets)
175+
out_hl = jagged_softmax_kernel(x_data, x_offsets)
176+
assert torch.allclose(out_eager, out_hl)
177+
178+
run_example(
179+
lambda x, o: jagged_softmax_kernel(x, o),
180+
lambda x, o: reference_jagged_softmax_pytorch(x, o),
181+
(x_data, x_offsets),
182+
)
183+
184+
185+
if __name__ == "__main__":
186+
main()

test/test_examples.expected

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,6 +1014,150 @@ def jagged_mean_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, x_feature_
10141014
_launcher(_helion_jagged_mean_kernel, (triton.cdiv(num_rows, _BLOCK_SIZE_0),), x_offsets, x_feature_counts, x_flat, out, out.stride(0), out.stride(1), x_feature_counts.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, max_M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
10151015
return out
10161016

1017+
--- assertExpectedJournal(TestExamples.test_jagged_softmax)
1018+
from __future__ import annotations
1019+
1020+
import torch
1021+
import triton
1022+
import triton.language as tl
1023+
from torch._inductor.runtime import triton_helpers
1024+
from torch._inductor.runtime.triton_helpers import math as tl_math
1025+
from helion.runtime import default_launcher as _default_launcher
1026+
1027+
@triton.jit
1028+
def _helion_jagged_softmax_kernel(x_offsets, x_flat, out, out_stride_0, x_flat_stride_0, x_offsets_stride_0, num_rows, M, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
1029+
pid_0 = tl.program_id(0)
1030+
offset_0 = pid_0 * _BLOCK_SIZE_0
1031+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1032+
mask_0 = indices_0 < num_rows
1033+
starts = tl.load(x_offsets + indices_0 * x_offsets_stride_0, mask_0, other=0)
1034+
v_0 = tl.full([], 1, tl.int32)
1035+
v_1 = indices_0 + v_0
1036+
ends = tl.load(x_offsets + v_1 * x_offsets_stride_0, mask_0, other=0)
1037+
v_2 = ends - starts
1038+
_mask_to = tl.where(mask_0, v_2, -9223372036854775808)
1039+
max_seqlen = tl.max(_mask_to, 0)
1040+
for offset_1 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_1):
1041+
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
1042+
mask_1 = indices_1 < M
1043+
max_seqlen_copy = max_seqlen
1044+
starts_copy = starts
1045+
v_2_copy = v_2
1046+
max_seqlen_copy_0 = max_seqlen_copy
1047+
starts_copy_0 = starts_copy
1048+
v_2_copy_0 = v_2_copy
1049+
block_max = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1050+
block_new_max = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1051+
block_L = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1052+
for offset_2 in tl.range(0, max_seqlen_copy_0.to(tl.int32), _BLOCK_SIZE_2):
1053+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
1054+
mask_2 = indices_2 < max_seqlen_copy_0
1055+
starts_copy_0_copy = starts_copy_0
1056+
v_2_copy_0_copy = v_2_copy_0
1057+
block_max_copy = block_max
1058+
block_L_copy = block_L
1059+
starts_copy_0_copy_0 = starts_copy_0_copy
1060+
v_2_copy_0_copy_0 = v_2_copy_0_copy
1061+
block_max_copy_0 = block_max_copy
1062+
block_L_copy_0 = block_L_copy
1063+
subscript = starts_copy_0_copy_0[:, None]
1064+
subscript_1 = indices_2[None, :]
1065+
v_3 = subscript_1.to(tl.int64)
1066+
v_4 = subscript + v_3
1067+
subscript_2 = v_4[:, :, None]
1068+
v_5 = subscript_2 * M
1069+
subscript_3 = indices_1[None, None, :]
1070+
v_6 = subscript_3.to(tl.int64)
1071+
v_7 = v_5 + v_6
1072+
subscript_4 = indices_2[None, :]
1073+
subscript_5 = v_2_copy_0_copy_0[:, None]
1074+
v_8 = subscript_4.to(tl.int64)
1075+
v_9 = v_8 < subscript_5
1076+
subscript_6 = v_9[:, :, None]
1077+
v_10 = M.to(tl.int32)
1078+
v_11 = indices_1 < v_10
1079+
subscript_7 = v_11[None, None, :]
1080+
v_12 = subscript_6 & subscript_7
1081+
x_slice = tl.load(x_flat + v_7 * x_flat_stride_0, mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :] & v_12, other=0)
1082+
v_13 = float('-inf')
1083+
v_14 = v_13[None, None, None]
1084+
v_15 = tl.where(v_12, x_slice, v_14)
1085+
_mask_to_1 = tl.where(mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :], v_15, float('-inf'))
1086+
slice_max = tl.max(_mask_to_1, 1)
1087+
block_new_max = triton_helpers.maximum(block_max_copy_0, slice_max)
1088+
v_17 = block_max_copy_0 - block_new_max
1089+
v_18 = tl_math.exp(v_17)
1090+
v_19 = block_L_copy_0 * v_18
1091+
subscript_8 = block_new_max[:, None, :]
1092+
v_20 = x_slice - subscript_8
1093+
v_21 = float('-inf')
1094+
v_22 = v_21[None, None, None]
1095+
v_23 = tl.where(v_12, v_20, v_22)
1096+
v_24 = tl_math.exp(v_23)
1097+
_mask_to_2 = tl.where(mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :], v_24, 0)
1098+
sum_1 = tl.sum(_mask_to_2, 1)
1099+
block_L = v_19 + sum_1
1100+
block_max = block_new_max
1101+
for offset_3 in tl.range(0, max_seqlen_copy_0.to(tl.int32), _BLOCK_SIZE_3):
1102+
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
1103+
mask_3 = indices_3 < max_seqlen_copy_0
1104+
starts_copy_0_copy_1 = starts_copy_0
1105+
v_2_copy_0_copy_1 = v_2_copy_0
1106+
block_max_copy_1 = block_max
1107+
block_L_copy_1 = block_L
1108+
starts_copy_0_copy_1_0 = starts_copy_0_copy_1
1109+
v_2_copy_0_copy_1_0 = v_2_copy_0_copy_1
1110+
block_max_copy_1_0 = block_max_copy_1
1111+
block_L_copy_1_0 = block_L_copy_1
1112+
subscript_9 = starts_copy_0_copy_1_0[:, None]
1113+
subscript_10 = indices_3[None, :]
1114+
v_26 = subscript_10.to(tl.int64)
1115+
v_27 = subscript_9 + v_26
1116+
subscript_11 = v_27[:, :, None]
1117+
v_28 = subscript_11 * M
1118+
subscript_12 = indices_1[None, None, :]
1119+
v_29 = subscript_12.to(tl.int64)
1120+
v_30 = v_28 + v_29
1121+
subscript_13 = indices_3[None, :]
1122+
subscript_14 = v_2_copy_0_copy_1_0[:, None]
1123+
v_31 = subscript_13.to(tl.int64)
1124+
v_32 = v_31 < subscript_14
1125+
subscript_15 = v_32[:, :, None]
1126+
v_33 = M.to(tl.int32)
1127+
v_34 = indices_1 < v_33
1128+
subscript_16 = v_34[None, None, :]
1129+
v_35 = subscript_15 & subscript_16
1130+
x_slice_1 = tl.load(x_flat + v_30 * x_flat_stride_0, mask_0[:, None, None] & mask_3[None, :, None] & mask_1[None, None, :] & v_35, other=0)
1131+
subscript_17 = block_max_copy_1_0[:, None, :]
1132+
v_36 = x_slice_1 - subscript_17
1133+
v_37 = tl_math.exp(v_36)
1134+
subscript_18 = block_L_copy_1_0[:, None, :]
1135+
v_38 = v_37 / subscript_18
1136+
tl.store(out + v_30 * out_stride_0, v_38, mask_0[:, None, None] & mask_3[None, :, None] & mask_1[None, None, :] & v_35)
1137+
1138+
def jagged_softmax_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, *, _launcher=_default_launcher):
1139+
"""
1140+
Compute the per-batch softmax in a jagged tensor.
1141+
1142+
Args:
1143+
x_data: 2-D tensor of shape (total_elements, max_M) holding all elements
1144+
x_offsets: (num_rows + 1) tensor. Row i is the slice
1145+
x_data[x_offsets[i] : x_offsets[i+1], :]
1146+
1147+
Returns:
1148+
2-D tensor of shape (total_elements, max_M), containing the per-batch softmax scores.
1149+
"""
1150+
N = int(x_offsets[-1].item())
1151+
num_rows, M = (x_offsets.size(0) - 1, x_data.size(1))
1152+
out = torch.zeros(N * M, dtype=x_data.dtype, device=x_data.device)
1153+
x_flat = x_data.view(-1)
1154+
_BLOCK_SIZE_0 = 16
1155+
_BLOCK_SIZE_1 = 8
1156+
_BLOCK_SIZE_2 = 16
1157+
_BLOCK_SIZE_3 = 16
1158+
_launcher(_helion_jagged_softmax_kernel, (triton.cdiv(num_rows, _BLOCK_SIZE_0),), x_offsets, x_flat, out, out.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
1159+
return out.reshape(N, M)
1160+
10171161
--- assertExpectedJournal(TestExamples.test_layernorm)
10181162
from __future__ import annotations
10191163

test/test_examples.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,35 @@ def test_layernorm(self):
629629
)
630630
)
631631

632+
@skipIfRefEager("ref eager mode hits CUDA indexing error with hl.store")
633+
def test_jagged_softmax(self):
634+
num_rows, max_cols = 128, 64
635+
M = 8 # number of features
636+
lengths = torch.randint(1, max_cols + 1, (num_rows,), device=DEVICE)
637+
x_offsets = torch.cat(
638+
[
639+
torch.zeros(1, dtype=torch.long, device=DEVICE),
640+
torch.cumsum(lengths, dim=0),
641+
]
642+
)
643+
nnz = int(x_offsets[-1])
644+
x_data = torch.randn(nnz, M, dtype=torch.float32, device=DEVICE)
645+
args = (x_data, x_offsets)
646+
647+
# Import and use the reference implementation
648+
mod = import_path(EXAMPLES_DIR / "jagged_softmax.py")
649+
expected = mod.reference_jagged_softmax_pytorch(x_data, x_offsets)
650+
651+
self.assertExpectedJournal(
652+
check_example(
653+
"jagged_softmax",
654+
args,
655+
expected,
656+
fn_name="jagged_softmax_kernel",
657+
block_sizes=[16, 8, 16, 16],
658+
)
659+
)
660+
632661

633662
if __name__ == "__main__":
634663
unittest.main()

0 commit comments

Comments
 (0)