33import triton .language as tl
44
55from liger_kernel .ops .backends ._ascend .ub_manager import compute_default_tiling_strategy
6+ from liger_kernel .ops .utils import get_npu_core_count
67
78
89@triton .jit
@@ -15,111 +16,102 @@ def _triton_qwen2vl_mrope_npu(
1516 sin ,
1617 sl ,
1718 bs : tl .constexpr ,
19+ total_rows : tl .constexpr ,
1820 n_qh : tl .constexpr ,
1921 n_kh : tl .constexpr ,
2022 hd : tl .constexpr ,
2123 mrope_section_t : tl .constexpr ,
2224 mrope_section_h : tl .constexpr ,
2325 BLOCK_Q : tl .constexpr ,
2426 BLOCK_K : tl .constexpr ,
27+ NUM_STAGES : tl .constexpr ,
2528 BACKWARD_PASS : tl .constexpr = False ,
2629):
27- pid = tl .program_id (0 ).to (tl .int64 )
30+ program_id = tl .program_id (0 )
31+ num_programs = tl .num_programs (0 )
2832
29- t_end = mrope_section_t
30- h_end = t_end + mrope_section_h
33+ rows_per_program = (total_rows + num_programs - 1 ) // num_programs
34+ start_row = program_id * rows_per_program
35+ actual_rows = tl .minimum (rows_per_program , total_rows - start_row )
3136
32- t_cos = cos + pid * hd
33- h_cos = t_cos + bs * sl * hd
34- w_cos = h_cos + bs * sl * hd
35- t_sin = sin + pid * hd
36- h_sin = t_sin + bs * sl * hd
37- w_sin = h_sin + bs * sl * hd
37+ for row_offset in tl .range (0 , actual_rows , num_stages = NUM_STAGES ):
38+ pid = start_row + row_offset
3839
39- q_base = q_ptr + pid * q_row_stride
40- k_base = k_ptr + pid * k_row_stride
40+ t_end = mrope_section_t
41+ h_end = t_end + mrope_section_h
4142
42- d_idx = tl .arange (0 , hd // 2 )
43- d_mask = d_idx < (hd // 2 )
43+ t_cos = cos + pid * hd
44+ h_cos = t_cos + bs * sl * hd
45+ w_cos = h_cos + bs * sl * hd
46+ t_sin = sin + pid * hd
47+ h_sin = t_sin + bs * sl * hd
48+ w_sin = h_sin + bs * sl * hd
4449
45- pos_mask_t = d_idx < t_end
46- pos_mask_h = ( d_idx >= t_end ) & ( d_idx < h_end )
50+ q_base = q_ptr + pid * q_row_stride
51+ k_base = k_ptr + pid * k_row_stride
4752
48- text_cos_vals = tl .load (t_cos + d_idx , mask = d_mask , other = 0 )
49- text_sin_vals = tl .load (t_sin + d_idx , mask = d_mask , other = 0 )
50- height_cos_vals = tl .load (h_cos + d_idx , mask = d_mask , other = 0 )
51- height_sin_vals = tl .load (h_sin + d_idx , mask = d_mask , other = 0 )
52- width_cos_vals = tl .load (w_cos + d_idx , mask = d_mask , other = 0 )
53- width_sin_vals = tl .load (w_sin + d_idx , mask = d_mask , other = 0 )
53+ d_idx = tl .arange (0 , hd // 2 )
54+ d_mask = d_idx < (hd // 2 )
5455
55- cos_vals = tl . where ( pos_mask_t , text_cos_vals , tl . where ( pos_mask_h , height_cos_vals , width_cos_vals ))
56- sin_vals = tl . where ( pos_mask_t , text_sin_vals , tl . where ( pos_mask_h , height_sin_vals , width_sin_vals ) )
56+ pos_mask_t = d_idx < t_end
57+ pos_mask_h = ( d_idx >= t_end ) & ( d_idx < h_end )
5758
58- for qh_block in range (0 , n_qh , BLOCK_Q ):
59- qh_idx = tl .arange (0 , BLOCK_Q ) + qh_block
60- qh_mask = qh_idx < n_qh
59+ text_cos_vals = tl .load (t_cos + d_idx , mask = d_mask , other = 0 )
60+ text_sin_vals = tl .load (t_sin + d_idx , mask = d_mask , other = 0 )
61+ height_cos_vals = tl .load (h_cos + d_idx , mask = d_mask , other = 0 )
62+ height_sin_vals = tl .load (h_sin + d_idx , mask = d_mask , other = 0 )
63+ width_cos_vals = tl .load (w_cos + d_idx , mask = d_mask , other = 0 )
64+ width_sin_vals = tl .load (w_sin + d_idx , mask = d_mask , other = 0 )
6165
62- block_mask = qh_mask [:, None ] & d_mask [ None , :]
63- offsets = qh_idx [:, None ] * hd + d_idx [ None , :]
66+ cos_vals = tl . where ( pos_mask_t , text_cos_vals , tl . where ( pos_mask_h , height_cos_vals , width_cos_vals ))
67+ sin_vals = tl . where ( pos_mask_t , text_sin_vals , tl . where ( pos_mask_h , height_sin_vals , width_sin_vals ))
6468
65- q_left = tl .load (q_base + offsets , mask = block_mask , other = 0 )
66- q_right = tl .load (q_base + offsets + (hd // 2 ), mask = block_mask , other = 0 )
69+ # Process q heads in chunks to prevent UB overflow
70+ for qh_block in range (0 , n_qh , BLOCK_Q ):
71+ qh_idx = tl .arange (0 , BLOCK_Q ) + qh_block
72+ qh_mask = qh_idx < n_qh
6773
68- if not BACKWARD_PASS :
69- new_left = q_left * cos_vals - q_right * sin_vals
70- new_right = q_right * cos_vals + q_left * sin_vals
71- else :
72- new_left = q_left * cos_vals + q_right * sin_vals
73- new_right = q_right * cos_vals - q_left * sin_vals
74+ block_mask = qh_mask [:, None ] & d_mask [None , :]
75+ offsets = qh_idx [:, None ] * hd + d_idx [None , :]
7476
75- tl .store (q_base + offsets , new_left , mask = block_mask )
76- tl .store (q_base + offsets + (hd // 2 ), new_right , mask = block_mask )
77+ q_left = tl .load (q_base + offsets , mask = block_mask , other = 0 )
78+ q_right = tl .load (q_base + offsets + (hd // 2 ), mask = block_mask , other = 0 )
7779
78- for kh_block in range (0 , n_kh , BLOCK_K ):
79- kh_idx = tl .arange (0 , BLOCK_K ) + kh_block
80- kh_mask = kh_idx < n_kh
80+ if not BACKWARD_PASS :
81+ new_left = q_left * cos_vals - q_right * sin_vals
82+ new_right = q_right * cos_vals + q_left * sin_vals
83+ else :
84+ new_left = q_left * cos_vals + q_right * sin_vals
85+ new_right = q_right * cos_vals - q_left * sin_vals
8186
82- block_mask = kh_mask [:, None ] & d_mask [ None , :]
83- offsets = kh_idx [:, None ] * hd + d_idx [ None , :]
87+ tl . store ( q_base + offsets , new_left , mask = block_mask )
88+ tl . store ( q_base + offsets + ( hd // 2 ), new_right , mask = block_mask )
8489
85- k_left = tl .load (k_base + offsets , mask = block_mask , other = 0 )
86- k_right = tl .load (k_base + offsets + (hd // 2 ), mask = block_mask , other = 0 )
90+ # Process k heads in chunks to prevent UB overflow
91+ for kh_block in range (0 , n_kh , BLOCK_K ):
92+ kh_idx = tl .arange (0 , BLOCK_K ) + kh_block
93+ kh_mask = kh_idx < n_kh
8794
88- if not BACKWARD_PASS :
89- new_left = k_left * cos_vals - k_right * sin_vals
90- new_right = k_right * cos_vals + k_left * sin_vals
91- else :
92- new_left = k_left * cos_vals + k_right * sin_vals
93- new_right = k_right * cos_vals - k_left * sin_vals
95+ block_mask = kh_mask [:, None ] & d_mask [None , :]
96+ offsets = kh_idx [:, None ] * hd + d_idx [None , :]
9497
95- tl .store (k_base + offsets , new_left , mask = block_mask )
96- tl .store (k_base + offsets + (hd // 2 ), new_right , mask = block_mask )
98+ k_left = tl .load (k_base + offsets , mask = block_mask , other = 0 )
99+ k_right = tl .load (k_base + offsets + (hd // 2 ), mask = block_mask , other = 0 )
97100
101+ if not BACKWARD_PASS :
102+ new_left = k_left * cos_vals - k_right * sin_vals
103+ new_right = k_right * cos_vals + k_left * sin_vals
104+ else :
105+ new_left = k_left * cos_vals + k_right * sin_vals
106+ new_right = k_right * cos_vals - k_left * sin_vals
98107
99- def qwen2vl_mrope_forward (q , k , cos , sin , mrope_section ):
100- # transpose it back to the physical shape because Triton looks at the physical storage
101- # note: q and k are incontiguous before the transformation and will become contiguous after transpose
102- q = q .transpose (1 , 2 )
103- k = k .transpose (1 , 2 )
108+ tl .store (k_base + offsets , new_left , mask = block_mask )
109+ tl .store (k_base + offsets + (hd // 2 ), new_right , mask = block_mask )
104110
105- batch_size , seq_len , n_q_head , head_dim = q .shape
106- n_kv_head = k .shape [2 ]
107- pad_hd = triton .next_power_of_2 (head_dim )
108- pad_n_q_head = triton .next_power_of_2 (n_q_head )
109- pad_n_kv_head = triton .next_power_of_2 (n_kv_head )
110111
111- n_row = batch_size * seq_len
112-
113- # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
114- q = q .contiguous ()
115- k = k .contiguous ()
116- cos = cos .contiguous ()
117- sin = sin .contiguous ()
118-
119- # Compute tiling strategy based on UB capacity
120- dtype_size = q .element_size ()
112+ def get_optimal_block_size_mrope (pad_n_q_head , pad_n_kv_head , pad_hd , dtype_size ):
121113 # MROPE forward tiling strategy:
122- # - cos_vals and sin_vals (include text, height and width) are loaded once outside loops (shared): (pad_hd // 2) * 4 = 2 * pad_hd elements each
114+ # - cos_vals and sin_vals (include text, height and width) are loaded once outside loops (shared): (pad_hd // 2) * 6 = 3 * pad_hd elements each
123115 # - In q heads loop (peak memory):
124116 # * q_left: BLOCK_Q * (pad_hd // 2) elements
125117 # * q_right: BLOCK_Q * (pad_hd // 2) elements
@@ -133,9 +125,9 @@ def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
133125 # * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result)
134126 # * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements
135127 # - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case
136- # - Plus shared cos/sin: 2 * (pad_hd // 2) = pad_hd elements
137- # - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + pad_hd) * dtype_size * 8 bits
138- # - Simplified: (2 * BLOCK_SIZE + 2 ) * pad_hd * dtype_size * 8 bits
128+ # - Plus shared cos/sin: 6 * (pad_hd // 2) = 3 * pad_hd elements
129+ # - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + 3 * pad_hd) * dtype_size * 8 bits
130+ # - Simplified: (2 * BLOCK_SIZE + 3 ) * pad_hd * dtype_size * 8 bits
139131 # - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
140132 # - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
141133 # - tiling_dims: (0, 0) means first dimension of each shape can be tiled
@@ -156,9 +148,38 @@ def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
156148 BLOCK_K , _ = k_tile_shape
157149 else :
158150 # Fallback to conservative defaults
159- BLOCK_Q = triton .next_power_of_2 (pad_n_q_head )
160- BLOCK_K = triton .next_power_of_2 (pad_n_kv_head )
161- _triton_qwen2vl_mrope_npu [(n_row ,)](
151+ BLOCK_Q = 2048
152+ BLOCK_K = 2048
153+
154+ return BLOCK_Q , BLOCK_K
155+
156+
157+ def qwen2vl_mrope_forward (q , k , cos , sin , mrope_section ):
158+ # transpose it back to the physical shape because Triton looks at the physical storage
159+ q = q .transpose (1 , 2 )
160+ k = k .transpose (1 , 2 )
161+
162+ batch_size , seq_len , n_q_head , head_dim = q .shape
163+ n_kv_head = k .shape [2 ]
164+ pad_hd = triton .next_power_of_2 (head_dim )
165+ pad_n_q_head = triton .next_power_of_2 (n_q_head )
166+ pad_n_kv_head = triton .next_power_of_2 (n_kv_head )
167+
168+ n_row = batch_size * seq_len
169+
170+ # ensure tensors passed into the kernel are contiguous
171+ q = q .contiguous ()
172+ k = k .contiguous ()
173+ cos = cos .contiguous ()
174+ sin = sin .contiguous ()
175+
176+ dtype_size = q .element_size ()
177+ BLOCK_Q , BLOCK_K = get_optimal_block_size_mrope (pad_n_q_head , pad_n_kv_head , pad_hd , dtype_size )
178+
179+ num_cores = get_npu_core_count ()
180+ grid_size = min (num_cores , n_row )
181+
182+ _triton_qwen2vl_mrope_npu [(grid_size ,)](
162183 q ,
163184 q .stride (1 ),
164185 k ,
@@ -167,13 +188,15 @@ def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
167188 sin ,
168189 seq_len ,
169190 batch_size ,
191+ n_row ,
170192 n_q_head ,
171193 n_kv_head ,
172194 head_dim ,
173195 mrope_section [0 ],
174196 mrope_section [1 ],
175197 BLOCK_Q ,
176198 BLOCK_K ,
199+ NUM_STAGES = 3 ,
177200 BACKWARD_PASS = False ,
178201 )
179202 return q .transpose (1 , 2 ), k .transpose (1 , 2 ), cos , sin
@@ -195,49 +218,13 @@ def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
195218 dq = dq .contiguous ()
196219 dk = dk .contiguous ()
197220
198- # Compute tiling strategy based on UB capacity
199221 dtype_size = dq .element_size ()
200- # MROPE backward tiling strategy:
201- # - cos_vals and sin_vals (include text, height and width) are loaded once outside loops (shared): (pad_hd // 2) * 4 = 2 * pad_hd elements each
202- # - In q heads loop (peak memory):
203- # * q_left: BLOCK_Q * (pad_hd // 2) elements
204- # * q_right: BLOCK_Q * (pad_hd // 2) elements
205- # * new_left: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
206- # * new_right: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
207- # * Total: 4 * BLOCK_Q * (pad_hd // 2) = 2 * BLOCK_Q * pad_hd elements
208- # - In k heads loop (peak memory):
209- # * k_left: BLOCK_K * (pad_hd // 2) elements
210- # * k_right: BLOCK_K * (pad_hd // 2) elements
211- # * new_left: BLOCK_K * (pad_hd // 2) elements (intermediate result)
212- # * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result)
213- # * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements
214- # - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case
215- # - Plus shared cos/sin: 2 * (pad_hd // 2) = pad_hd elements
216- # - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + pad_hd) * dtype_size * 8 bits
217- # - Simplified: (2 * BLOCK_SIZE + 2) * pad_hd * dtype_size * 8 bits
218- # - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
219- # - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
220- # - tiling_dims: (0, 0) means first dimension of each shape can be tiled
221- # - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
222- shapes = ((pad_n_q_head , pad_hd ), (pad_n_kv_head , pad_hd ))
223- tile_shapes = compute_default_tiling_strategy (
224- safety_margin = 0.90 ,
225- dtype_size = dtype_size ,
226- memory_multiplier = 3.0 ,
227- shapes = shapes ,
228- tiling_dims = (0 , 0 ),
229- )
222+ BLOCK_Q , BLOCK_K = get_optimal_block_size_mrope (pad_n_q_head , pad_n_kv_head , pad_hd , dtype_size )
230223
231- if tile_shapes is not None and len (tile_shapes ) == len (shapes ):
232- # Strategy returns ((block_size_q, pad_hd), (block_size_kv, pad_hd))
233- q_tile_shape , k_tile_shape = tile_shapes
234- BLOCK_Q , _ = q_tile_shape
235- BLOCK_K , _ = k_tile_shape
236- else :
237- # Fallback to conservative defaults
238- BLOCK_Q = triton .next_power_of_2 (pad_n_q_head )
239- BLOCK_K = triton .next_power_of_2 (pad_n_kv_head )
240- _triton_qwen2vl_mrope_npu [(n_row ,)](
224+ num_cores = get_npu_core_count ()
225+ grid_size = min (num_cores , n_row )
226+
227+ _triton_qwen2vl_mrope_npu [(grid_size ,)](
241228 dq ,
242229 dq .stride (1 ),
243230 dk ,
@@ -246,13 +233,15 @@ def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
246233 sin ,
247234 seq_len ,
248235 batch_size ,
236+ n_row ,
249237 n_q_head ,
250238 n_kv_head ,
251239 head_dim ,
252240 mrope_section [0 ],
253241 mrope_section [1 ],
254242 BLOCK_Q ,
255243 BLOCK_K ,
244+ NUM_STAGES = 3 ,
256245 BACKWARD_PASS = True ,
257246 )
258247 return dq .transpose (1 , 2 ), dk .transpose (1 , 2 )
@@ -272,6 +261,7 @@ def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1):
272261 ctx .mrope_section = mrope_section
273262 return q , k
274263
264+ @staticmethod
275265 def backward (ctx , dq , dk ):
276266 """
277267 dq size: (bsz, n_q_head, seq_len, head_dim)
0 commit comments