@@ -17,37 +17,43 @@ def rotary_kernel(
1717 stride_cos_d ,
1818 stride_sin_l ,
1919 stride_sin_d ,
20+ max_L ,
2021 D : tl .constexpr ,
2122 HALF_D : tl .constexpr ,
2223 BLOCK_D : tl .constexpr ,
24+ BLOCK_L : tl .constexpr ,
2325):
2426 pid_h = tl .program_id (0 ).to (tl .int64 )
25- pid_l = tl .program_id (1 ).to (tl .int64 )
26- pid_blk = tl .program_id (2 ).to (tl .int64 )
27+ pid_l_blk = tl .program_id (1 ).to (tl .int64 )
28+ pid_blk_d = tl .program_id (2 ).to (tl .int64 )
2729
2830 offs_d = tl .arange (0 , BLOCK_D )
29- d = pid_blk * BLOCK_D + offs_d
31+ d = pid_blk_d * BLOCK_D + offs_d
3032 mask = d < D
3133
32- base = pid_l * stride_l + pid_h * stride_h
34+ seq_start = pid_l_blk * BLOCK_L
35+ seq_end = (pid_l_blk + 1 ) * BLOCK_L
36+ seq_end = tl .where (seq_end < max_L , seq_end , max_L )
3337
34- in_ptr = inp_ptr + base + d * stride_d
35- cos_ptr_ = cos_ptr + pid_l * stride_cos_l + d
36- sin_ptr_ = sin_ptr + pid_l * stride_sin_l + d
38+ for seq in tl .range (seq_start , seq_end ):
39+ base = seq * stride_l + pid_h * stride_h
3740
38- x = tl . load ( in_ptr , mask = mask )
39- cos = tl . load ( cos_ptr_ , mask = mask )
40- sin = tl . load ( sin_ptr_ , mask = mask )
41+ in_ptr = inp_ptr + base + d * stride_d
42+ cos_ptr_ = cos_ptr + seq * stride_cos_l + d * stride_cos_d
43+ sin_ptr_ = sin_ptr + seq * stride_sin_l + d * stride_sin_d
4144
42- partner_d = tl .where (d < HALF_D , d + HALF_D , d - HALF_D )
43- partner_ptr = inp_ptr + base + partner_d * stride_d
44- partner_val = tl .load (partner_ptr , mask = mask )
45- rotated = tl .where (d < HALF_D , - partner_val , partner_val )
45+ x = tl .load (in_ptr , mask = mask , other = 0.0 )
46+ cos = tl .load (cos_ptr_ , mask = mask , other = 0.0 )
47+ sin = tl .load (sin_ptr_ , mask = mask , other = 0.0 )
4648
47- y = x * cos + rotated * sin
49+ partner_d = tl .where (d < HALF_D , d + HALF_D , d - HALF_D )
50+ partner_ptr = inp_ptr + base + partner_d * stride_d
51+ partner_val = tl .load (partner_ptr , mask = mask , other = 0.0 )
52+ rotated = tl .where (d < HALF_D , - partner_val , partner_val )
4853
49- out_ptr_ = out_ptr + base + d
50- tl .store (out_ptr_ , y , mask = mask )
54+ y = x * cos + rotated * sin
55+ out_ptr_ = out_ptr + base + d * stride_d
56+ tl .store (out_ptr_ , y , mask = mask )
5157
5258
5359def apply_rotary_pos_emb_triton (
@@ -67,7 +73,8 @@ def apply_rotary_pos_emb_triton(
6773 HALF_D = D // 2
6874 y = torch .empty_like (x )
6975
70- grid = (H , L , triton .cdiv (D , BLOCK_D ))
76+ BLOCK_L = 16
77+ grid = (H , triton .cdiv (L , BLOCK_L ), triton .cdiv (D , BLOCK_D ))
7178
7279 rotary_kernel [grid ](
7380 inp_ptr = x ,
@@ -81,9 +88,11 @@ def apply_rotary_pos_emb_triton(
8188 stride_cos_d = cos .stride (1 ),
8289 stride_sin_l = sin .stride (0 ),
8390 stride_sin_d = sin .stride (1 ),
91+ max_L = L ,
8492 D = D ,
8593 HALF_D = HALF_D ,
8694 BLOCK_D = BLOCK_D ,
95+ BLOCK_L = BLOCK_L ,
8796 )
8897
8998 return y .to (orig_dtype )
0 commit comments