|
| 1 | +import triton |
| 2 | +import triton.language as tl |
| 3 | +import torch |
| 4 | + |
| 5 | + |
| 6 | +@triton.jit |
| 7 | +def _fwd_kernel_mtp_verify( |
| 8 | + req_to_next_token_ids, |
| 9 | + req_to_next_token_ids_stride, |
| 10 | + new_next_token_ids, |
| 11 | + mtp_accept_len, |
| 12 | + b_req_mtp_start_loc, |
| 13 | + b_req_idx, |
| 14 | + b_mtp_index, |
| 15 | + accepted_index, |
| 16 | + batch_size: tl.constexpr, |
| 17 | + BLOCK_SIZE: tl.constexpr, |
| 18 | +): |
| 19 | + cur_index = tl.program_id(0) |
| 20 | + req_start_loc = tl.load(b_req_mtp_start_loc + cur_index) |
| 21 | + cur_req_idx = tl.load(b_req_idx + req_start_loc) |
| 22 | + offset = tl.arange(0, BLOCK_SIZE) |
| 23 | + req_offset = req_start_loc + offset |
| 24 | + cur_mtp_index = tl.load(b_mtp_index + req_offset, mask=req_offset < batch_size) |
| 25 | + |
| 26 | + mask = cur_mtp_index == offset |
| 27 | + |
| 28 | + cur_next_token_id = tl.load( |
| 29 | + req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + offset + 1, mask=mask, other=-1 |
| 30 | + ) |
| 31 | + cur_new_next_token_id = tl.load(new_next_token_ids + req_offset, mask=mask, other=-2) |
| 32 | + |
| 33 | + match_mask = cur_next_token_id == cur_new_next_token_id |
| 34 | + |
| 35 | + first_false = tl.where(~match_mask, offset, BLOCK_SIZE - 1) |
| 36 | + accept_len = tl.min(first_false) |
| 37 | + tl.store(mtp_accept_len + cur_index, accept_len) |
| 38 | + accpeted_index = tl.where((offset < accept_len + 1), 1, 0) |
| 39 | + tl.store(accepted_index + req_offset, accpeted_index, mask=mask) |
| 40 | + return |
| 41 | + |
| 42 | + |
| 43 | +def mtp_verify( |
| 44 | + req_to_next_token_ids: torch.Tensor, |
| 45 | + b_req_mtp_start_loc: torch.Tensor, |
| 46 | + new_next_token_ids: torch.Tensor, |
| 47 | + b_req_idx: torch.Tensor, |
| 48 | + b_mtp_index: torch.Tensor, |
| 49 | +): |
| 50 | + """ |
| 51 | + This function is used to verify the accept_len. |
| 52 | + Args: |
| 53 | + req_to_next_token_ids: (max_req_num, max_mtp_step) |
| 54 | + b_req_mtp_start_loc: (num_reqs,) |
| 55 | + new_next_token_ids: (batch_size,) |
| 56 | + b_req_idx: (batch_size,) |
| 57 | + b_mtp_index: (batch_size,) |
| 58 | + Returns: |
| 59 | + mtp_accept_len: (num_reqs,) |
| 60 | + accepted_index: (batch_size,) |
| 61 | + accepted_index: [1, 0, 1, 1, 0], 0 means the token is not accepted, 1 means the token is accepted. |
| 62 | + """ |
| 63 | + max_mtp_step = req_to_next_token_ids.shape[1] |
| 64 | + BLOCK_SIZE = 16 |
| 65 | + assert max_mtp_step <= BLOCK_SIZE, f"max_mtp_step must be less than {BLOCK_SIZE}" |
| 66 | + num_reqs = b_req_mtp_start_loc.shape[0] |
| 67 | + batch_size = b_req_idx.shape[0] |
| 68 | + mtp_accept_len = torch.empty((num_reqs,), dtype=torch.int32, device=req_to_next_token_ids.device) |
| 69 | + accepted_index = torch.empty((batch_size,), dtype=torch.int32, device=req_to_next_token_ids.device) |
| 70 | + |
| 71 | + grid = (num_reqs,) |
| 72 | + num_warps = 1 |
| 73 | + _fwd_kernel_mtp_verify[grid]( |
| 74 | + req_to_next_token_ids, |
| 75 | + req_to_next_token_ids.stride(0), |
| 76 | + new_next_token_ids, |
| 77 | + mtp_accept_len, |
| 78 | + b_req_mtp_start_loc, |
| 79 | + b_req_idx, |
| 80 | + b_mtp_index, |
| 81 | + accepted_index, |
| 82 | + batch_size, |
| 83 | + BLOCK_SIZE, |
| 84 | + num_warps=num_warps, |
| 85 | + num_stages=1, |
| 86 | + ) |
| 87 | + return mtp_accept_len, accepted_index |
| 88 | + |
| 89 | + |
| 90 | +@triton.jit |
| 91 | +def _fwd_kernel_mtp_scatter_next_token_ids( |
| 92 | + req_to_next_token_ids, |
| 93 | + req_to_next_token_ids_stride, |
| 94 | + all_next_token_ids, |
| 95 | + all_next_token_ids_stride, |
| 96 | + mtp_accept_len, |
| 97 | + b_req_mtp_start_loc, |
| 98 | + b_req_idx, |
| 99 | + b_mtp_index, |
| 100 | + mtp_step: tl.constexpr, |
| 101 | + batch_size: tl.constexpr, |
| 102 | + BLOCK_SIZE: tl.constexpr, |
| 103 | +): |
| 104 | + |
| 105 | + cur_index = tl.program_id(0) |
| 106 | + req_start_loc = tl.load(b_req_mtp_start_loc + cur_index) |
| 107 | + accept_len = tl.load(mtp_accept_len + cur_index) |
| 108 | + cur_req_idx = tl.load(b_req_idx + req_start_loc) |
| 109 | + offset = tl.arange(0, BLOCK_SIZE) |
| 110 | + req_offset = req_start_loc + offset |
| 111 | + cur_mtp_index = tl.load(b_mtp_index + req_offset, mask=req_offset < batch_size) |
| 112 | + |
| 113 | + mask = cur_mtp_index == offset |
| 114 | + scatter_next_token_ids = tl.load( |
| 115 | + all_next_token_ids + (req_start_loc + accept_len) * all_next_token_ids_stride + offset, |
| 116 | + mask=offset < mtp_step, |
| 117 | + other=0, |
| 118 | + ) |
| 119 | + scatter_next_token_ids = tl.where(mask, scatter_next_token_ids, -1) |
| 120 | + tl.store( |
| 121 | + req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + offset, |
| 122 | + scatter_next_token_ids, |
| 123 | + mask=offset < mtp_step, |
| 124 | + ) |
| 125 | + return |
| 126 | + |
| 127 | + |
| 128 | +def mtp_scatter_next_token_ids( |
| 129 | + req_to_next_token_ids: torch.Tensor, |
| 130 | + b_req_mtp_start_loc: torch.Tensor, |
| 131 | + all_next_token_ids: torch.Tensor, |
| 132 | + b_req_idx: torch.Tensor, |
| 133 | + b_mtp_index: torch.Tensor, |
| 134 | + mtp_accept_len: torch.Tensor, |
| 135 | +): |
| 136 | + max_mtp_step = req_to_next_token_ids.shape[1] |
| 137 | + BLOCK_SIZE = 16 |
| 138 | + assert max_mtp_step <= BLOCK_SIZE, f"max_mtp_step must be less than {BLOCK_SIZE}" |
| 139 | + num_reqs = b_req_mtp_start_loc.shape[0] |
| 140 | + batch_size = b_req_idx.shape[0] |
| 141 | + mtp_step = all_next_token_ids.shape[1] |
| 142 | + grid = (num_reqs,) |
| 143 | + num_warps = 1 |
| 144 | + _fwd_kernel_mtp_scatter_next_token_ids[grid]( |
| 145 | + req_to_next_token_ids, |
| 146 | + req_to_next_token_ids.stride(0), |
| 147 | + all_next_token_ids, |
| 148 | + all_next_token_ids.stride(0), |
| 149 | + mtp_accept_len, |
| 150 | + b_req_mtp_start_loc, |
| 151 | + b_req_idx, |
| 152 | + b_mtp_index, |
| 153 | + mtp_step, |
| 154 | + batch_size, |
| 155 | + BLOCK_SIZE, |
| 156 | + num_warps=num_warps, |
| 157 | + num_stages=1, |
| 158 | + ) |
| 159 | + |
| 160 | + |
| 161 | +@triton.jit |
| 162 | +def _fwd_kernel_gen_b_req_mtp_start_loc( |
| 163 | + b_mtp_index, |
| 164 | + b_req_mtp_start_loc, |
| 165 | + num_reqs: tl.constexpr, |
| 166 | + batch_size: tl.constexpr, |
| 167 | + BLOCK_SIZE: tl.constexpr, |
| 168 | +): |
| 169 | + offset = tl.arange(0, BLOCK_SIZE) |
| 170 | + cur_mtp_index = tl.load(b_mtp_index + offset, mask=offset < batch_size, other=-1) |
| 171 | + non_zero_mask = tl.where(cur_mtp_index == 0, 1, 0) # 1 0 1 0 0 |
| 172 | + output_offset = tl.cumsum(non_zero_mask) - 1 |
| 173 | + tl.store(b_req_mtp_start_loc + output_offset, offset, mask=non_zero_mask == 1) |
| 174 | + return |
| 175 | + |
| 176 | + |
| 177 | +def gen_b_req_mtp_start_loc(b_mtp_index: torch.Tensor, num_reqs: int): |
| 178 | + b_req_mtp_start_loc = torch.empty((num_reqs,), dtype=torch.int32, device=b_mtp_index.device) |
| 179 | + BLOCK_SIZE = triton.next_power_of_2(b_mtp_index.shape[0]) |
| 180 | + batch_size = b_mtp_index.shape[0] |
| 181 | + grid = (1,) |
| 182 | + _fwd_kernel_gen_b_req_mtp_start_loc[grid]( |
| 183 | + b_mtp_index=b_mtp_index, |
| 184 | + b_req_mtp_start_loc=b_req_mtp_start_loc, |
| 185 | + num_reqs=num_reqs, |
| 186 | + batch_size=batch_size, |
| 187 | + BLOCK_SIZE=BLOCK_SIZE, |
| 188 | + num_warps=8, |
| 189 | + ) |
| 190 | + return b_req_mtp_start_loc |
| 191 | + |
| 192 | + |
| 193 | +def test_mtp_verify(): |
| 194 | + req_to_next_token_ids = torch.tensor( |
| 195 | + [[1, 2, -2, -1, -1], [1, 2, 0, -1, -1], [1, 3, 4, 4, 5]], dtype=torch.int32, device="cuda" |
| 196 | + ) |
| 197 | + b_req_idx = torch.tensor([0, 0, 2, 2, 2], dtype=torch.int32, device="cuda") |
| 198 | + b_mtp_index = torch.tensor([0, 1, 0, 1, 2], dtype=torch.int32, device="cuda") |
| 199 | + b_req_mtp_start_loc = torch.tensor([0, 2], dtype=torch.int32, device="cuda") |
| 200 | + new_next_token_ids = torch.tensor([1, 4, 3, 4, 13], dtype=torch.int32, device="cuda") |
| 201 | + all_next_token_ids = torch.tensor( |
| 202 | + [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]], dtype=torch.int32, device="cuda" |
| 203 | + ) |
| 204 | + mtp_accept_len, accepted_index = mtp_verify( |
| 205 | + req_to_next_token_ids, b_req_mtp_start_loc, new_next_token_ids, b_req_idx, b_mtp_index |
| 206 | + ) |
| 207 | + mtp_scatter_next_token_ids( |
| 208 | + req_to_next_token_ids, b_req_mtp_start_loc, all_next_token_ids, b_req_idx, b_mtp_index, mtp_accept_len |
| 209 | + ) |
| 210 | + print(mtp_accept_len) |
| 211 | + print(req_to_next_token_ids) |
| 212 | + print(accepted_index) |
| 213 | + |
| 214 | + |
| 215 | +def test_gen_b_req_mtp_start_loc(): |
| 216 | + b_mtp_index = torch.tensor([0, 1, 0, 1, 2], dtype=torch.int32, device="cuda") |
| 217 | + gt_output = torch.where(b_mtp_index == 0)[0] |
| 218 | + b_req_mtp_start_loc = gen_b_req_mtp_start_loc(b_mtp_index, 2) |
| 219 | + print(b_req_mtp_start_loc, gt_output) |
| 220 | + |
| 221 | + |
| 222 | +if __name__ == "__main__": |
| 223 | + # test_mtp_verify() |
| 224 | + test_gen_b_req_mtp_start_loc() |
0 commit comments