44import triton .language as tl
55import numpy as np
66
7+
78@triton .jit
89def _fwd_kernel_apply_penalty (
9- Logits , presence_penalty , freqency_penalty , repetition_penalty ,
10- p_token_ids , p_token_counts , p_cumsum_seq_len ,
11- stride_logit_b , stride_logit_s ,
12- BLOCK_P : tl .constexpr
10+ Logits ,
11+ presence_penalty ,
12+ freqency_penalty ,
13+ repetition_penalty ,
14+ p_token_ids ,
15+ p_token_counts ,
16+ p_cumsum_seq_len ,
17+ exponential_decay_length_penalties ,
18+ length_penalty_idx ,
19+ eos_ids ,
20+ mask_eos_reqs ,
21+ stride_logit_b ,
22+ stride_logit_s ,
23+ BLOCK_P : tl .constexpr ,
24+ EOS_ID_NUM : tl .constexpr ,
1325):
1426 cur_batch = tl .program_id (0 )
1527 cur_freqency = tl .load (freqency_penalty + cur_batch )
@@ -18,36 +30,70 @@ def _fwd_kernel_apply_penalty(
1830
1931 cur_batch_start_index = tl .load (p_cumsum_seq_len + cur_batch )
2032 cur_batch_end_index = tl .load (p_cumsum_seq_len + cur_batch + 1 )
33+ for block_start_index in range (cur_batch_start_index , cur_batch_end_index , BLOCK_P ):
34+ cur_batch_id_offset = block_start_index + tl .arange (0 , BLOCK_P )
35+ batch_ids = tl .load (p_token_ids + cur_batch_id_offset , mask = cur_batch_id_offset < cur_batch_end_index , other = 0 )
36+ batch_ids_count = tl .load (
37+ p_token_counts + cur_batch_id_offset , mask = cur_batch_id_offset < cur_batch_end_index , other = 0
38+ )
39+
40+ row_start_ptr = Logits + cur_batch * stride_logit_b
41+ cur_offset = row_start_ptr + batch_ids
42+ cur_logits = tl .load (cur_offset , mask = cur_batch_id_offset < cur_batch_end_index , other = 0.0 )
43+ rep_logits = tl .where (cur_logits > 0 , cur_logits / cur_repetition , cur_logits * cur_repetition )
44+ freq_logits = rep_logits - batch_ids_count * cur_freqency
45+ pre_logits = freq_logits - cur_presence
46+ output_ptr = Logits + cur_batch * stride_logit_b + batch_ids
47+ tl .store (output_ptr , pre_logits , mask = cur_batch_id_offset < cur_batch_end_index )
2148
22- cur_batch_id_offset = cur_batch_start_index + tl .arange (0 , BLOCK_P )
23- batch_ids = tl .load (p_token_ids + cur_batch_id_offset , mask = cur_batch_id_offset < cur_batch_end_index , other = 0 )
24- batch_ids_count = tl .load (p_token_counts + cur_batch_id_offset , mask = cur_batch_id_offset < cur_batch_end_index , other = 0 )
25-
26- row_start_ptr = Logits + cur_batch * stride_logit_b
27- cur_offset = row_start_ptr + batch_ids
28- cur_logits = tl .load (cur_offset , mask = cur_batch_id_offset < cur_batch_end_index , other = 0.0 )
29- rep_logits = tl .where (cur_logits > 0 , cur_logits / cur_repetition , cur_logits * cur_repetition )
30- freq_logits = rep_logits - batch_ids_count * cur_freqency
31- pre_logits = freq_logits - cur_presence
32- output_ptr = Logits + cur_batch * stride_logit_b + batch_ids
33- tl .store (output_ptr , pre_logits , mask = cur_batch_id_offset < cur_batch_end_index )
49+ mask_eos = tl .load (mask_eos_reqs + cur_batch )
50+ exponential_decay_length_penalty = tl .load (exponential_decay_length_penalties + cur_batch )
51+ length_penalty = tl .load (length_penalty_idx + cur_batch )
52+ penalty_scale = tl .exp2 (tl .log2 (exponential_decay_length_penalty ) * length_penalty ) - 1
3453
54+ for eos_index in range (EOS_ID_NUM ):
55+ eos_id = tl .load (eos_ids + eos_index )
56+ cur_eos_logit_ptr = Logits + cur_batch * stride_logit_b + eos_id
57+ cur_eos_logit = tl .load (cur_eos_logit_ptr )
58+ cur_eos_logit = cur_eos_logit + tl .abs (cur_eos_logit ) * penalty_scale
59+ cur_eos_logit = tl .where (mask_eos , - 10000000.0 , cur_eos_logit )
60+ tl .store (cur_eos_logit_ptr , cur_eos_logit )
3561 return
3662
63+
3764@torch .no_grad ()
38- def apply_penalty (Logits , presence_penalty , freqency_penalty , repetition_penalty , p_token_ids , p_token_counts , p_cumsum_seq_len , p_max_len_in_batch ):
65+ def apply_penalty (
66+ Logits ,
67+ presence_penalty ,
68+ freqency_penalty ,
69+ repetition_penalty ,
70+ p_token_ids ,
71+ p_token_counts ,
72+ p_cumsum_seq_len ,
73+ exponential_decay_length_penalties ,
74+ length_penalty_idx ,
75+ eos_ids ,
76+ mask_eos_reqs ,
77+ ):
3978 assert Logits .is_contiguous ()
40- BLOCK = triton .next_power_of_2 (p_max_len_in_batch )
41- if BLOCK <= 512 :
42- BLOCK = 512
43- elif BLOCK <= 1024 :
44- BLOCK = 1024
79+ BLOCK_P = 1024
4580 num_warps = 8
46- _fwd_kernel_apply_penalty [(Logits .shape [0 ], )](
47- Logits , presence_penalty , freqency_penalty , repetition_penalty ,
48- p_token_ids , p_token_counts , p_cumsum_seq_len ,
49- Logits .stride (0 ), Logits .stride (1 ),
81+ _fwd_kernel_apply_penalty [(Logits .shape [0 ],)](
82+ Logits ,
83+ presence_penalty ,
84+ freqency_penalty ,
85+ repetition_penalty ,
86+ p_token_ids ,
87+ p_token_counts ,
88+ p_cumsum_seq_len ,
89+ exponential_decay_length_penalties ,
90+ length_penalty_idx ,
91+ eos_ids ,
92+ mask_eos_reqs ,
93+ Logits .stride (0 ),
94+ Logits .stride (1 ),
5095 num_warps = num_warps ,
51- BLOCK_P = BLOCK
96+ BLOCK_P = BLOCK_P ,
97+ EOS_ID_NUM = eos_ids .shape [0 ],
5298 )
5399 return
0 commit comments