55from liger_kernel .ops .backends ._ascend .ub_manager import compute_default_tiling_strategy
66
77
8- def _prepare_freqs (freqs_cis : torch .Tensor , seq_len : int , head_dim_half : int ):
9- """
10- Canonicalize freqs to (seq_len, head_dim_half) real/imag tensors.
11-
12- Supports:
13- - complex freqs: (..., head_dim_half) complex -> real/imag
14- - packed freqs: (..., 2*head_dim_half) real -> split into real/imag
15- """
16- if freqs_cis .is_complex ():
17- freqs_real = freqs_cis .real
18- freqs_imag = freqs_cis .imag
19- else :
20- if freqs_cis .shape [- 1 ] == 2 * head_dim_half :
21- freqs_real = freqs_cis [..., :head_dim_half ]
22- freqs_imag = freqs_cis [..., head_dim_half :]
23- else :
24- raise ValueError (
25- f"Unexpected freqs_cis shape for non-complex input: { freqs_cis .shape } , "
26- f"expected last dim = { 2 * head_dim_half } "
27- )
28-
29- if freqs_real .shape [- 1 ] != head_dim_half :
30- raise ValueError (f"Unexpected last dim for freqs: { freqs_real .shape [- 1 ]} (expected { head_dim_half } )" )
31-
32- # Flatten leading dims -> (N, head_dim_half)
33- freqs_real = freqs_real .reshape (- 1 , head_dim_half )
34- freqs_imag = freqs_imag .reshape (- 1 , head_dim_half )
35-
36- # Broadcast/slice to (seq_len, head_dim_half)
37- if freqs_real .shape [0 ] < seq_len :
38- if freqs_real .shape [0 ] == 1 :
39- freqs_real = freqs_real .expand (seq_len , - 1 )
40- freqs_imag = freqs_imag .expand (seq_len , - 1 )
41- else :
42- raise ValueError (f"Insufficient rows in freqs: { freqs_real .shape [0 ]} < seq_len={ seq_len } " )
43- elif freqs_real .shape [0 ] > seq_len :
44- freqs_real = freqs_real [:seq_len ]
45- freqs_imag = freqs_imag [:seq_len ]
46-
47- return freqs_real , freqs_imag
48-
49-
50- def _cast_and_contiguous (q , k , freqs_real , freqs_imag ):
8+ def _cast_and_contiguous (q , k , freqs_complex ):
519 # Align dtype: fp32 only when q is fp32; otherwise keep q dtype for perf
5210 compute_dtype = torch .float32 if q .dtype == torch .float32 else q .dtype
5311
@@ -56,17 +14,15 @@ def _cast_and_contiguous(q, k, freqs_real, freqs_imag):
5614
5715 q = q .to (compute_dtype ).contiguous ()
5816 k = k .to (compute_dtype ).contiguous ()
59- freqs_real = freqs_real .to (compute_dtype ).contiguous ()
60- freqs_imag = freqs_imag .to (compute_dtype ).contiguous ()
61- return q , k , freqs_real , freqs_imag , compute_dtype
17+ freqs_complex = freqs_complex .contiguous ()
18+ return q , k , freqs_complex , compute_dtype
6219
6320
6421@triton .jit
6522def _triton_llama4_rope_npu (
6623 q_ptr ,
6724 k_ptr ,
68- freqs_real_ptr ,
69- freqs_imag_ptr ,
25+ freqs_complex_ptr ,
7026 q_row_stride ,
7127 k_row_stride ,
7228 q_head_stride ,
@@ -84,8 +40,7 @@ def _triton_llama4_rope_npu(
8440 """
8541 Llama4 RoPE on Ascend NPU for interleaved complex layout:
8642 - q/k shape: (bs, sl, n_heads, hd)
87- - last dim layout: [real0, imag0, real1, imag1, ...]
88- - freqs_real/imag: (sl, hd//2)
43+ - freqs_complex_ptr: (sl, hd//2, 2)
8944 """
9045 pid = tl .program_id (0 ).to (tl .int64 )
9146 batch_idx = pid // sl
@@ -101,11 +56,14 @@ def _triton_llama4_rope_npu(
10156 hd_idx = tl .arange (0 , hd )
10257 hd_mask = hd_idx < (hd )
10358
104- freq_idx = tl .arange (0 , hd // 2 )
105- freq_mask = freq_idx < (hd // 2 )
59+ freq_idx = tl .arange (0 , hd )
60+ freq_mask = freq_idx < (hd )
10661
107- freqs_real = tl .load (freqs_real_ptr + freq_base + freq_idx , mask = freq_mask , other = 0.0 )
108- freqs_imag = tl .load (freqs_imag_ptr + freq_base + freq_idx , mask = freq_mask , other = 0.0 ) * imag_sign
62+ freqs_complex = tl .load (freqs_complex_ptr + freq_base + freq_idx , mask = freq_mask , other = 0.0 )
63+
64+ freqs_complex = freqs_complex .reshape (hd // 2 , 2 , can_reorder = True )
65+ freqs_real , freqs_imag = tl .split (freqs_complex )
66+ freqs_imag = freqs_imag * imag_sign
10967
11068 # Q heads (chunked for UB)
11169 for qh_block in range (0 , n_qh , BLOCK_Q ):
@@ -166,10 +124,14 @@ def llama4_rope_forward(q, k, freqs_cis):
166124 _ , _ , n_kh , _ = k .shape
167125 if hd % 2 != 0 :
168126 raise ValueError (f"head_dim must be even for interleaved complex layout, got { hd } " )
169- hd_half = hd // 2
170127
171- freqs_real , freqs_imag = _prepare_freqs (freqs_cis , sl , hd_half )
172- q , k , freqs_real , freqs_imag , compute_dtype = _cast_and_contiguous (q , k , freqs_real , freqs_imag )
128+ if freqs_cis .is_complex ():
129+ freqs_cis = freqs_cis .reshape (- 1 , freqs_cis .shape [- 1 ])
130+ if freqs_cis .shape [0 ] > sl :
131+ freqs_cis = freqs_cis [:sl ]
132+ freqs_cis = torch .view_as_real (freqs_cis )
133+
134+ q , k , freqs_cis , compute_dtype = _cast_and_contiguous (q , k , freqs_cis )
173135
174136 # UB tiling strategy: tile heads dimension only
175137 dtype_size = q .element_size ()
@@ -195,13 +157,12 @@ def llama4_rope_forward(q, k, freqs_cis):
195157 _triton_llama4_rope_npu [(n_row ,)](
196158 q ,
197159 k ,
198- freqs_real ,
199- freqs_imag ,
160+ freqs_cis ,
200161 q .stride (1 ),
201162 k .stride (1 ),
202163 q .stride (2 ),
203164 k .stride (2 ),
204- freqs_real .stride (0 ),
165+ freqs_cis .stride (0 ),
205166 sl ,
206167 bs ,
207168 n_qh ,
@@ -231,10 +192,14 @@ def llama4_rope_backward(dq, dk, freqs_cis):
231192 _ , _ , n_kh , _ = dk .shape
232193 if hd % 2 != 0 :
233194 raise ValueError (f"head_dim must be even for interleaved complex layout, got { hd } " )
234- hd_half = hd // 2
235195
236- freqs_real , freqs_imag = _prepare_freqs (freqs_cis , sl , hd_half )
237- dq , dk , freqs_real , freqs_imag , compute_dtype = _cast_and_contiguous (dq , dk , freqs_real , freqs_imag )
196+ if freqs_cis .is_complex ():
197+ freqs_cis = freqs_cis .reshape (- 1 , freqs_cis .shape [- 1 ])
198+ if freqs_cis .shape [0 ] > sl :
199+ freqs_cis = freqs_cis [:sl ]
200+ freqs_cis = torch .view_as_real (freqs_cis )
201+
202+ dq , dk , freqs_cis , compute_dtype = _cast_and_contiguous (dq , dk , freqs_cis )
238203
239204 # UB tiling strategy: tile heads dimension only
240205 dtype_size = dq .element_size ()
@@ -260,13 +225,12 @@ def llama4_rope_backward(dq, dk, freqs_cis):
260225 _triton_llama4_rope_npu [(n_row ,)](
261226 dq ,
262227 dk ,
263- freqs_real ,
264- freqs_imag ,
228+ freqs_cis ,
265229 dq .stride (1 ),
266230 dk .stride (1 ),
267231 dq .stride (2 ),
268232 dk .stride (2 ),
269- freqs_real .stride (0 ),
233+ freqs_cis .stride (0 ),
270234 sl ,
271235 bs ,
272236 n_qh ,
0 commit comments