55
66
77@triton .jit
8- def mrope_kernel_combined (
8+ def mrope_kernel (
99 Q_ptr ,
1010 K_ptr ,
1111 COS_ptr ,
1212 SIN_ptr ,
13- AXIS_MAP_ptr ,
14- Q_out_ptr ,
15- K_out_ptr ,
13+ AXIS_ptr ,
14+ QO_ptr ,
15+ KO_ptr ,
1616 B : tl .int32 ,
1717 H_q : tl .int32 ,
1818 H_k : tl .int32 ,
1919 L : tl .int32 ,
2020 D : tl .int32 ,
21- HALF : tl .int32 ,
21+ HALF : tl .constexpr ,
22+ q_sb : tl .int32 ,
23+ q_sh : tl .int32 ,
24+ q_sl : tl .int32 ,
25+ q_sd : tl .int32 ,
26+ k_sb : tl .int32 ,
27+ k_sh : tl .int32 ,
28+ k_sl : tl .int32 ,
29+ k_sd : tl .int32 ,
30+ qo_sb : tl .int32 ,
31+ qo_sh : tl .int32 ,
32+ qo_sl : tl .int32 ,
33+ qo_sd : tl .int32 ,
34+ ko_sb : tl .int32 ,
35+ ko_sh : tl .int32 ,
36+ ko_sl : tl .int32 ,
37+ ko_sd : tl .int32 ,
2238 BLOCK_D : tl .constexpr ,
2339):
24- total_h = H_q + H_k
2540
41+ total_h = H_q + H_k
2642 pid_bh = tl .program_id (0 )
2743 pid_l = tl .program_id (1 )
2844
2945 b = pid_bh // total_h
30- head_local = pid_bh - b * total_h
46+ h_local = pid_bh - b * total_h
3147
32- # decide whether this head comes from q or k
33- is_q = head_local < H_q
34- head_q = head_local
35- head_k = head_local - H_q
48+ is_q = h_local < H_q
49+ h_q = h_local
50+ h_k = h_local - H_q
3651
37- base_ptr = tl .where (is_q , Q_ptr , K_ptr )
38- out_ptr = tl .where (is_q , Q_out_ptr , K_out_ptr )
39- h_sub = tl .where (is_q , head_q , head_k )
40- H_sub = tl .where (is_q , H_q , H_k )
52+ sb = tl .where (is_q , q_sb , k_sb )
53+ sh = tl .where (is_q , q_sh , k_sh )
54+ sl = tl .where (is_q , q_sl , k_sl )
55+ sd = tl .where (is_q , q_sd , k_sd )
56+
57+ osb = tl .where (is_q , qo_sb , ko_sb )
58+ osh = tl .where (is_q , qo_sh , ko_sh )
59+ osl = tl .where (is_q , qo_sl , ko_sl )
60+ osd = tl .where (is_q , qo_sd , ko_sd )
4161
42- # base offset for (b, h_sub, pid_l)
43- base = ((b * H_sub + h_sub ) * L + pid_l ) * D
62+ base_ptr = tl .where (is_q , Q_ptr , K_ptr )
63+ out_ptr = tl .where (is_q , QO_ptr , KO_ptr )
64+ h_index = tl .where (is_q , h_q , h_k )
4465
66+ base = b * sb + h_index * sh + pid_l * sl
4567 offs = tl .arange (0 , BLOCK_D )
46- idx = base + offs
4768 mask = offs < D
4869
70+ idx = base + offs * sd
4971 vals = tl .load (base_ptr + idx , mask = mask , other = 0.0 )
50- axis_id = tl .load (AXIS_MAP_ptr + offs , mask = mask , other = 0 )
51- axis_id_b = b * 3 + axis_id
5272
53- seq_off = pid_l * D
54- cos_idx = axis_id_b * (L * D ) + seq_off + offs
73+ rot_offs = tl .where (offs < HALF , (offs + HALF ) * sd , (offs - HALF ) * sd )
74+ rot_vals = tl .load (base_ptr + base + rot_offs , mask = mask , other = 0.0 )
75+ rot_vals = tl .where (offs < HALF , - rot_vals , rot_vals )
76+
77+ axis_id = tl .load (AXIS_ptr + offs , mask = mask , other = 0 ) # 0,1,2
78+ cos_idx = axis_id * (L * D ) + pid_l * D + offs
5579 c = tl .load (COS_ptr + cos_idx , mask = mask , other = 0.0 )
5680 s = tl .load (SIN_ptr + cos_idx , mask = mask , other = 0.0 )
5781
58- # rotate_half
59- rot_idx = tl .where (offs < HALF , idx + HALF , idx - HALF )
60- rot_vals = tl .load (base_ptr + rot_idx , mask = mask , other = 0.0 )
61- sign = tl .where (offs < HALF , - 1.0 , 1.0 )
62- rot_vals *= sign
82+ out = vals * c + rot_vals * s
6383
64- out_vals = vals * c + rot_vals * s
65- tl .store (out_ptr + idx , out_vals , mask = mask )
84+ out_idx = b * osb + h_index * osh + pid_l * osl + offs * osd
85+ tl .store (out_ptr + out_idx , out , mask = mask )
6686
6787
68- def mrope_triton (q : torch .Tensor , k : torch .Tensor , cos : torch .Tensor , sin : torch .Tensor , mrope_section ):
88+ def mrope_triton (q : torch .Tensor , k : torch .Tensor , cos : torch .Tensor , sin : torch .Tensor , axis_map : torch .Tensor ):
89+
6990 B , H_q , L , D = q .shape
7091 H_k = k .shape [1 ]
92+ HALF = D // 2
7193
72- # build axis_map 0/1/2 label per feature dim
73- axis_map = []
74- for i , n in enumerate (mrope_section * 2 ):
75- axis_map += [i % 3 ] * n
76- axis_map = torch .tensor (axis_map , dtype = torch .int32 , device = q .device )
77-
78- cos_flat = cos .transpose (0 , 1 ).expand (B , 3 , L , D ).contiguous ().reshape (- 1 )
79- sin_flat = sin .transpose (0 , 1 ).expand (B , 3 , L , D ).contiguous ().reshape (- 1 )
94+ q_sb , q_sh , q_sl , q_sd = map (int , q .stride ())
95+ k_sb , k_sh , k_sl , k_sd = map (int , k .stride ())
8096
8197 q_out = torch .empty_like (q )
8298 k_out = torch .empty_like (k )
99+ qo_sb , qo_sh , qo_sl , qo_sd = map (int , q_out .stride ())
100+ ko_sb , ko_sh , ko_sl , ko_sd = map (int , k_out .stride ())
101+
102+ cos_flat = cos .transpose (0 , 1 ).contiguous ().reshape (- 1 )
103+ sin_flat = sin .transpose (0 , 1 ).contiguous ().reshape (- 1 )
83104
84105 grid = (B * (H_q + H_k ), L )
85- mrope_kernel_combined [grid ](
106+
107+ mrope_kernel [grid ](
86108 q ,
87109 k ,
88110 cos_flat ,
@@ -95,8 +117,26 @@ def mrope_triton(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch
95117 H_k ,
96118 L ,
97119 D ,
98- D // 2 ,
120+ HALF ,
121+ q_sb ,
122+ q_sh ,
123+ q_sl ,
124+ q_sd ,
125+ k_sb ,
126+ k_sh ,
127+ k_sl ,
128+ k_sd ,
129+ qo_sb ,
130+ qo_sh ,
131+ qo_sl ,
132+ qo_sd ,
133+ ko_sb ,
134+ ko_sh ,
135+ ko_sl ,
136+ ko_sd ,
99137 BLOCK_D = 128 ,
138+ num_warps = 4 ,
139+ num_stages = 3 ,
100140 )
101141 return q_out , k_out
102142
@@ -125,21 +165,25 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
125165 k_out = k * cos_embed + rotate_half (k ) * sin_embed
126166 return q_out , k_out
127167
128- B , H_q , H_k , L , D = 1 , 28 , 4 , 16384 , 128
168+ B , H_q , H_k , L , D = 3 , 28 , 4 , 16384 , 128
129169 mrope_section = [16 , 24 , 24 ]
130170 torch .manual_seed (0 )
131171 device = "cuda"
132172
133- q = torch .rand (B , H_q , L , D , dtype = torch .float32 , device = device )
134- k = torch .rand (B , H_k , L , D , dtype = torch .float32 , device = device )
173+ q = torch .rand (B , H_q , L , D , dtype = torch .float32 , device = device ). transpose ( 1 , 2 ). contiguous (). transpose ( 1 , 2 )
174+ k = torch .rand (B , H_k , L , D , dtype = torch .float32 , device = device ). transpose ( 1 , 2 ). contiguous (). transpose ( 1 , 2 )
135175 cos = torch .rand (3 , 1 , L , D , dtype = torch .float32 , device = device )
136176 sin = torch .rand (3 , 1 , L , D , dtype = torch .float32 , device = device )
137177
138178 # 精度对比
179+ axis_map = []
180+ for i , n in enumerate (mrope_section * 2 ):
181+ axis_map += [i % 3 ] * n
182+ axis_map = torch .tensor (axis_map , dtype = torch .int32 , device = "cuda" )
139183 ref_q , ref_k = apply_multimodal_rotary_pos_emb (q , k , cos , sin , mrope_section , unsqueeze_dim = 1 )
140184
141185 torch .cuda .synchronize ()
142- out_q , out_k = mrope_triton (q , k , cos , sin , mrope_section )
186+ out_q , out_k = mrope_triton (q , k , cos , sin , axis_map )
143187 torch .cuda .synchronize ()
144188
145189 err_q = (out_q - ref_q ).abs ().max ().item ()
@@ -162,7 +206,7 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
162206
163207 e0 .record ()
164208 for _ in range (n_iter ):
165- _ = mrope_triton (q , k , cos , sin , mrope_section )
209+ _ = mrope_triton (q , k , cos , sin , axis_map )
166210 e1 .record ()
167211 torch .cuda .synchronize ()
168212 t_tri = e0 .elapsed_time (e1 ) / n_iter
0 commit comments