Skip to content

Commit 0ea0b8f

Browse files
authored
[NPU] Add Llama4_rope support on NPU (#1035)
## Summary This PR implements a fully executable Llama4 RoPE operator for Ascend NPU. 1. Prevents UB overflow issues specific to NPU execution 2. Implements interleaved complex layout compatible with NPU kernels <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <img width="2153" height="477" alt="image" src="https://github.com/user-attachments/assets/e8347175-8b42-41b2-a41b-542f2aaa71fd" /> Test done with `python -m pytest ./test/transformers/test_llama4_rope.py -v` Verified on Ascend NPU 910B4 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent 1f51687 commit 0ea0b8f

File tree

3 files changed

+305
-1
lines changed

3 files changed

+305
-1
lines changed

src/liger_kernel/ops/backends/_ascend/ops/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
from liger_kernel.ops.backends._ascend.ops.geglu import LigerGELUMulFunction
2121
from liger_kernel.ops.backends._ascend.ops.geglu import geglu_backward
2222
from liger_kernel.ops.backends._ascend.ops.geglu import geglu_forward
23+
from liger_kernel.ops.backends._ascend.ops.llama4_rope import LigerLlama4RopeFunction
24+
from liger_kernel.ops.backends._ascend.ops.llama4_rope import llama4_rope_backward
25+
from liger_kernel.ops.backends._ascend.ops.llama4_rope import llama4_rope_forward
2326
from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
2427
from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import qwen2vl_mrope_backward
2528
from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import qwen2vl_mrope_forward
@@ -52,4 +55,7 @@
5255
"LigerTVDLossFunction",
5356
"tv_distance_forward_triton",
5457
"tvd_backward_triton",
58+
"LigerLlama4RopeFunction",
59+
"llama4_rope_forward",
60+
"llama4_rope_backward",
5561
]
Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
6+
7+
8+
def _prepare_freqs(freqs_cis: torch.Tensor, seq_len: int, head_dim_half: int):
9+
"""
10+
Canonicalize freqs to (seq_len, head_dim_half) real/imag tensors.
11+
12+
Supports:
13+
- complex freqs: (..., head_dim_half) complex -> real/imag
14+
- packed freqs: (..., 2*head_dim_half) real -> split into real/imag
15+
"""
16+
if freqs_cis.is_complex():
17+
freqs_real = freqs_cis.real
18+
freqs_imag = freqs_cis.imag
19+
else:
20+
if freqs_cis.shape[-1] == 2 * head_dim_half:
21+
freqs_real = freqs_cis[..., :head_dim_half]
22+
freqs_imag = freqs_cis[..., head_dim_half:]
23+
else:
24+
raise ValueError(
25+
f"Unexpected freqs_cis shape for non-complex input: {freqs_cis.shape}, "
26+
f"expected last dim = {2 * head_dim_half}"
27+
)
28+
29+
if freqs_real.shape[-1] != head_dim_half:
30+
raise ValueError(f"Unexpected last dim for freqs: {freqs_real.shape[-1]} (expected {head_dim_half})")
31+
32+
# Flatten leading dims -> (N, head_dim_half)
33+
freqs_real = freqs_real.reshape(-1, head_dim_half)
34+
freqs_imag = freqs_imag.reshape(-1, head_dim_half)
35+
36+
# Broadcast/slice to (seq_len, head_dim_half)
37+
if freqs_real.shape[0] < seq_len:
38+
if freqs_real.shape[0] == 1:
39+
freqs_real = freqs_real.expand(seq_len, -1)
40+
freqs_imag = freqs_imag.expand(seq_len, -1)
41+
else:
42+
raise ValueError(f"Insufficient rows in freqs: {freqs_real.shape[0]} < seq_len={seq_len}")
43+
elif freqs_real.shape[0] > seq_len:
44+
freqs_real = freqs_real[:seq_len]
45+
freqs_imag = freqs_imag[:seq_len]
46+
47+
return freqs_real, freqs_imag
48+
49+
50+
def _cast_and_contiguous(q, k, freqs_real, freqs_imag):
51+
# Align dtype: fp32 only when q is fp32; otherwise keep q dtype for perf
52+
compute_dtype = torch.float32 if q.dtype == torch.float32 else q.dtype
53+
54+
if k.dtype != q.dtype:
55+
k = k.to(q.dtype)
56+
57+
q = q.to(compute_dtype).contiguous()
58+
k = k.to(compute_dtype).contiguous()
59+
freqs_real = freqs_real.to(compute_dtype).contiguous()
60+
freqs_imag = freqs_imag.to(compute_dtype).contiguous()
61+
return q, k, freqs_real, freqs_imag, compute_dtype
62+
63+
64+
@triton.jit
65+
def _triton_llama4_rope_npu(
66+
q_ptr,
67+
k_ptr,
68+
freqs_real_ptr,
69+
freqs_imag_ptr,
70+
q_row_stride,
71+
k_row_stride,
72+
q_head_stride,
73+
k_head_stride,
74+
freqs_row_stride,
75+
sl,
76+
bs: tl.constexpr,
77+
n_qh: tl.constexpr,
78+
n_kh: tl.constexpr,
79+
hd: tl.constexpr,
80+
BLOCK_Q: tl.constexpr,
81+
BLOCK_K: tl.constexpr,
82+
imag_sign: tl.constexpr,
83+
):
84+
"""
85+
Llama4 RoPE on Ascend NPU for interleaved complex layout:
86+
- q/k shape: (bs, sl, n_heads, hd)
87+
- last dim layout: [real0, imag0, real1, imag1, ...]
88+
- freqs_real/imag: (sl, hd//2)
89+
"""
90+
pid = tl.program_id(0).to(tl.int64)
91+
batch_idx = pid // sl
92+
seq_idx = pid % sl
93+
94+
if batch_idx >= bs:
95+
return
96+
97+
q_base = q_ptr + pid * q_row_stride
98+
k_base = k_ptr + pid * k_row_stride
99+
100+
freq_base = seq_idx * freqs_row_stride
101+
hd_idx = tl.arange(0, hd)
102+
hd_mask = hd_idx < (hd)
103+
104+
freq_idx = tl.arange(0, hd // 2)
105+
freq_mask = freq_idx < (hd // 2)
106+
107+
freqs_real = tl.load(freqs_real_ptr + freq_base + freq_idx, mask=freq_mask, other=0.0)
108+
freqs_imag = tl.load(freqs_imag_ptr + freq_base + freq_idx, mask=freq_mask, other=0.0) * imag_sign
109+
110+
# Q heads (chunked for UB)
111+
for qh_block in range(0, n_qh, BLOCK_Q):
112+
qh_idx = tl.arange(0, BLOCK_Q) + qh_block
113+
qh_mask = qh_idx < n_qh
114+
block_mask = qh_mask[:, None] & hd_mask[None, :]
115+
116+
head_ptr = q_base + qh_idx[:, None] * q_head_stride
117+
118+
q_pair = tl.load(
119+
head_ptr + hd_idx[None, :],
120+
mask=block_mask,
121+
other=0.0,
122+
)
123+
q_pair = q_pair.reshape(BLOCK_Q, hd // 2, 2, can_reorder=True)
124+
q_real, q_imag = tl.split(q_pair)
125+
126+
new_real = tl.math.fma(q_real, freqs_real, -(q_imag * freqs_imag))
127+
new_imag = tl.math.fma(q_real, freqs_imag, q_imag * freqs_real)
128+
new_q_pair = tl.interleave(new_real, new_imag)
129+
130+
tl.store(head_ptr + hd_idx[None, :], new_q_pair, mask=block_mask)
131+
132+
# K heads (chunked for UB)
133+
for kh_block in range(0, n_kh, BLOCK_K):
134+
kh_idx = tl.arange(0, BLOCK_K) + kh_block
135+
kh_mask = kh_idx < n_kh
136+
block_mask = kh_mask[:, None] & hd_mask[None, :]
137+
138+
head_ptr = k_base + kh_idx[:, None] * k_head_stride
139+
140+
k_pair = tl.load(
141+
head_ptr + hd_idx[None, :],
142+
mask=block_mask,
143+
other=0.0,
144+
)
145+
146+
k_pair = k_pair.reshape(BLOCK_K, hd // 2, 2, can_reorder=True)
147+
k_real, k_imag = tl.split(k_pair)
148+
149+
new_real = tl.math.fma(k_real, freqs_real, -(k_imag * freqs_imag))
150+
new_imag = tl.math.fma(k_real, freqs_imag, k_imag * freqs_real)
151+
new_k_pair = tl.interleave(new_real, new_imag)
152+
153+
tl.store(head_ptr + hd_idx[None, :], new_k_pair, mask=block_mask)
154+
155+
156+
def llama4_rope_forward(q, k, freqs_cis):
157+
"""
158+
Ascend NPU implementation of Llama4 RoPE.
159+
160+
q/k: (bs, sl, n_heads, hd) with interleaved complex last-dim layout.
161+
freqs_cis: complex (..., hd//2) OR packed (..., 2*(hd//2)).
162+
"""
163+
original_dtype = q.dtype
164+
165+
bs, sl, n_qh, hd = q.shape
166+
_, _, n_kh, _ = k.shape
167+
if hd % 2 != 0:
168+
raise ValueError(f"head_dim must be even for interleaved complex layout, got {hd}")
169+
hd_half = hd // 2
170+
171+
freqs_real, freqs_imag = _prepare_freqs(freqs_cis, sl, hd_half)
172+
q, k, freqs_real, freqs_imag, compute_dtype = _cast_and_contiguous(q, k, freqs_real, freqs_imag)
173+
174+
# UB tiling strategy: tile heads dimension only
175+
dtype_size = q.element_size()
176+
shapes = ((n_qh, hd), (n_kh, hd))
177+
tile_shapes = compute_default_tiling_strategy(
178+
safety_margin=0.90,
179+
dtype_size=dtype_size,
180+
memory_multiplier=12.0,
181+
shapes=shapes,
182+
tiling_dims=(0, 0),
183+
)
184+
185+
if tile_shapes is not None and len(tile_shapes) == len(shapes):
186+
q_tile_shape, k_tile_shape = tile_shapes
187+
BLOCK_Q, _ = q_tile_shape
188+
BLOCK_K, _ = k_tile_shape
189+
else:
190+
BLOCK_Q = triton.next_power_of_2(n_qh)
191+
BLOCK_K = triton.next_power_of_2(n_kh)
192+
193+
n_row = bs * sl
194+
195+
_triton_llama4_rope_npu[(n_row,)](
196+
q,
197+
k,
198+
freqs_real,
199+
freqs_imag,
200+
q.stride(1),
201+
k.stride(1),
202+
q.stride(2),
203+
k.stride(2),
204+
freqs_real.stride(0),
205+
sl,
206+
bs,
207+
n_qh,
208+
n_kh,
209+
hd,
210+
BLOCK_Q,
211+
BLOCK_K,
212+
imag_sign=1.0,
213+
)
214+
215+
if compute_dtype != original_dtype:
216+
q = q.to(original_dtype)
217+
k = k.to(original_dtype)
218+
return q, k
219+
220+
221+
def llama4_rope_backward(dq, dk, freqs_cis):
222+
"""
223+
Ascend NPU implementation of Llama4 RoPE.
224+
225+
q/k: (bs, sl, n_heads, hd) with interleaved complex last-dim layout.
226+
freqs_cis: complex (..., hd//2) OR packed (..., 2*(hd//2)).
227+
"""
228+
original_dtype = dq.dtype
229+
230+
bs, sl, n_qh, hd = dq.shape
231+
_, _, n_kh, _ = dk.shape
232+
if hd % 2 != 0:
233+
raise ValueError(f"head_dim must be even for interleaved complex layout, got {hd}")
234+
hd_half = hd // 2
235+
236+
freqs_real, freqs_imag = _prepare_freqs(freqs_cis, sl, hd_half)
237+
dq, dk, freqs_real, freqs_imag, compute_dtype = _cast_and_contiguous(dq, dk, freqs_real, freqs_imag)
238+
239+
# UB tiling strategy: tile heads dimension only
240+
dtype_size = dq.element_size()
241+
shapes = ((n_qh, hd), (n_kh, hd))
242+
tile_shapes = compute_default_tiling_strategy(
243+
safety_margin=0.90,
244+
dtype_size=dtype_size,
245+
memory_multiplier=12.0,
246+
shapes=shapes,
247+
tiling_dims=(0, 0),
248+
)
249+
250+
if tile_shapes is not None and len(tile_shapes) == len(shapes):
251+
q_tile_shape, k_tile_shape = tile_shapes
252+
BLOCK_Q, _ = q_tile_shape
253+
BLOCK_K, _ = k_tile_shape
254+
else:
255+
BLOCK_Q = triton.next_power_of_2(n_qh)
256+
BLOCK_K = triton.next_power_of_2(n_kh)
257+
258+
n_row = bs * sl
259+
260+
_triton_llama4_rope_npu[(n_row,)](
261+
dq,
262+
dk,
263+
freqs_real,
264+
freqs_imag,
265+
dq.stride(1),
266+
dk.stride(1),
267+
dq.stride(2),
268+
dk.stride(2),
269+
freqs_real.stride(0),
270+
sl,
271+
bs,
272+
n_qh,
273+
n_kh,
274+
hd,
275+
BLOCK_Q,
276+
BLOCK_K,
277+
imag_sign=-1.0,
278+
)
279+
280+
if compute_dtype != original_dtype:
281+
dq = dq.to(original_dtype)
282+
dk = dk.to(original_dtype)
283+
return dq, dk
284+
285+
286+
class LigerLlama4RopeFunction(torch.autograd.Function):
287+
@staticmethod
288+
def forward(ctx, q, k, freqs_cis, BLOCK_SIZE: int = None):
289+
# BLOCK_SIZE is ignored for Ascend (we auto-tile heads by UB), kept for API compatibility
290+
q_out, k_out = llama4_rope_forward(q, k, freqs_cis)
291+
ctx.save_for_backward(freqs_cis.detach() if isinstance(freqs_cis, torch.Tensor) else freqs_cis)
292+
return q_out, k_out
293+
294+
@staticmethod
295+
def backward(ctx, dq, dk):
296+
(freqs_cis,) = ctx.saved_tensors
297+
dq_out, dk_out = llama4_rope_backward(dq, dk, freqs_cis)
298+
return dq_out, dk_out, None, None

test/transformers/test_llama4_rope.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from test.utils import supports_bfloat16
55

6-
from liger_kernel.ops.llama4_rope import LigerLlama4RopeFunction
6+
from liger_kernel.ops import LigerLlama4RopeFunction
77
from liger_kernel.transformers.llama4_rope import liger_llama4_text_rotary_pos_emb
88
from liger_kernel.utils import infer_device
99

0 commit comments

Comments
 (0)