@@ -21,40 +21,33 @@ def rotary_kernel(
2121 HALF_D : tl .constexpr ,
2222 BLOCK_D : tl .constexpr ,
2323):
24- # program_id(0) : seq idx)
25- # program_id(1) : head idx)
26- # program_id(2) : 块化的 d 维
2724 pid_l = tl .program_id (0 ).to (tl .int64 )
2825 pid_h = tl .program_id (1 ).to (tl .int64 )
2926 pid_blk = tl .program_id (2 ).to (tl .int64 )
3027
31- # 当前 block 处理的 d 下标
3228 offs_d = tl .arange (0 , BLOCK_D )
3329 d = pid_blk * BLOCK_D + offs_d
3430 mask = d < D
3531
36- # base 是 (l, h, 0) 的起始地址
3732 base = pid_l * stride_l + pid_h * stride_h
3833
39- # ---- load x / cos / sin ----
4034 in_ptr = inp_ptr + base + d * stride_d
4135
42- cos_ptr_ = cos_ptr + pid_l * stride_cos_l + d * stride_cos_d
43- sin_ptr_ = sin_ptr + pid_l * stride_sin_l + d * stride_sin_d
36+ cos_ptr_ = cos_ptr + pid_l * stride_cos_l + d
37+ sin_ptr_ = sin_ptr + pid_l * stride_sin_l + d
4438
45- x = tl .load (in_ptr , mask = mask , other = 0.0 )
46- cos = tl .load (cos_ptr_ , mask = mask , other = 0.0 )
47- sin = tl .load (sin_ptr_ , mask = mask , other = 0.0 )
39+ x = tl .load (in_ptr , mask = mask )
40+ cos = tl .load (cos_ptr_ , mask = mask )
41+ sin = tl .load (sin_ptr_ , mask = mask )
4842
49- # ---- rotary partner (偶/奇半维互换) ----
5043 partner_d = tl .where (d < HALF_D , d + HALF_D , d - HALF_D )
5144 partner_ptr = inp_ptr + base + partner_d * stride_d
52- partner_val = tl .load (partner_ptr , mask = mask , other = 0.0 )
45+ partner_val = tl .load (partner_ptr , mask = mask )
5346 rotated = tl .where (d < HALF_D , - partner_val , partner_val )
5447
5548 y = x * cos + rotated * sin
5649
57- out_ptr_ = out_ptr + base + d * stride_d
50+ out_ptr_ = out_ptr + base + d
5851 tl .store (out_ptr_ , y , mask = mask )
5952
6053
@@ -67,9 +60,8 @@ def apply_rotary_pos_emb_triton(
6760 raise RuntimeError ("tensor shape should be [L, H, D]" )
6861
6962 orig_dtype = tensor .dtype
70- x = tensor .float () # [L, H, D]
63+ x = tensor .float ()
7164
72- # cos/sin: [L, D/2] -> [L, D]
7365 cos = cos .repeat (1 , 2 ).view (cos .size (0 ), - 1 ).contiguous ().float ()
7466 sin = sin .repeat (1 , 2 ).view (sin .size (0 ), - 1 ).contiguous ().float ()
7567
0 commit comments