@@ -17,42 +17,38 @@ def rotary_kernel(
1717 stride_cos_d ,
1818 stride_sin_l ,
1919 stride_sin_d ,
20- max_L ,
20+ total_len ,
2121 D : tl .constexpr ,
2222 HALF_D : tl .constexpr ,
2323 BLOCK_D : tl .constexpr ,
24- BLOCK_L : tl .constexpr ,
2524):
2625 pid_h = tl .program_id (0 ).to (tl .int64 )
27- pid_l_blk = tl .program_id (1 ).to (tl .int64 )
28- pid_blk_d = tl .program_id (2 ).to (tl .int64 )
26+ pid_l_start = tl .program_id (1 ).to (tl .int64 )
27+ pid_blk = tl .program_id (2 ).to (tl .int64 )
2928
3029 offs_d = tl .arange (0 , BLOCK_D )
31- d = pid_blk_d * BLOCK_D + offs_d
30+ d = pid_blk * BLOCK_D + offs_d
3231 mask = d < D
32+ for pid_l in tl .range (pid_l_start , total_len , step = tl .num_programs (axis = 1 )):
3333
34- seq_start = pid_l_blk * BLOCK_L
35- seq_end = (pid_l_blk + 1 ) * BLOCK_L
36- seq_end = tl .where (seq_end < max_L , seq_end , max_L )
37-
38- for seq in tl .range (seq_start , seq_end ):
39- base = seq * stride_l + pid_h * stride_h
34+ base = pid_l * stride_l + pid_h * stride_h
4035
4136 in_ptr = inp_ptr + base + d * stride_d
42- cos_ptr_ = cos_ptr + seq * stride_cos_l + d * stride_cos_d
43- sin_ptr_ = sin_ptr + seq * stride_sin_l + d * stride_sin_d
37+ cos_ptr_ = cos_ptr + pid_l * stride_cos_l + d
38+ sin_ptr_ = sin_ptr + pid_l * stride_sin_l + d
4439
45- x = tl .load (in_ptr , mask = mask , other = 0.0 )
46- cos = tl .load (cos_ptr_ , mask = mask , other = 0.0 )
47- sin = tl .load (sin_ptr_ , mask = mask , other = 0.0 )
40+ x = tl .load (in_ptr , mask = mask )
41+ cos = tl .load (cos_ptr_ , mask = mask )
42+ sin = tl .load (sin_ptr_ , mask = mask )
4843
4944 partner_d = tl .where (d < HALF_D , d + HALF_D , d - HALF_D )
5045 partner_ptr = inp_ptr + base + partner_d * stride_d
51- partner_val = tl .load (partner_ptr , mask = mask , other = 0.0 )
46+ partner_val = tl .load (partner_ptr , mask = mask )
5247 rotated = tl .where (d < HALF_D , - partner_val , partner_val )
5348
5449 y = x * cos + rotated * sin
55- out_ptr_ = out_ptr + base + d * stride_d
50+
51+ out_ptr_ = out_ptr + base + d
5652 tl .store (out_ptr_ , y , mask = mask )
5753
5854
@@ -72,12 +68,12 @@ def apply_rotary_pos_emb_triton(
7268 L , H , D = x .shape
7369 HALF_D = D // 2
7470 y = torch .empty_like (x )
75-
76- if L <= 4096 :
77- BLOCK_L = 16
71+ if L < 1024 :
72+ grid_L = L
7873 else :
79- BLOCK_L = 128
80- grid = (H , triton .cdiv (L , BLOCK_L ), triton .cdiv (D , BLOCK_D ))
74+ grid_L = 1024
75+
76+ grid = (H , grid_L , triton .cdiv (D , BLOCK_D ))
8177
8278 rotary_kernel [grid ](
8379 inp_ptr = x ,
@@ -91,11 +87,10 @@ def apply_rotary_pos_emb_triton(
9187 stride_cos_d = cos .stride (1 ),
9288 stride_sin_l = sin .stride (0 ),
9389 stride_sin_d = sin .stride (1 ),
94- max_L = L ,
90+ total_len = L ,
9591 D = D ,
9692 HALF_D = HALF_D ,
9793 BLOCK_D = BLOCK_D ,
98- BLOCK_L = BLOCK_L ,
9994 )
10095
10196 return y .to (orig_dtype )
0 commit comments