77
88
99@triton .jit
10- def _fwd_kernel_apply_penalty_cache (
10+ def _kernel_apply_penalty_cache (
1111 Logits ,
1212 req_idxs ,
1313 presence_penalty ,
@@ -27,7 +27,7 @@ def _fwd_kernel_apply_penalty_cache(
2727
2828 batch_ids = BLOCK_P * block_idx + tl .arange (0 , BLOCK_P )
2929 batch_ids_count = tl .load (
30- p_token_vocabs + token_idx * stride_p_token_vocabs_b + batch_ids ,
30+ p_token_vocabs + cur_batch * stride_p_token_vocabs_b + batch_ids ,
3131 mask = batch_ids < stride_p_token_vocabs_b ,
3232 other = 0 ,
3333 )
@@ -43,7 +43,7 @@ def _fwd_kernel_apply_penalty_cache(
4343
4444
4545@triton .jit
46- def _eos_penalty (
46+ def _kernel_eos_penalty (
4747 Logits ,
4848 req_idxs ,
4949 p_token_lens ,
@@ -71,26 +71,58 @@ def _eos_penalty(
7171 return
7272
7373
74+ @triton .jit
75+ def _kernel_bincount (
76+ req_idxs ,
77+ input ,
78+ output ,
79+ input_lens ,
80+ stride_input_b ,
81+ stride_output_b ,
82+ BLOCK_SIZE : tl .constexpr ,
83+ ):
84+ cur_batch = tl .program_id (0 )
85+ req_idx = tl .load (req_idxs + cur_batch )
86+ block_idx = tl .program_id (1 )
87+ input_ptr = input + req_idx * stride_input_b + block_idx * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
88+ input_len = tl .load (input_lens + req_idx )
89+ mask = block_idx * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE ) < input_len
90+ token_id = tl .load (input_ptr , mask = mask , other = 0 )
91+ tl .atomic_add (output + cur_batch * stride_output_b + token_id , 1 , mask = mask )
92+ return
93+
94+
7495@torch .no_grad ()
7596def apply_penalty_cache (
7697 Logits ,
7798 req_idxs ,
7899 presence_penalty ,
79100 freqency_penalty ,
80101 repetition_penalty ,
81- p_token_vocabs ,
102+ p_token_ids ,
82103 p_token_lens ,
83104 exponential_decay_length_penalties ,
84105 length_penalty_idx ,
85106 eos_ids ,
86107 mask_eos_reqs ,
87- is_eos_penalty = False ,
108+ vocab_size : tl .constexpr ,
109+ is_eos_penalty : tl .constexpr = False ,
88110):
89111 assert Logits .is_contiguous ()
90112 BLOCK_P = 1024
91- num_warps = 8
92- vocab_size = p_token_vocabs .shape [1 ]
93- _fwd_kernel_apply_penalty_cache [(Logits .shape [0 ], triton .cdiv (vocab_size , BLOCK_P ))](
113+ num_warps = 4
114+ p_token_vocabs = torch .zeros ((Logits .shape [0 ], vocab_size ), dtype = torch .int32 , device = "cuda" )
115+ _kernel_bincount [(Logits .shape [0 ], triton .cdiv (p_token_ids .stride (0 ), BLOCK_P ))](
116+ req_idxs ,
117+ p_token_ids ,
118+ p_token_vocabs ,
119+ p_token_lens ,
120+ p_token_ids .stride (0 ),
121+ p_token_vocabs .stride (0 ),
122+ num_warps = num_warps ,
123+ BLOCK_SIZE = BLOCK_P ,
124+ )
125+ _kernel_apply_penalty_cache [(Logits .shape [0 ], triton .cdiv (vocab_size , BLOCK_P ))](
94126 Logits ,
95127 req_idxs ,
96128 presence_penalty ,
@@ -103,8 +135,7 @@ def apply_penalty_cache(
103135 BLOCK_P = BLOCK_P ,
104136 )
105137 if is_eos_penalty :
106- p_token_lens = p_token_vocabs [req_idxs ].sum (dim = 1 ).cuda () if p_token_lens is None else p_token_lens
107- _eos_penalty [(Logits .shape [0 ],)](
138+ _kernel_eos_penalty [(Logits .shape [0 ],)](
108139 Logits ,
109140 req_idxs ,
110141 p_token_lens ,
@@ -121,11 +152,13 @@ def apply_penalty_cache(
121152
122153if __name__ == "__main__" :
123154 from .apply_penalty import apply_penalty
155+ from torch .nn .utils .rnn import pad_sequence
124156
125157 bs = 200
126158 vocab_size = 150000
127- p_tokens = 2000
128- repseats = 1000
159+ p_tokens = 3000
160+ max_token_len = 16384
161+ repseats = max_token_len // p_tokens
129162 req_idxs = torch .arange (bs ).cuda ()
130163 logits = torch .randn ((bs , vocab_size ), dtype = torch .float32 ).cuda ()
131164 logits2 = logits .clone ()
@@ -144,7 +177,7 @@ def apply_penalty_cache(
144177 i += s_l
145178 p_token_counts = torch .randint (1 , repseats , (p_seq_len .sum (),)).cuda ()
146179 p_cumsum_seq_len = p_seq_len .cumsum (dim = 0 ).cuda ()
147- p_token_vocabs = torch .zeros ((bs , vocab_size ), dtype = torch .int16 ).cuda ()
180+ p_token_vocabs = torch .zeros ((bs , vocab_size ), dtype = torch .int32 ).cuda ()
148181 i = 0
149182 b = 0
150183 for token_id , count in zip (p_token_ids , p_token_counts ):
@@ -154,7 +187,12 @@ def apply_penalty_cache(
154187 b += 1
155188 i = 0
156189
157- p_token_lens = p_token_vocabs .sum (dim = 1 ).cuda ()
190+ p_token_lens = p_token_vocabs .cuda ().sum (dim = 1 )
191+ assert p_token_lens .max () < max_token_len
192+ token_idx = torch .arange (vocab_size ).cuda ()
193+ sequences = [token_idx .repeat_interleave (p_token_vocabs [b ]) for b in range (bs )] # shape = [sum(rep_nums)]
194+ p_token_mat = pad_sequence (sequences , batch_first = True , padding_value = 0 ).to (torch .int32 ).cuda ()
195+
158196 length_penalty_idx = torch .randint (0 , p_tokens , (bs ,)).cuda ()
159197 len_idx = torch .tensor ([max (p_token_lens [i ] - length_penalty_idx [i ], 0 ) for i in range (bs )]).cuda ()
160198 mask_eos_reqs = torch .randint (1 , p_tokens , (bs ,)).cuda ()
@@ -180,12 +218,13 @@ def apply_penalty_cache(
180218 presence_penalty ,
181219 freqency_penalty ,
182220 repetition_penalty ,
183- p_token_vocabs ,
221+ p_token_mat ,
184222 p_token_lens ,
185223 exponential_decay_length_penalties ,
186224 length_penalty_idx ,
187225 eos_ids ,
188226 mask_eos_reqs ,
227+ vocab_size ,
189228 )
190229 fn1 ()
191230 fn2 ()
0 commit comments