@@ -21,32 +21,40 @@ def rotary_kernel(
2121 HALF_D : tl .constexpr ,
2222 BLOCK_D : tl .constexpr ,
2323):
24- pid_h = tl .program_id (0 ).to (tl .int64 )
25- pid_l = tl .program_id (1 ).to (tl .int64 )
24+ # program_id(0) : seq idx)
25+ # program_id(1) : head idx)
26+ # program_id(2) : 块化的 d 维
27+ pid_l = tl .program_id (0 ).to (tl .int64 )
28+ pid_h = tl .program_id (1 ).to (tl .int64 )
2629 pid_blk = tl .program_id (2 ).to (tl .int64 )
2730
31+ # 当前 block 处理的 d 下标
2832 offs_d = tl .arange (0 , BLOCK_D )
2933 d = pid_blk * BLOCK_D + offs_d
3034 mask = d < D
3135
36+ # base 是 (l, h, 0) 的起始地址
3237 base = pid_l * stride_l + pid_h * stride_h
3338
39+ # ---- load x / cos / sin ----
3440 in_ptr = inp_ptr + base + d * stride_d
35- cos_ptr_ = cos_ptr + pid_l * stride_cos_l + d
36- sin_ptr_ = sin_ptr + pid_l * stride_sin_l + d
3741
38- x = tl .load (in_ptr , mask = mask )
39- cos = tl .load (cos_ptr_ , mask = mask )
40- sin = tl .load (sin_ptr_ , mask = mask )
42+ cos_ptr_ = cos_ptr + pid_l * stride_cos_l + d * stride_cos_d
43+ sin_ptr_ = sin_ptr + pid_l * stride_sin_l + d * stride_sin_d
4144
45+ x = tl .load (in_ptr , mask = mask , other = 0.0 )
46+ cos = tl .load (cos_ptr_ , mask = mask , other = 0.0 )
47+ sin = tl .load (sin_ptr_ , mask = mask , other = 0.0 )
48+
49+ # ---- rotary partner (偶/奇半维互换) ----
4250 partner_d = tl .where (d < HALF_D , d + HALF_D , d - HALF_D )
4351 partner_ptr = inp_ptr + base + partner_d * stride_d
44- partner_val = tl .load (partner_ptr , mask = mask )
52+ partner_val = tl .load (partner_ptr , mask = mask , other = 0.0 )
4553 rotated = tl .where (d < HALF_D , - partner_val , partner_val )
4654
4755 y = x * cos + rotated * sin
4856
49- out_ptr_ = out_ptr + base + d
57+ out_ptr_ = out_ptr + base + d * stride_d
5058 tl .store (out_ptr_ , y , mask = mask )
5159
5260
@@ -57,17 +65,19 @@ def apply_rotary_pos_emb_triton(
5765 assert cos .is_contiguous () and sin .is_contiguous ()
5866 if tensor .ndim != 3 :
5967 raise RuntimeError ("tensor shape should be [L, H, D]" )
68+
6069 orig_dtype = tensor .dtype
61- x = tensor .float ()
70+ x = tensor .float () # [L, H, D]
6271
72+ # cos/sin: [L, D/2] -> [L, D]
6373 cos = cos .repeat (1 , 2 ).view (cos .size (0 ), - 1 ).contiguous ().float ()
6474 sin = sin .repeat (1 , 2 ).view (sin .size (0 ), - 1 ).contiguous ().float ()
6575
6676 L , H , D = x .shape
6777 HALF_D = D // 2
6878 y = torch .empty_like (x )
6979
70- grid = (H , L , triton .cdiv (D , BLOCK_D ))
80+ grid = (L , H , triton .cdiv (D , BLOCK_D ))
7181
7282 rotary_kernel [grid ](
7383 inp_ptr = x ,
0 commit comments