Skip to content

Commit ad6f0a7

Browse files
authored
[NPU] Frequencies fusion for Llama4_rope on NPU (#1053)
## Summary This PR is a descendant of #1035 It removes `_prepare_freqs` for simplicity and directly uses a single `freq_complex_ptr `for llama4_rope frequencies inside the Triton kernel. By avoiding extra preprocessing and reducing load, this approach simplifies the code path and improves performance. Benchmark results show better performance compared to the previous implementation. ## Testing Done Test done with `python -m pytest ./test/transformers/test_llama4_rope.py -v` Verified on Atlas 800I A2 - [ ] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent 60772c9 commit ad6f0a7

File tree

2 files changed

+51
-132
lines changed

2 files changed

+51
-132
lines changed

src/liger_kernel/ops/backends/_ascend/ops/llama4_rope.py

Lines changed: 30 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -5,49 +5,7 @@
55
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
66

77

8-
def _prepare_freqs(freqs_cis: torch.Tensor, seq_len: int, head_dim_half: int):
9-
"""
10-
Canonicalize freqs to (seq_len, head_dim_half) real/imag tensors.
11-
12-
Supports:
13-
- complex freqs: (..., head_dim_half) complex -> real/imag
14-
- packed freqs: (..., 2*head_dim_half) real -> split into real/imag
15-
"""
16-
if freqs_cis.is_complex():
17-
freqs_real = freqs_cis.real
18-
freqs_imag = freqs_cis.imag
19-
else:
20-
if freqs_cis.shape[-1] == 2 * head_dim_half:
21-
freqs_real = freqs_cis[..., :head_dim_half]
22-
freqs_imag = freqs_cis[..., head_dim_half:]
23-
else:
24-
raise ValueError(
25-
f"Unexpected freqs_cis shape for non-complex input: {freqs_cis.shape}, "
26-
f"expected last dim = {2 * head_dim_half}"
27-
)
28-
29-
if freqs_real.shape[-1] != head_dim_half:
30-
raise ValueError(f"Unexpected last dim for freqs: {freqs_real.shape[-1]} (expected {head_dim_half})")
31-
32-
# Flatten leading dims -> (N, head_dim_half)
33-
freqs_real = freqs_real.reshape(-1, head_dim_half)
34-
freqs_imag = freqs_imag.reshape(-1, head_dim_half)
35-
36-
# Broadcast/slice to (seq_len, head_dim_half)
37-
if freqs_real.shape[0] < seq_len:
38-
if freqs_real.shape[0] == 1:
39-
freqs_real = freqs_real.expand(seq_len, -1)
40-
freqs_imag = freqs_imag.expand(seq_len, -1)
41-
else:
42-
raise ValueError(f"Insufficient rows in freqs: {freqs_real.shape[0]} < seq_len={seq_len}")
43-
elif freqs_real.shape[0] > seq_len:
44-
freqs_real = freqs_real[:seq_len]
45-
freqs_imag = freqs_imag[:seq_len]
46-
47-
return freqs_real, freqs_imag
48-
49-
50-
def _cast_and_contiguous(q, k, freqs_real, freqs_imag):
8+
def _cast_and_contiguous(q, k, freqs_complex):
519
# Align dtype: fp32 only when q is fp32; otherwise keep q dtype for perf
5210
compute_dtype = torch.float32 if q.dtype == torch.float32 else q.dtype
5311

@@ -56,17 +14,15 @@ def _cast_and_contiguous(q, k, freqs_real, freqs_imag):
5614

5715
q = q.to(compute_dtype).contiguous()
5816
k = k.to(compute_dtype).contiguous()
59-
freqs_real = freqs_real.to(compute_dtype).contiguous()
60-
freqs_imag = freqs_imag.to(compute_dtype).contiguous()
61-
return q, k, freqs_real, freqs_imag, compute_dtype
17+
freqs_complex = freqs_complex.contiguous()
18+
return q, k, freqs_complex, compute_dtype
6219

6320

6421
@triton.jit
6522
def _triton_llama4_rope_npu(
6623
q_ptr,
6724
k_ptr,
68-
freqs_real_ptr,
69-
freqs_imag_ptr,
25+
freqs_complex_ptr,
7026
q_row_stride,
7127
k_row_stride,
7228
q_head_stride,
@@ -84,8 +40,7 @@ def _triton_llama4_rope_npu(
8440
"""
8541
Llama4 RoPE on Ascend NPU for interleaved complex layout:
8642
- q/k shape: (bs, sl, n_heads, hd)
87-
- last dim layout: [real0, imag0, real1, imag1, ...]
88-
- freqs_real/imag: (sl, hd//2)
43+
- freqs_complex_ptr: (sl, hd//2, 2)
8944
"""
9045
pid = tl.program_id(0).to(tl.int64)
9146
batch_idx = pid // sl
@@ -101,11 +56,14 @@ def _triton_llama4_rope_npu(
10156
hd_idx = tl.arange(0, hd)
10257
hd_mask = hd_idx < (hd)
10358

104-
freq_idx = tl.arange(0, hd // 2)
105-
freq_mask = freq_idx < (hd // 2)
59+
freq_idx = tl.arange(0, hd)
60+
freq_mask = freq_idx < (hd)
10661

107-
freqs_real = tl.load(freqs_real_ptr + freq_base + freq_idx, mask=freq_mask, other=0.0)
108-
freqs_imag = tl.load(freqs_imag_ptr + freq_base + freq_idx, mask=freq_mask, other=0.0) * imag_sign
62+
freqs_complex = tl.load(freqs_complex_ptr + freq_base + freq_idx, mask=freq_mask, other=0.0)
63+
64+
freqs_complex = freqs_complex.reshape(hd // 2, 2, can_reorder=True)
65+
freqs_real, freqs_imag = tl.split(freqs_complex)
66+
freqs_imag = freqs_imag * imag_sign
10967

11068
# Q heads (chunked for UB)
11169
for qh_block in range(0, n_qh, BLOCK_Q):
@@ -166,10 +124,14 @@ def llama4_rope_forward(q, k, freqs_cis):
166124
_, _, n_kh, _ = k.shape
167125
if hd % 2 != 0:
168126
raise ValueError(f"head_dim must be even for interleaved complex layout, got {hd}")
169-
hd_half = hd // 2
170127

171-
freqs_real, freqs_imag = _prepare_freqs(freqs_cis, sl, hd_half)
172-
q, k, freqs_real, freqs_imag, compute_dtype = _cast_and_contiguous(q, k, freqs_real, freqs_imag)
128+
if freqs_cis.is_complex():
129+
freqs_cis = freqs_cis.reshape(-1, freqs_cis.shape[-1])
130+
if freqs_cis.shape[0] > sl:
131+
freqs_cis = freqs_cis[:sl]
132+
freqs_cis = torch.view_as_real(freqs_cis)
133+
134+
q, k, freqs_cis, compute_dtype = _cast_and_contiguous(q, k, freqs_cis)
173135

174136
# UB tiling strategy: tile heads dimension only
175137
dtype_size = q.element_size()
@@ -195,13 +157,12 @@ def llama4_rope_forward(q, k, freqs_cis):
195157
_triton_llama4_rope_npu[(n_row,)](
196158
q,
197159
k,
198-
freqs_real,
199-
freqs_imag,
160+
freqs_cis,
200161
q.stride(1),
201162
k.stride(1),
202163
q.stride(2),
203164
k.stride(2),
204-
freqs_real.stride(0),
165+
freqs_cis.stride(0),
205166
sl,
206167
bs,
207168
n_qh,
@@ -231,10 +192,14 @@ def llama4_rope_backward(dq, dk, freqs_cis):
231192
_, _, n_kh, _ = dk.shape
232193
if hd % 2 != 0:
233194
raise ValueError(f"head_dim must be even for interleaved complex layout, got {hd}")
234-
hd_half = hd // 2
235195

236-
freqs_real, freqs_imag = _prepare_freqs(freqs_cis, sl, hd_half)
237-
dq, dk, freqs_real, freqs_imag, compute_dtype = _cast_and_contiguous(dq, dk, freqs_real, freqs_imag)
196+
if freqs_cis.is_complex():
197+
freqs_cis = freqs_cis.reshape(-1, freqs_cis.shape[-1])
198+
if freqs_cis.shape[0] > sl:
199+
freqs_cis = freqs_cis[:sl]
200+
freqs_cis = torch.view_as_real(freqs_cis)
201+
202+
dq, dk, freqs_cis, compute_dtype = _cast_and_contiguous(dq, dk, freqs_cis)
238203

239204
# UB tiling strategy: tile heads dimension only
240205
dtype_size = dq.element_size()
@@ -260,13 +225,12 @@ def llama4_rope_backward(dq, dk, freqs_cis):
260225
_triton_llama4_rope_npu[(n_row,)](
261226
dq,
262227
dk,
263-
freqs_real,
264-
freqs_imag,
228+
freqs_cis,
265229
dq.stride(1),
266230
dk.stride(1),
267231
dq.stride(2),
268232
dk.stride(2),
269-
freqs_real.stride(0),
233+
freqs_cis.stride(0),
270234
sl,
271235
bs,
272236
n_qh,

src/liger_kernel/ops/llama4_rope.py

Lines changed: 21 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -3,72 +3,24 @@
33
import triton.language as tl
44

55

6-
def _prepare_freqs(freqs_cis: torch.Tensor, seq_len: int, head_dim_half: int):
7-
# Split or unpack complex frequencies into real and imag parts
8-
if freqs_cis.is_complex():
9-
freqs_real = freqs_cis.real
10-
freqs_imag = freqs_cis.imag
11-
else:
12-
# Already split: last dim should be 2*head_dim_half
13-
if freqs_cis.shape[-1] == 2 * head_dim_half:
14-
freqs_real = freqs_cis[..., :head_dim_half]
15-
freqs_imag = freqs_cis[..., head_dim_half:]
16-
else:
17-
raise ValueError(
18-
f"Unexpected freqs_cis shape for non-complex input: {freqs_cis.shape}, expected last dim = {2 * head_dim_half}"
19-
)
20-
21-
# Canonicalize to shape (seq_len, head_dim_half):
22-
# 1) Ensure the last dimension is head_dim_half
23-
if freqs_real.shape[-1] != head_dim_half:
24-
raise ValueError(f"Unexpected last dim for freqs: {freqs_real.shape[-1]} (expected {head_dim_half})")
25-
# 2) Flatten all leading dims to a single row dimension
26-
freqs_real = freqs_real.reshape(-1, head_dim_half)
27-
freqs_imag = freqs_imag.reshape(-1, head_dim_half)
28-
# 3) If we have fewer rows than seq_len, allow broadcasting when single row
29-
if freqs_real.shape[0] < seq_len:
30-
if freqs_real.shape[0] == 1:
31-
freqs_real = freqs_real.expand(seq_len, -1)
32-
freqs_imag = freqs_imag.expand(seq_len, -1)
33-
else:
34-
raise ValueError(f"Insufficient rows in freqs: {freqs_real.shape[0]} < seq_len={seq_len}")
35-
# 4) If we have more rows than seq_len (e.g., batch present), take the first seq_len rows
36-
elif freqs_real.shape[0] > seq_len:
37-
freqs_real = freqs_real[:seq_len]
38-
freqs_imag = freqs_imag[:seq_len]
39-
40-
return freqs_real, freqs_imag
41-
42-
43-
def _maybe_to_dtype(t: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
44-
return t if t.dtype == dtype else t.to(dtype)
45-
46-
47-
def _maybe_contiguous(t: torch.Tensor) -> torch.Tensor:
48-
return t if t.is_contiguous() else t.contiguous()
49-
50-
51-
def _cast_and_contiguous(q, k, freqs_real, freqs_imag):
52-
# Choose compute dtype: use fp32 only when inputs are fp32; otherwise keep input dtype for performance
6+
def _cast_and_contiguous(q, k, freqs_complex):
7+
# Align dtype: fp32 only when q is fp32; otherwise keep q dtype for perf
538
compute_dtype = torch.float32 if q.dtype == torch.float32 else q.dtype
549

55-
# Make sure q/k share the same dtype before casting to compute dtype
5610
if k.dtype != q.dtype:
5711
k = k.to(q.dtype)
5812

59-
q = _maybe_contiguous(_maybe_to_dtype(q, compute_dtype))
60-
k = _maybe_contiguous(_maybe_to_dtype(k, compute_dtype))
61-
freqs_real = _maybe_contiguous(_maybe_to_dtype(freqs_real, compute_dtype))
62-
freqs_imag = _maybe_contiguous(_maybe_to_dtype(freqs_imag, compute_dtype))
63-
return q, k, freqs_real, freqs_imag
13+
q = q.to(compute_dtype).contiguous()
14+
k = k.to(compute_dtype).contiguous()
15+
freqs_complex = freqs_complex.contiguous()
16+
return q, k, freqs_complex
6417

6518

6619
@triton.jit
6720
def _llama4_rope_kernel(
6821
q_ptr,
6922
k_ptr,
70-
freqs_real_ptr,
71-
freqs_imag_ptr,
23+
freqs_complex_ptr,
7224
q_row_stride,
7325
k_row_stride,
7426
q_head_stride,
@@ -101,16 +53,18 @@ def _llama4_rope_kernel(
10153
base_offset = batch_idx * seq_len + seq_idx
10254
q_base = q_ptr + base_offset * q_row_stride
10355
k_base = k_ptr + base_offset * k_row_stride
56+
freq_base = seq_idx * freqs_row_stride
10457

10558
# Tiling over dim/2
10659
for d_start in tl.static_range(0, head_dim_half, BLOCK_SIZE):
10760
d_indices = d_start + tl.arange(0, BLOCK_SIZE)
10861
mask_d = d_indices < head_dim_half
10962

110-
# Load frequencies once per tile (freqs layout: [seq_len, head_dim_half])
111-
freq_idx = d_indices
112-
freqs_real = tl.load(freqs_real_ptr + seq_idx * freqs_row_stride + freq_idx, mask=mask_d, other=0.0)
113-
freqs_imag = tl.load(freqs_imag_ptr + seq_idx * freqs_row_stride + freq_idx, mask=mask_d, other=0.0)
63+
# Compute offsets for the block
64+
freq_offsets = d_indices[:, None] * 2 + tl.arange(0, 2)[None, :]
65+
# Load the block
66+
freqs_complex = tl.load(freqs_complex_ptr + freq_base + freq_offsets, mask=mask_d[:, None], other=0.0)
67+
freqs_real, freqs_imag = tl.split(freqs_complex)
11468
freqs_imag = freqs_imag * imag_sign
11569

11670
# Process one query head per program in pid_h
@@ -159,12 +113,14 @@ def llama4_rope_forward(q, k, freqs_cis, BLOCK_SIZE: int = None, imag_sign: floa
159113
batch_size, seq_len, n_q_heads, head_dim = q.shape
160114
_, _, n_k_heads, _ = k.shape
161115
head_dim_half = head_dim // 2
162-
163-
# Prepare frequencies
164-
freqs_real, freqs_imag = _prepare_freqs(freqs_cis, seq_len, head_dim_half)
116+
if freqs_cis.is_complex():
117+
freqs_cis = freqs_cis.reshape(-1, freqs_cis.shape[-1])
118+
if freqs_cis.shape[0] > seq_len:
119+
freqs_cis = freqs_cis[:seq_len]
120+
freqs_cis = torch.view_as_real(freqs_cis)
165121

166122
# Cast to appropriate dtype and make contiguous only when needed
167-
q, k, freqs_real, freqs_imag = _cast_and_contiguous(q, k, freqs_real, freqs_imag)
123+
q, k, freqs_cis = _cast_and_contiguous(q, k, freqs_cis)
168124

169125
# H100-optimized meta-params
170126
if BLOCK_SIZE is None:
@@ -181,13 +137,12 @@ def llama4_rope_forward(q, k, freqs_cis, BLOCK_SIZE: int = None, imag_sign: floa
181137
_llama4_rope_kernel[grid](
182138
q,
183139
k,
184-
freqs_real,
185-
freqs_imag,
140+
freqs_cis,
186141
q.stride(1),
187142
k.stride(1),
188143
q.stride(2),
189144
k.stride(2),
190-
freqs_real.stride(0),
145+
freqs_cis.stride(0),
191146
seq_len,
192147
batch_size,
193148
imag_sign,

0 commit comments

Comments
 (0)