|
| 1 | +import torch |
| 2 | + |
| 3 | +import triton |
| 4 | +import triton.language as tl |
| 5 | + |
| 6 | + |
| 7 | +@triton.jit |
| 8 | +def _fwd_kernel_gather_and_scatter( |
| 9 | + probs_idx, |
| 10 | + probs_sort, |
| 11 | + req_to_next_token_ids, |
| 12 | + req_to_next_token_probs, |
| 13 | + sampled_index, |
| 14 | + b_req_idx, |
| 15 | + probs_idx_stride, |
| 16 | + probs_sort_stride, |
| 17 | + req_to_next_token_ids_stride, |
| 18 | + req_to_next_token_probs_stride, |
| 19 | +): |
| 20 | + cur_index = tl.program_id(0) |
| 21 | + cur_req_idx = tl.load(b_req_idx + cur_index) |
| 22 | + cur_sampled_index = tl.load(sampled_index + cur_index) |
| 23 | + cur_token_index = tl.load(probs_idx + cur_index * probs_idx_stride + cur_sampled_index) |
| 24 | + cur_token_probs = tl.load(probs_sort + cur_index * probs_sort_stride + cur_sampled_index) |
| 25 | + tl.store(req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride, cur_token_index) |
| 26 | + tl.store(req_to_next_token_probs + cur_req_idx * req_to_next_token_probs_stride, tl.log(cur_token_probs)) |
| 27 | + return |
| 28 | + |
| 29 | + |
| 30 | +@torch.no_grad() |
| 31 | +def gather_and_scatter_token_to_cpu( |
| 32 | + probs_idx: torch.Tensor, |
| 33 | + probs_sort: torch.Tensor, |
| 34 | + req_to_next_token_ids: torch.Tensor, |
| 35 | + req_to_next_token_probs: torch.Tensor, |
| 36 | + sampled_index: torch.Tensor, |
| 37 | + b_req_idx: torch.Tensor, |
| 38 | +): |
| 39 | + """ |
| 40 | + This function is used to gather the next_token_id(GPU tensor) and next_token_probs(GPU tensor) |
| 41 | + info to the req_to_next_token_ids and req_to_next_token_probs(CPU tensor). |
| 42 | + Args: |
| 43 | + probs_idx: (batch_size, vocab_size) |
| 44 | + probs_sort: (batch_size, vocab_size) |
| 45 | + req_to_next_token_ids: (max_req_num,) |
| 46 | + req_to_next_token_probs: (max_req_num,) |
| 47 | + sampled_index: (batch_size,) |
| 48 | + b_req_idx: (batch_size,) |
| 49 | + """ |
| 50 | + assert probs_idx.shape == probs_sort.shape |
| 51 | + assert sampled_index.shape[0] == b_req_idx.shape[0] |
| 52 | + batch_size = b_req_idx.shape[0] |
| 53 | + grid = (batch_size,) |
| 54 | + num_warps = 1 |
| 55 | + |
| 56 | + _fwd_kernel_gather_and_scatter[grid]( |
| 57 | + probs_idx, |
| 58 | + probs_sort, |
| 59 | + req_to_next_token_ids, |
| 60 | + req_to_next_token_probs, |
| 61 | + sampled_index, |
| 62 | + b_req_idx, |
| 63 | + probs_idx.stride(0), |
| 64 | + probs_sort.stride(0), |
| 65 | + req_to_next_token_ids.stride(0), |
| 66 | + req_to_next_token_probs.stride(0), |
| 67 | + num_warps=num_warps, |
| 68 | + num_stages=1, |
| 69 | + ) |
| 70 | + return |
| 71 | + |
| 72 | + |
| 73 | +@triton.jit |
| 74 | +def _fwd_kernel_scatter( |
| 75 | + token_info, |
| 76 | + req_to_token_info, |
| 77 | + b_req_idx, |
| 78 | + req_to_token_info_stride, |
| 79 | +): |
| 80 | + cur_index = tl.program_id(0) |
| 81 | + cur_req_idx = tl.load(b_req_idx + cur_index) |
| 82 | + cur_token_info = tl.load(token_info + cur_index) |
| 83 | + tl.store(req_to_token_info + cur_req_idx * req_to_token_info_stride, cur_token_info) |
| 84 | + return |
| 85 | + |
| 86 | + |
| 87 | +@torch.no_grad() |
| 88 | +def scatter_token_to_cpu(token_info: torch.Tensor, req_to_token_info: torch.Tensor, b_req_idx: torch.Tensor): |
| 89 | + """ |
| 90 | + This function is used to scatter the token_info(GPU tensor) to the req_to_token_info(CPU tensor). |
| 91 | + Args: |
| 92 | + token_info: (batch_size, vocab_size) |
| 93 | + req_to_token_info: (max_req_num,) |
| 94 | + b_req_idx: (batch_size,) |
| 95 | + """ |
| 96 | + assert token_info.shape[0] == b_req_idx.shape[0] |
| 97 | + batch_size = b_req_idx.shape[0] |
| 98 | + grid = (batch_size,) |
| 99 | + num_warps = 1 |
| 100 | + |
| 101 | + _fwd_kernel_scatter[grid]( |
| 102 | + token_info, |
| 103 | + req_to_token_info, |
| 104 | + b_req_idx, |
| 105 | + req_to_token_info.stride(0), |
| 106 | + num_warps=num_warps, |
| 107 | + num_stages=1, |
| 108 | + ) |
| 109 | + return |
| 110 | + |
| 111 | + |
| 112 | +@triton.jit |
| 113 | +def _fwd_kernel_gather( |
| 114 | + req_to_token_info, |
| 115 | + output, |
| 116 | + b_req_idx, |
| 117 | +): |
| 118 | + cur_index = tl.program_id(0) |
| 119 | + cur_req_idx = tl.load(b_req_idx + cur_index) |
| 120 | + cur_token_info = tl.load(req_to_token_info + cur_req_idx) |
| 121 | + tl.store(output + cur_index, cur_token_info) |
| 122 | + return |
| 123 | + |
| 124 | + |
| 125 | +def gather_token_from_cpu(req_to_token_info: torch.Tensor, b_req_idx: torch.Tensor): |
| 126 | + """ |
| 127 | + This function is used to gather the token_info(CPU tensor) to the token_info(GPU tensor). |
| 128 | + Args: |
| 129 | + req_to_token_info: (max_req_num,) |
| 130 | + b_req_idx: (batch_size,) |
| 131 | + Returns: |
| 132 | + output: (batch_size,) |
| 133 | + """ |
| 134 | + batch_size = b_req_idx.shape[0] |
| 135 | + output = torch.empty_like(b_req_idx) |
| 136 | + grid = (batch_size,) |
| 137 | + num_warps = 1 |
| 138 | + _fwd_kernel_gather[grid]( |
| 139 | + req_to_token_info, |
| 140 | + output, |
| 141 | + b_req_idx, |
| 142 | + num_warps=num_warps, |
| 143 | + num_stages=1, |
| 144 | + ) |
| 145 | + return output |
| 146 | + |
| 147 | + |
| 148 | +def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor): |
| 149 | + probs_sort, probs_idx = probs.sort(dim=-1, descending=True) |
| 150 | + |
| 151 | + probs_sum = torch.cumsum(probs_sort, dim=-1) |
| 152 | + probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0 |
| 153 | + |
| 154 | + probs_sort[torch.arange(0, probs.shape[-1], device="cuda").view(1, -1) >= top_ks.view(-1, 1)] = 0.0 |
| 155 | + |
| 156 | + return probs_sort, probs_idx |
| 157 | + |
| 158 | + |
| 159 | +def test_gather_and_scatter_token_to_cpu(): |
| 160 | + batch_size = 30 |
| 161 | + vocab_size = 60000 |
| 162 | + req_to_next_token_ids = torch.ones((1000,), dtype=torch.int32, pin_memory=True) |
| 163 | + req_to_next_token_probs = torch.ones((1000,), dtype=torch.float32, pin_memory=True) |
| 164 | + req_ids = torch.arange(20, 20 + batch_size, dtype=torch.int32).cuda() |
| 165 | + probs = torch.randn((batch_size, vocab_size)).cuda() |
| 166 | + top_ps = torch.rand((batch_size,)).cuda() |
| 167 | + top_ks = torch.ones((batch_size,), dtype=torch.int32).cuda() |
| 168 | + probs_sort, probs_idx = _top_p_top_k(probs, top_ps, top_ks) |
| 169 | + sampled_index = torch.multinomial(probs_sort, num_samples=1, replacement=True) |
| 170 | + batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index) |
| 171 | + batch_next_token_probs = torch.gather(probs_sort, dim=1, index=sampled_index) |
| 172 | + |
| 173 | + gather_and_scatter_token_to_cpu( |
| 174 | + probs_idx, probs_sort, req_to_next_token_ids, req_to_next_token_probs, sampled_index, req_ids |
| 175 | + ) |
| 176 | + diff_ids = (req_to_next_token_ids[20 : 20 + batch_size].cuda() - batch_next_token_ids.view(-1)).abs().max() |
| 177 | + diff_probs = (req_to_next_token_probs[20 : 20 + batch_size].cuda() - batch_next_token_probs.view(-1)).abs().max() |
| 178 | + assert diff_ids < 1e-6 |
| 179 | + assert diff_probs < 1e-6 |
| 180 | + print("test_gather_and_scatter_token_to_cpu passed") |
| 181 | + |
| 182 | + |
| 183 | +def test_scatter_token_to_cpu(): |
| 184 | + batch_size = 30 |
| 185 | + req_to_token_info = torch.zeros((1000,), dtype=torch.float32, pin_memory=True) |
| 186 | + token_info = torch.randn((batch_size,)).cuda() |
| 187 | + req_ids = torch.arange(20, 20 + batch_size, dtype=torch.int32).cuda() |
| 188 | + scatter_token_to_cpu(token_info, req_to_token_info, req_ids) |
| 189 | + diff = (req_to_token_info[20 : 20 + batch_size].cuda() - token_info).abs().max() |
| 190 | + assert diff < 1e-6 |
| 191 | + print("test_scatter_token_to_cpu passed") |
| 192 | + |
| 193 | + |
| 194 | +def test_gather_token_from_cpu(): |
| 195 | + batch_size = 30 |
| 196 | + req_to_token_info = torch.zeros((1000,), dtype=torch.int32, pin_memory=True) |
| 197 | + token_info = torch.randn((batch_size,)).cuda() |
| 198 | + req_ids = torch.arange(20, 20 + batch_size, dtype=torch.int32).cuda() |
| 199 | + scatter_token_to_cpu(token_info, req_to_token_info, req_ids) |
| 200 | + output = gather_token_from_cpu(req_to_token_info, req_ids) |
| 201 | + diff = (token_info - output).abs().max() |
| 202 | + assert diff < 1e-6 |
| 203 | + print("test_gather_token_from_cpu passed") |
| 204 | + |
| 205 | + |
| 206 | +if __name__ == "__main__": |
| 207 | + test_gather_and_scatter_token_to_cpu() |
| 208 | + test_scatter_token_to_cpu() |
| 209 | + test_gather_token_from_cpu() |
0 commit comments