|
| 1 | +# adopt from https://github.com/triton-lang/triton/issues/3698#issuecomment-2067681396 |
| 2 | +import torch |
| 3 | +import triton |
| 4 | +import triton.language as tl |
| 5 | +from triton.language.standard import _log2, sum, zeros_like |
| 6 | + |
| 7 | + |
| 8 | +@triton.jit |
| 9 | +def _compare_and_swap(x, ids, flip, i: tl.core.constexpr, n_dims: tl.core.constexpr): |
| 10 | + n_outer: tl.core.constexpr = x.numel >> n_dims |
| 11 | + shape: tl.core.constexpr = [n_outer * 2 ** i, 2, 2 ** (n_dims - i - 1)] |
| 12 | + y = tl.core.reshape(x, shape) |
| 13 | + # slice left/right with 'stride' 2**(n_dims - i - 1) |
| 14 | + mask = tl.core.arange(0, 2)[None, :, None] |
| 15 | + left = tl.core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape) |
| 16 | + right = tl.core.broadcast_to(sum(y * mask, 1)[:, None, :], shape) |
| 17 | + left = tl.core.reshape(left, x.shape) |
| 18 | + right = tl.core.reshape(right, x.shape) |
| 19 | + |
| 20 | + # idx |
| 21 | + y_idx = tl.core.reshape(ids, shape) |
| 22 | + left_idx = tl.core.broadcast_to(sum(y_idx * (1 - mask), 1)[:, None, :], shape) |
| 23 | + right_idx = tl.core.broadcast_to(sum(y_idx * mask, 1)[:, None, :], shape) |
| 24 | + left_idx = tl.core.reshape(left_idx, x.shape) |
| 25 | + right_idx = tl.core.reshape(right_idx, x.shape) |
| 26 | + |
| 27 | + # actual compare-and-swap |
| 28 | + idtype = tl.core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) |
| 29 | + ileft = left.to(idtype, bitcast=True) |
| 30 | + iright = right.to(idtype, bitcast=True) |
| 31 | + ix = x.to(idtype, bitcast=True) |
| 32 | + |
| 33 | + cond = (left > right) ^ flip |
| 34 | + |
| 35 | + ret = ix ^ tl.core.where(cond, ileft ^ iright, zeros_like(ix)) |
| 36 | + |
| 37 | + new_ids = ids ^ tl.core.where(cond, left_idx ^ right_idx, zeros_like(ids)) |
| 38 | + |
| 39 | + return ret.to(x.dtype, bitcast=True), new_ids |
| 40 | + |
| 41 | + |
| 42 | +@triton.jit |
| 43 | +def _bitonic_merge(x, ids, stage: tl.core.constexpr, order: tl.core.constexpr, n_dims: tl.core.constexpr): |
| 44 | + """ |
| 45 | + order_type 0 == ascending |
| 46 | + order_type 1 == descending |
| 47 | + order_type 2 == alternating |
| 48 | + """ |
| 49 | + n_outer: tl.core.constexpr = x.numel >> n_dims |
| 50 | + tl.core.static_assert(stage <= n_dims) |
| 51 | + # flip denotes whether to re-arrange sub-sequences of elements in ascending or |
| 52 | + # descending order. |
| 53 | + # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage |
| 54 | + # if flip = 00110011... then all the elements will be re-arranged alternatingly (with |
| 55 | + # a stride of 2) at this stage |
| 56 | + if order == 2: |
| 57 | + shape: tl.core.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2 ** stage] |
| 58 | + flip = tl.core.reshape(tl.core.broadcast_to(tl.core.arange(0, 2)[None, :, None], shape), x.shape) |
| 59 | + else: |
| 60 | + flip = order |
| 61 | + # perform `stage` rounds of `compare-and-swap` |
| 62 | + for i in tl.core.static_range(stage): |
| 63 | + x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims) |
| 64 | + return x, ids |
| 65 | + |
| 66 | + |
| 67 | +@triton.jit |
| 68 | +def argsort(x, ids, dim: tl.core.constexpr = None, descending: tl.core.constexpr = tl.core.CONSTEXPR_0): |
| 69 | + # handle default dimension or check that it is the most minor dim |
| 70 | + _dim: tl.core.constexpr = len(x.shape) - 1 if dim is None else dim |
| 71 | + tl.core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported") |
| 72 | + # iteratively run bitonic merge-sort steps |
| 73 | + n_dims: tl.core.constexpr = _log2(x.shape[_dim]) |
| 74 | + |
| 75 | + for i in tl.core.static_range(1, n_dims + 1): |
| 76 | + x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims) |
| 77 | + return x, ids |
| 78 | + |
| 79 | + |
| 80 | +@triton.jit |
| 81 | +def grouped_topk_kernel( |
| 82 | + gating_output_ptr, |
| 83 | + gating_output_stride_m, |
| 84 | + gating_output_stride_n, |
| 85 | + correction_bias_ptr, |
| 86 | + scores_buffer_ptr, # [token_num, total_expert_num] |
| 87 | + scores_stride_m, |
| 88 | + scores_stride_n, |
| 89 | + scores_stride_token_m, |
| 90 | + scores_stride_group, |
| 91 | + scores_stride_group_v, |
| 92 | + out_topk_weights, |
| 93 | + out_topk_weights_stride_m, |
| 94 | + out_topk_weights_stride_n, |
| 95 | + out_topk_ids, |
| 96 | + out_topk_ids_stride_m, |
| 97 | + out_topk_ids_stride_n, |
| 98 | + group_num, |
| 99 | + group_expert_num, |
| 100 | + total_expert_num, # group_num * group_expert_num == total_expert_num |
| 101 | + topk_num, |
| 102 | + group_topk_num, |
| 103 | + IS_SIGMOID: tl.constexpr, |
| 104 | + HAS_CORRECTION_BIAS: tl.constexpr, |
| 105 | + EXPERT_BLOCK_SIZE: tl.constexpr, # tl.next_power_two_of(total_expert_num) |
| 106 | + EXPERT_GROUP_NUM: tl.constexpr, # tl.next_power_two_of(group_num) |
| 107 | + EXPERT_GROUP_SIZE: tl.constexpr, # tl.next_power_two_of(group_expert_num) |
| 108 | + RENORMALIZE: tl.constexpr, |
| 109 | +): |
| 110 | + token_index = tl.program_id(axis=0) |
| 111 | + offs_n = tl.arange(0, EXPERT_BLOCK_SIZE) |
| 112 | + hidden_states = tl.load( |
| 113 | + gating_output_ptr + token_index * gating_output_stride_m + offs_n, |
| 114 | + mask=offs_n < total_expert_num, |
| 115 | + other=-10000000.0, |
| 116 | + ) |
| 117 | + if IS_SIGMOID: |
| 118 | + scores = tl.sigmoid(hidden_states) |
| 119 | + else: |
| 120 | + scores = tl.softmax(hidden_states) |
| 121 | + |
| 122 | + if HAS_CORRECTION_BIAS: |
| 123 | + scores += tl.load(correction_bias_ptr + offs_n, mask=offs_n < total_expert_num, other=-10000000.0) |
| 124 | + |
| 125 | + offs_group = tl.arange(0, EXPERT_GROUP_NUM) |
| 126 | + offs_group_v = tl.arange(0, EXPERT_GROUP_SIZE) |
| 127 | + tl.store(scores_buffer_ptr + scores_stride_m * token_index + offs_n, scores, mask=offs_n < total_expert_num) |
| 128 | + group_scores = tl.load( |
| 129 | + scores_buffer_ptr |
| 130 | + + scores_stride_token_m * token_index |
| 131 | + + offs_group[:, None] * scores_stride_group |
| 132 | + + offs_group_v[None, :] * scores_stride_group_v, |
| 133 | + mask=(offs_group < group_num)[:, None] & (offs_group_v < group_expert_num)[None, :], |
| 134 | + other=-10000000.0, |
| 135 | + ) # [group, group_size] |
| 136 | + |
| 137 | + group_value = tl.max(group_scores, axis=1) # [group,] |
| 138 | + sorted_group_value = tl.sort(group_value, descending=True) |
| 139 | + group_topk_value = tl.sum(tl.where(offs_group == group_topk_num - 1, sorted_group_value, 0.0)) |
| 140 | + mask_group_scores = tl.where( |
| 141 | + ((group_value >= group_topk_value)[:, None]) & ((offs_group_v < group_expert_num)[None, :]), |
| 142 | + group_scores, |
| 143 | + -10000000.0, |
| 144 | + ) |
| 145 | + |
| 146 | + tl.store( |
| 147 | + scores_buffer_ptr |
| 148 | + + scores_stride_token_m * token_index |
| 149 | + + offs_group[:, None] * scores_stride_group |
| 150 | + + offs_group_v[None, :] * scores_stride_group_v, |
| 151 | + mask_group_scores, |
| 152 | + mask=((offs_group < group_num)[:, None]) & ((offs_group_v < group_expert_num)[None, :]), |
| 153 | + ) # [group, group_size] |
| 154 | + |
| 155 | + mask_scores = tl.load( |
| 156 | + scores_buffer_ptr + scores_stride_m * token_index + offs_n, mask=offs_n < total_expert_num, other=-10000000.0 |
| 157 | + ) |
| 158 | + sorted_scores, sorted_indexes = argsort(mask_scores, offs_n, descending=True) |
| 159 | + |
| 160 | + if RENORMALIZE: |
| 161 | + sum_scores = tl.sum(tl.where(offs_n < topk_num, sorted_scores, 0.0)) |
| 162 | + renormlize_scores = sorted_scores / sum_scores |
| 163 | + |
| 164 | + tl.store( |
| 165 | + out_topk_weights + token_index * out_topk_weights_stride_m + offs_n, |
| 166 | + renormlize_scores, |
| 167 | + mask=offs_n < topk_num, |
| 168 | + ) |
| 169 | + tl.store(out_topk_ids + token_index * out_topk_ids_stride_m + offs_n, sorted_indexes, mask=offs_n < topk_num) |
| 170 | + else: |
| 171 | + tl.store( |
| 172 | + out_topk_weights + token_index * out_topk_weights_stride_m + offs_n, sorted_scores, mask=offs_n < topk_num |
| 173 | + ) |
| 174 | + tl.store(out_topk_ids + token_index * out_topk_ids_stride_m + offs_n, sorted_indexes, mask=offs_n < topk_num) |
| 175 | + return |
| 176 | + |
| 177 | + |
| 178 | +def triton_grouped_topk( |
| 179 | + hidden_states: torch.Tensor, |
| 180 | + gating_output: torch.Tensor, |
| 181 | + correction_bias: torch.Tensor, |
| 182 | + topk: int, |
| 183 | + renormalize: bool, |
| 184 | + num_expert_group: int = 0, |
| 185 | + topk_group: int = 0, |
| 186 | + scoring_func: str = "softmax", |
| 187 | +): |
| 188 | + |
| 189 | + if correction_bias is not None: |
| 190 | + has_correction_bias = True |
| 191 | + else: |
| 192 | + has_correction_bias = False |
| 193 | + |
| 194 | + token_num, total_expert_num = gating_output.shape |
| 195 | + if gating_output.dtype == torch.float64: |
| 196 | + dtype = torch.float64 |
| 197 | + else: |
| 198 | + dtype = torch.float32 |
| 199 | + |
| 200 | + scores_buffer = torch.empty((token_num, total_expert_num), dtype=dtype, device="cuda") |
| 201 | + out_topk_weights = torch.empty((token_num, topk), dtype=torch.float32, device="cuda") |
| 202 | + out_topk_ids = torch.empty((token_num, topk), dtype=torch.int32, device="cuda") |
| 203 | + |
| 204 | + assert total_expert_num % num_expert_group == 0 |
| 205 | + |
| 206 | + grouped_topk_kernel[(token_num,)]( |
| 207 | + gating_output, |
| 208 | + *gating_output.stride(), |
| 209 | + correction_bias, |
| 210 | + scores_buffer, |
| 211 | + *scores_buffer.stride(), |
| 212 | + *scores_buffer.view(token_num, num_expert_group, -1).stride(), |
| 213 | + out_topk_weights, |
| 214 | + *out_topk_weights.stride(), |
| 215 | + out_topk_ids, |
| 216 | + *out_topk_ids.stride(), |
| 217 | + group_num=num_expert_group, |
| 218 | + group_expert_num=total_expert_num // num_expert_group, |
| 219 | + total_expert_num=total_expert_num, |
| 220 | + topk_num=topk, |
| 221 | + group_topk_num=topk_group, |
| 222 | + IS_SIGMOID=scoring_func == "sigmoid", |
| 223 | + HAS_CORRECTION_BIAS=has_correction_bias, |
| 224 | + EXPERT_BLOCK_SIZE=triton.next_power_of_2(total_expert_num), |
| 225 | + EXPERT_GROUP_NUM=triton.next_power_of_2(num_expert_group), |
| 226 | + EXPERT_GROUP_SIZE=triton.next_power_of_2(total_expert_num // num_expert_group), |
| 227 | + RENORMALIZE=renormalize, |
| 228 | + num_warps=1, |
| 229 | + num_stages=1, |
| 230 | + ) |
| 231 | + return out_topk_weights, out_topk_ids |
0 commit comments