|
| 1 | +import torch |
| 2 | + |
| 3 | +import triton |
| 4 | +import triton.language as tl |
| 5 | +import torch.nn.functional as F |
| 6 | +import numpy as np |
| 7 | + |
| 8 | + |
| 9 | +@triton.jit |
| 10 | +def _fwd_kernel_apply_penalty_cache( |
| 11 | + Logits, |
| 12 | + req_idxs, |
| 13 | + presence_penalty, |
| 14 | + freqency_penalty, |
| 15 | + repetition_penalty, |
| 16 | + p_token_vocabs, |
| 17 | + stride_logit_b, |
| 18 | + stride_p_token_vocabs_b, |
| 19 | + BLOCK_P: tl.constexpr, |
| 20 | +): |
| 21 | + cur_batch = tl.program_id(0) |
| 22 | + block_idx = tl.program_id(1) |
| 23 | + token_idx = tl.load(req_idxs + cur_batch) |
| 24 | + cur_freqency = tl.load(freqency_penalty + token_idx) |
| 25 | + cur_presence = tl.load(presence_penalty + token_idx) |
| 26 | + cur_repetition = tl.load(repetition_penalty + token_idx) |
| 27 | + |
| 28 | + batch_ids = BLOCK_P * block_idx + tl.arange(0, BLOCK_P) |
| 29 | + batch_ids_count = tl.load( |
| 30 | + p_token_vocabs + token_idx * stride_p_token_vocabs_b + batch_ids, |
| 31 | + mask=batch_ids < stride_p_token_vocabs_b, |
| 32 | + other=0, |
| 33 | + ) |
| 34 | + row_start_ptr = Logits + cur_batch * stride_logit_b |
| 35 | + cur_offset = row_start_ptr + batch_ids |
| 36 | + cur_logits = tl.load(cur_offset, mask=batch_ids_count > 0, other=0.0) |
| 37 | + rep_logits = tl.where(cur_logits > 0, cur_logits / cur_repetition, cur_logits * cur_repetition) |
| 38 | + freq_logits = rep_logits - batch_ids_count * cur_freqency |
| 39 | + pre_logits = freq_logits - cur_presence |
| 40 | + output_ptr = Logits + cur_batch * stride_logit_b + batch_ids |
| 41 | + tl.store(output_ptr, pre_logits, mask=batch_ids_count > 0) |
| 42 | + return |
| 43 | + |
| 44 | + |
| 45 | +@triton.jit |
| 46 | +def _eos_penalty( |
| 47 | + Logits, |
| 48 | + req_idxs, |
| 49 | + p_token_lens, |
| 50 | + exponential_decay_length_penalties, |
| 51 | + length_penalty_idx, |
| 52 | + eos_ids, |
| 53 | + mask_eos_reqs, |
| 54 | + stride_logit_b, |
| 55 | + EOS_ID_NUM: tl.constexpr, |
| 56 | +): |
| 57 | + cur_batch = tl.program_id(0) |
| 58 | + token_idx = tl.load(req_idxs + cur_batch) |
| 59 | + exponential_decay_length_penalty = tl.load(exponential_decay_length_penalties + token_idx) |
| 60 | + token_lens = tl.load(p_token_lens + cur_batch) |
| 61 | + length_penalty = tl.maximum(token_lens - tl.load(length_penalty_idx + token_idx), 0) |
| 62 | + penalty_scale = tl.exp2(tl.log2(exponential_decay_length_penalty) * length_penalty) - 1 |
| 63 | + mask_eos = tl.load(mask_eos_reqs + token_idx) |
| 64 | + for eos_index in range(EOS_ID_NUM): |
| 65 | + eos_id = tl.load(eos_ids + eos_index) |
| 66 | + cur_eos_logit_ptr = Logits + cur_batch * stride_logit_b + eos_id |
| 67 | + cur_eos_logit = tl.load(cur_eos_logit_ptr) |
| 68 | + cur_eos_logit = cur_eos_logit + tl.abs(cur_eos_logit) * penalty_scale |
| 69 | + cur_eos_logit = tl.where(token_lens < mask_eos, -10000000.0, cur_eos_logit) |
| 70 | + tl.store(cur_eos_logit_ptr, cur_eos_logit) |
| 71 | + return |
| 72 | + |
| 73 | + |
| 74 | +@torch.no_grad() |
| 75 | +def apply_penalty_cache( |
| 76 | + Logits, |
| 77 | + req_idxs, |
| 78 | + presence_penalty, |
| 79 | + freqency_penalty, |
| 80 | + repetition_penalty, |
| 81 | + p_token_vocabs, |
| 82 | + p_token_lens, |
| 83 | + exponential_decay_length_penalties, |
| 84 | + length_penalty_idx, |
| 85 | + eos_ids, |
| 86 | + mask_eos_reqs, |
| 87 | + is_eos_penalty=False, |
| 88 | +): |
| 89 | + assert Logits.is_contiguous() |
| 90 | + BLOCK_P = 1024 |
| 91 | + num_warps = 8 |
| 92 | + vocab_size = p_token_vocabs.shape[1] |
| 93 | + _fwd_kernel_apply_penalty_cache[(Logits.shape[0], triton.cdiv(vocab_size, BLOCK_P))]( |
| 94 | + Logits, |
| 95 | + req_idxs, |
| 96 | + presence_penalty, |
| 97 | + freqency_penalty, |
| 98 | + repetition_penalty, |
| 99 | + p_token_vocabs, |
| 100 | + Logits.stride(0), |
| 101 | + p_token_vocabs.stride(0), |
| 102 | + num_warps=num_warps, |
| 103 | + BLOCK_P=BLOCK_P, |
| 104 | + ) |
| 105 | + if is_eos_penalty: |
| 106 | + p_token_lens = p_token_vocabs[req_idxs].sum(dim=1).cuda() if p_token_lens is None else p_token_lens |
| 107 | + _eos_penalty[(Logits.shape[0],)]( |
| 108 | + Logits, |
| 109 | + req_idxs, |
| 110 | + p_token_lens, |
| 111 | + exponential_decay_length_penalties, |
| 112 | + length_penalty_idx, |
| 113 | + eos_ids, |
| 114 | + mask_eos_reqs, |
| 115 | + Logits.stride(0), |
| 116 | + num_warps=num_warps, |
| 117 | + EOS_ID_NUM=eos_ids.shape[0], |
| 118 | + ) |
| 119 | + return |
| 120 | + |
| 121 | + |
| 122 | +if __name__ == "__main__": |
| 123 | + from .apply_penalty import apply_penalty |
| 124 | + |
| 125 | + bs = 200 |
| 126 | + vocab_size = 150000 |
| 127 | + p_tokens = 2000 |
| 128 | + repseats = 1000 |
| 129 | + req_idxs = torch.arange(bs).cuda() |
| 130 | + logits = torch.randn((bs, vocab_size), dtype=torch.float32).cuda() |
| 131 | + logits2 = logits.clone() |
| 132 | + |
| 133 | + presence_penalty = torch.randn((bs,), dtype=torch.float32).cuda() + 1e-5 |
| 134 | + freqency_penalty = torch.randn((bs,), dtype=torch.float32).cuda() |
| 135 | + repetition_penalty = torch.randn((bs,), dtype=torch.float32).cuda() |
| 136 | + exponential_decay_length_penalties = torch.rand(bs).cuda() |
| 137 | + eos_ids = torch.tensor([999]).cuda() |
| 138 | + |
| 139 | + p_seq_len = torch.cat([torch.tensor([0]), torch.randint(1, p_tokens, (bs,))]).cuda() |
| 140 | + p_token_ids = torch.randint(0, vocab_size, (p_seq_len.sum(),)).cuda() |
| 141 | + i = 0 |
| 142 | + for s_l in p_seq_len[1:]: |
| 143 | + p_token_ids[i : i + s_l] = torch.randperm(vocab_size)[:s_l] |
| 144 | + i += s_l |
| 145 | + p_token_counts = torch.randint(1, repseats, (p_seq_len.sum(),)).cuda() |
| 146 | + p_cumsum_seq_len = p_seq_len.cumsum(dim=0).cuda() |
| 147 | + p_token_vocabs = torch.zeros((bs, vocab_size), dtype=torch.int16).cuda() |
| 148 | + i = 0 |
| 149 | + b = 0 |
| 150 | + for token_id, count in zip(p_token_ids, p_token_counts): |
| 151 | + p_token_vocabs[b][token_id] = count |
| 152 | + i += 1 |
| 153 | + if i == p_seq_len[b + 1]: |
| 154 | + b += 1 |
| 155 | + i = 0 |
| 156 | + |
| 157 | + p_token_lens = p_token_vocabs.sum(dim=1).cuda() |
| 158 | + length_penalty_idx = torch.randint(0, p_tokens, (bs,)).cuda() |
| 159 | + len_idx = torch.tensor([max(p_token_lens[i] - length_penalty_idx[i], 0) for i in range(bs)]).cuda() |
| 160 | + mask_eos_reqs = torch.randint(1, p_tokens, (bs,)).cuda() |
| 161 | + mask_bool = torch.tensor([p_token_lens[i] < mask_eos_reqs[i] for i in range(bs)]).cuda() |
| 162 | + |
| 163 | + fn1 = lambda: apply_penalty( |
| 164 | + logits, |
| 165 | + presence_penalty, |
| 166 | + freqency_penalty, |
| 167 | + repetition_penalty, |
| 168 | + p_token_ids, |
| 169 | + p_token_counts, |
| 170 | + p_cumsum_seq_len, |
| 171 | + exponential_decay_length_penalties, |
| 172 | + len_idx, |
| 173 | + eos_ids, |
| 174 | + mask_bool, |
| 175 | + ) |
| 176 | + |
| 177 | + fn2 = lambda: apply_penalty_cache( |
| 178 | + logits2, |
| 179 | + req_idxs, |
| 180 | + presence_penalty, |
| 181 | + freqency_penalty, |
| 182 | + repetition_penalty, |
| 183 | + p_token_vocabs, |
| 184 | + p_token_lens, |
| 185 | + exponential_decay_length_penalties, |
| 186 | + length_penalty_idx, |
| 187 | + eos_ids, |
| 188 | + mask_eos_reqs, |
| 189 | + ) |
| 190 | + fn1() |
| 191 | + fn2() |
| 192 | + cos = F.cosine_similarity(logits, logits2).mean() |
| 193 | + print("cos =", cos) |
| 194 | + assert torch.allclose(logits, logits2, atol=1e-2, rtol=0) |
| 195 | + |
| 196 | + ms1 = triton.testing.do_bench(fn1) |
| 197 | + ms2 = triton.testing.do_bench(fn2) |
| 198 | + print("ms1 =", ms1, "ms2 =", ms2) |
0 commit comments