@@ -1014,6 +1014,150 @@ def jagged_mean_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, x_feature_
1014
1014
_launcher(_helion_jagged_mean_kernel, (triton.cdiv(num_rows, _BLOCK_SIZE_0),), x_offsets, x_feature_counts, x_flat, out, out.stride(0), out.stride(1), x_feature_counts.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, max_M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
1015
1015
return out
1016
1016
1017
+ --- assertExpectedJournal(TestExamples.test_jagged_softmax)
1018
+ from __future__ import annotations
1019
+
1020
+ import torch
1021
+ import triton
1022
+ import triton.language as tl
1023
+ from torch._inductor.runtime import triton_helpers
1024
+ from torch._inductor.runtime.triton_helpers import math as tl_math
1025
+ from helion.runtime import default_launcher as _default_launcher
1026
+
1027
+ @triton.jit
1028
+ def _helion_jagged_softmax_kernel(x_offsets, x_flat, out, out_stride_0, x_flat_stride_0, x_offsets_stride_0, num_rows, M, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
1029
+ pid_0 = tl.program_id(0)
1030
+ offset_0 = pid_0 * _BLOCK_SIZE_0
1031
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1032
+ mask_0 = indices_0 < num_rows
1033
+ starts = tl.load(x_offsets + indices_0 * x_offsets_stride_0, mask_0, other=0)
1034
+ v_0 = tl.full([], 1, tl.int32)
1035
+ v_1 = indices_0 + v_0
1036
+ ends = tl.load(x_offsets + v_1 * x_offsets_stride_0, mask_0, other=0)
1037
+ v_2 = ends - starts
1038
+ _mask_to = tl.where(mask_0, v_2, -9223372036854775808)
1039
+ max_seqlen = tl.max(_mask_to, 0)
1040
+ for offset_1 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_1):
1041
+ indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
1042
+ mask_1 = indices_1 < M
1043
+ max_seqlen_copy = max_seqlen
1044
+ starts_copy = starts
1045
+ v_2_copy = v_2
1046
+ max_seqlen_copy_0 = max_seqlen_copy
1047
+ starts_copy_0 = starts_copy
1048
+ v_2_copy_0 = v_2_copy
1049
+ block_max = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1050
+ block_new_max = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1051
+ block_L = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1052
+ for offset_2 in tl.range(0, max_seqlen_copy_0.to(tl.int32), _BLOCK_SIZE_2):
1053
+ indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
1054
+ mask_2 = indices_2 < max_seqlen_copy_0
1055
+ starts_copy_0_copy = starts_copy_0
1056
+ v_2_copy_0_copy = v_2_copy_0
1057
+ block_max_copy = block_max
1058
+ block_L_copy = block_L
1059
+ starts_copy_0_copy_0 = starts_copy_0_copy
1060
+ v_2_copy_0_copy_0 = v_2_copy_0_copy
1061
+ block_max_copy_0 = block_max_copy
1062
+ block_L_copy_0 = block_L_copy
1063
+ subscript = starts_copy_0_copy_0[:, None]
1064
+ subscript_1 = indices_2[None, :]
1065
+ v_3 = subscript_1.to(tl.int64)
1066
+ v_4 = subscript + v_3
1067
+ subscript_2 = v_4[:, :, None]
1068
+ v_5 = subscript_2 * M
1069
+ subscript_3 = indices_1[None, None, :]
1070
+ v_6 = subscript_3.to(tl.int64)
1071
+ v_7 = v_5 + v_6
1072
+ subscript_4 = indices_2[None, :]
1073
+ subscript_5 = v_2_copy_0_copy_0[:, None]
1074
+ v_8 = subscript_4.to(tl.int64)
1075
+ v_9 = v_8 < subscript_5
1076
+ subscript_6 = v_9[:, :, None]
1077
+ v_10 = M.to(tl.int32)
1078
+ v_11 = indices_1 < v_10
1079
+ subscript_7 = v_11[None, None, :]
1080
+ v_12 = subscript_6 & subscript_7
1081
+ x_slice = tl.load(x_flat + v_7 * x_flat_stride_0, mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :] & v_12, other=0)
1082
+ v_13 = float('-inf')
1083
+ v_14 = v_13[None, None, None]
1084
+ v_15 = tl.where(v_12, x_slice, v_14)
1085
+ _mask_to_1 = tl.where(mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :], v_15, float('-inf'))
1086
+ slice_max = tl.max(_mask_to_1, 1)
1087
+ block_new_max = triton_helpers.maximum(block_max_copy_0, slice_max)
1088
+ v_17 = block_max_copy_0 - block_new_max
1089
+ v_18 = tl_math.exp(v_17)
1090
+ v_19 = block_L_copy_0 * v_18
1091
+ subscript_8 = block_new_max[:, None, :]
1092
+ v_20 = x_slice - subscript_8
1093
+ v_21 = float('-inf')
1094
+ v_22 = v_21[None, None, None]
1095
+ v_23 = tl.where(v_12, v_20, v_22)
1096
+ v_24 = tl_math.exp(v_23)
1097
+ _mask_to_2 = tl.where(mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :], v_24, 0)
1098
+ sum_1 = tl.sum(_mask_to_2, 1)
1099
+ block_L = v_19 + sum_1
1100
+ block_max = block_new_max
1101
+ for offset_3 in tl.range(0, max_seqlen_copy_0.to(tl.int32), _BLOCK_SIZE_3):
1102
+ indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
1103
+ mask_3 = indices_3 < max_seqlen_copy_0
1104
+ starts_copy_0_copy_1 = starts_copy_0
1105
+ v_2_copy_0_copy_1 = v_2_copy_0
1106
+ block_max_copy_1 = block_max
1107
+ block_L_copy_1 = block_L
1108
+ starts_copy_0_copy_1_0 = starts_copy_0_copy_1
1109
+ v_2_copy_0_copy_1_0 = v_2_copy_0_copy_1
1110
+ block_max_copy_1_0 = block_max_copy_1
1111
+ block_L_copy_1_0 = block_L_copy_1
1112
+ subscript_9 = starts_copy_0_copy_1_0[:, None]
1113
+ subscript_10 = indices_3[None, :]
1114
+ v_26 = subscript_10.to(tl.int64)
1115
+ v_27 = subscript_9 + v_26
1116
+ subscript_11 = v_27[:, :, None]
1117
+ v_28 = subscript_11 * M
1118
+ subscript_12 = indices_1[None, None, :]
1119
+ v_29 = subscript_12.to(tl.int64)
1120
+ v_30 = v_28 + v_29
1121
+ subscript_13 = indices_3[None, :]
1122
+ subscript_14 = v_2_copy_0_copy_1_0[:, None]
1123
+ v_31 = subscript_13.to(tl.int64)
1124
+ v_32 = v_31 < subscript_14
1125
+ subscript_15 = v_32[:, :, None]
1126
+ v_33 = M.to(tl.int32)
1127
+ v_34 = indices_1 < v_33
1128
+ subscript_16 = v_34[None, None, :]
1129
+ v_35 = subscript_15 & subscript_16
1130
+ x_slice_1 = tl.load(x_flat + v_30 * x_flat_stride_0, mask_0[:, None, None] & mask_3[None, :, None] & mask_1[None, None, :] & v_35, other=0)
1131
+ subscript_17 = block_max_copy_1_0[:, None, :]
1132
+ v_36 = x_slice_1 - subscript_17
1133
+ v_37 = tl_math.exp(v_36)
1134
+ subscript_18 = block_L_copy_1_0[:, None, :]
1135
+ v_38 = v_37 / subscript_18
1136
+ tl.store(out + v_30 * out_stride_0, v_38, mask_0[:, None, None] & mask_3[None, :, None] & mask_1[None, None, :] & v_35)
1137
+
1138
+ def jagged_softmax_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, *, _launcher=_default_launcher):
1139
+ """
1140
+ Compute the per-batch softmax in a jagged tensor.
1141
+
1142
+ Args:
1143
+ x_data: 2-D tensor of shape (total_elements, max_M) holding all elements
1144
+ x_offsets: (num_rows + 1) tensor. Row i is the slice
1145
+ x_data[x_offsets[i] : x_offsets[i+1], :]
1146
+
1147
+ Returns:
1148
+ 2-D tensor of shape (total_elements, max_M), containing the per-batch softmax scores.
1149
+ """
1150
+ N = int(x_offsets[-1].item())
1151
+ num_rows, M = (x_offsets.size(0) - 1, x_data.size(1))
1152
+ out = torch.zeros(N * M, dtype=x_data.dtype, device=x_data.device)
1153
+ x_flat = x_data.view(-1)
1154
+ _BLOCK_SIZE_0 = 16
1155
+ _BLOCK_SIZE_1 = 8
1156
+ _BLOCK_SIZE_2 = 16
1157
+ _BLOCK_SIZE_3 = 16
1158
+ _launcher(_helion_jagged_softmax_kernel, (triton.cdiv(num_rows, _BLOCK_SIZE_0),), x_offsets, x_flat, out, out.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
1159
+ return out.reshape(N, M)
1160
+
1017
1161
--- assertExpectedJournal(TestExamples.test_layernorm)
1018
1162
from __future__ import annotations
1019
1163
0 commit comments