Skip to content

Commit f92a96d

Browse files
committed
[NPU]: Add NPU support for the mrope operator
1 parent 6a38342 commit f92a96d

File tree

3 files changed

+294
-2
lines changed

3 files changed

+294
-2
lines changed

src/liger_kernel/ops/backends/_ascend/ops/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,21 @@
1717
from liger_kernel.ops.backends._ascend.ops.geglu import LigerGELUMulFunction
1818
from liger_kernel.ops.backends._ascend.ops.geglu import geglu_backward
1919
from liger_kernel.ops.backends._ascend.ops.geglu import geglu_forward
20+
from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
21+
from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import qwen2vl_mrope_backward
22+
from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import qwen2vl_mrope_forward
2023
from liger_kernel.ops.backends._ascend.ops.rope import LigerRopeFunction
2124
from liger_kernel.ops.backends._ascend.ops.rope import rope_backward
2225
from liger_kernel.ops.backends._ascend.ops.rope import rope_forward
2326

27+
2428
__all__ = [
2529
"LigerGELUMulFunction",
2630
"geglu_forward",
2731
"geglu_backward",
32+
"LigerQwen2VLMRopeFunction",
33+
"qwen2vl_mrope_forward",
34+
"qwen2vl_mrope_backward",
2835
"LigerRopeFunction",
2936
"rope_forward",
3037
"rope_backward",
Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
6+
7+
8+
@triton.jit
9+
def _triton_qwen2vl_mrope_npu(
10+
q_ptr,
11+
q_row_stride,
12+
k_ptr,
13+
k_row_stride,
14+
cos,
15+
sin,
16+
sl,
17+
bs: tl.constexpr,
18+
n_qh: tl.constexpr,
19+
n_kh: tl.constexpr,
20+
hd: tl.constexpr,
21+
mrope_section_t: tl.constexpr,
22+
mrope_section_h: tl.constexpr,
23+
BLOCK_Q: tl.constexpr,
24+
BLOCK_K: tl.constexpr,
25+
BACKWARD_PASS: tl.constexpr = False,
26+
):
27+
pid = tl.program_id(0).to(tl.int64)
28+
29+
t_end = mrope_section_t
30+
h_end = t_end + mrope_section_h
31+
32+
t_cos = cos + pid * hd
33+
h_cos = t_cos + bs * sl * hd
34+
w_cos = h_cos + bs * sl * hd
35+
t_sin = sin + pid * hd
36+
h_sin = t_sin + bs * sl * hd
37+
w_sin = h_sin + bs * sl * hd
38+
39+
q_base = q_ptr + pid * q_row_stride
40+
k_base = k_ptr + pid * k_row_stride
41+
42+
d_idx = tl.arange(0, hd // 2)
43+
d_mask = d_idx < (hd // 2)
44+
45+
pos_mask_t = d_idx < t_end
46+
pos_mask_h = (d_idx >= t_end) & (d_idx < h_end)
47+
48+
text_cos_vals = tl.load(t_cos + d_idx, mask=d_mask, other=0)
49+
text_sin_vals = tl.load(t_sin + d_idx, mask=d_mask, other=0)
50+
height_cos_vals = tl.load(h_cos + d_idx, mask=d_mask, other=0)
51+
height_sin_vals = tl.load(h_sin + d_idx, mask=d_mask, other=0)
52+
width_cos_vals = tl.load(w_cos + d_idx, mask=d_mask, other=0)
53+
width_sin_vals = tl.load(w_sin + d_idx, mask=d_mask, other=0)
54+
55+
cos_vals = tl.where(pos_mask_t, text_cos_vals, tl.where(pos_mask_h, height_cos_vals, width_cos_vals))
56+
sin_vals = tl.where(pos_mask_t, text_sin_vals, tl.where(pos_mask_h, height_sin_vals, width_sin_vals))
57+
58+
for qh_block in range(0, n_qh, BLOCK_Q):
59+
qh_idx = tl.arange(0, BLOCK_Q) + qh_block
60+
qh_mask = qh_idx < n_qh
61+
62+
block_mask = qh_mask[:, None] & d_mask[None, :]
63+
offsets = qh_idx[:, None] * hd + d_idx[None, :]
64+
65+
q_left = tl.load(q_base + offsets, mask=block_mask, other=0)
66+
q_right = tl.load(q_base + offsets + (hd // 2), mask=block_mask, other=0)
67+
68+
if not BACKWARD_PASS:
69+
new_left = q_left * cos_vals - q_right * sin_vals
70+
new_right = q_right * cos_vals + q_left * sin_vals
71+
else:
72+
new_left = q_left * cos_vals + q_right * sin_vals
73+
new_right = q_right * cos_vals - q_left * sin_vals
74+
75+
tl.store(q_base + offsets, new_left, mask=block_mask)
76+
tl.store(q_base + offsets + (hd // 2), new_right, mask=block_mask)
77+
78+
for kh_block in range(0, n_kh, BLOCK_K):
79+
kh_idx = tl.arange(0, BLOCK_K) + kh_block
80+
kh_mask = kh_idx < n_kh
81+
82+
block_mask = kh_mask[:, None] & d_mask[None, :]
83+
offsets = kh_idx[:, None] * hd + d_idx[None, :]
84+
85+
k_left = tl.load(k_base + offsets, mask=block_mask, other=0)
86+
k_right = tl.load(k_base + offsets + (hd // 2), mask=block_mask, other=0)
87+
88+
if not BACKWARD_PASS:
89+
new_left = k_left * cos_vals - k_right * sin_vals
90+
new_right = k_right * cos_vals + k_left * sin_vals
91+
else:
92+
new_left = k_left * cos_vals + k_right * sin_vals
93+
new_right = k_right * cos_vals - k_left * sin_vals
94+
95+
tl.store(k_base + offsets, new_left, mask=block_mask)
96+
tl.store(k_base + offsets + (hd // 2), new_right, mask=block_mask)
97+
98+
99+
def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
100+
# transpose it back to the physical shape because Triton looks at the physical storage
101+
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
102+
q = q.transpose(1, 2)
103+
k = k.transpose(1, 2)
104+
105+
batch_size, seq_len, n_q_head, head_dim = q.shape
106+
n_kv_head = k.shape[2]
107+
pad_hd = triton.next_power_of_2(head_dim)
108+
pad_n_q_head = triton.next_power_of_2(n_q_head)
109+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
110+
111+
n_row = batch_size * seq_len
112+
113+
# ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
114+
q = q.contiguous()
115+
k = k.contiguous()
116+
cos = cos.contiguous()
117+
sin = sin.contiguous()
118+
119+
# Compute tiling strategy based on UB capacity
120+
dtype_size = q.element_size()
121+
# MROPE forward tiling strategy:
122+
# - cos_vals and sin_vals (include text, height and width) are loaded once outside loops (shared): (pad_hd // 2) * 4 = 2 * pad_hd elements each
123+
# - In q heads loop (peak memory):
124+
# * q_left: BLOCK_Q * (pad_hd // 2) elements
125+
# * q_right: BLOCK_Q * (pad_hd // 2) elements
126+
# * new_left: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
127+
# * new_right: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
128+
# * Total: 4 * BLOCK_Q * (pad_hd // 2) = 2 * BLOCK_Q * pad_hd elements
129+
# - In k heads loop (peak memory):
130+
# * k_left: BLOCK_K * (pad_hd // 2) elements
131+
# * k_right: BLOCK_K * (pad_hd // 2) elements
132+
# * new_left: BLOCK_K * (pad_hd // 2) elements (intermediate result)
133+
# * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result)
134+
# * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements
135+
# - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case
136+
# - Plus shared cos/sin: 2 * (pad_hd // 2) = pad_hd elements
137+
# - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + pad_hd) * dtype_size * 8 bits
138+
# - Simplified: (2 * BLOCK_SIZE + 2) * pad_hd * dtype_size * 8 bits
139+
# - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
140+
# - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
141+
# - tiling_dims: (0, 0) means first dimension of each shape can be tiled
142+
# - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
143+
shapes = ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
144+
tile_shapes = compute_default_tiling_strategy(
145+
safety_margin=0.90,
146+
dtype_size=dtype_size,
147+
memory_multiplier=3.0,
148+
shapes=shapes,
149+
tiling_dims=(0, 0),
150+
)
151+
152+
if tile_shapes is not None and len(tile_shapes) == len(shapes):
153+
# Strategy returns ((block_size_q, pad_hd), (block_size_kv, pad_hd))
154+
q_tile_shape, k_tile_shape = tile_shapes
155+
BLOCK_Q, _ = q_tile_shape
156+
BLOCK_K, _ = k_tile_shape
157+
else:
158+
# Fallback to conservative defaults
159+
BLOCK_Q = triton.next_power_of_2(pad_n_q_head)
160+
BLOCK_K = triton.next_power_of_2(pad_n_kv_head)
161+
_triton_qwen2vl_mrope_npu[(n_row,)](
162+
q,
163+
q.stride(1),
164+
k,
165+
k.stride(1),
166+
cos,
167+
sin,
168+
seq_len,
169+
batch_size,
170+
n_q_head,
171+
n_kv_head,
172+
head_dim,
173+
mrope_section[0],
174+
mrope_section[1],
175+
BLOCK_Q,
176+
BLOCK_K,
177+
BACKWARD_PASS=False,
178+
)
179+
return q.transpose(1, 2), k.transpose(1, 2), cos, sin
180+
181+
182+
def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
183+
dq = dq.transpose(1, 2)
184+
dk = dk.transpose(1, 2)
185+
186+
batch_size, seq_len, n_q_head, head_dim = dq.shape
187+
n_kv_head = dk.shape[2]
188+
pad_hd = triton.next_power_of_2(head_dim)
189+
pad_n_q_head = triton.next_power_of_2(n_q_head)
190+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
191+
192+
n_row = batch_size * seq_len
193+
194+
# ensure dq and dk are contiguous
195+
dq = dq.contiguous()
196+
dk = dk.contiguous()
197+
198+
# Compute tiling strategy based on UB capacity
199+
dtype_size = dq.element_size()
200+
# MROPE backward tiling strategy:
201+
# - cos_vals and sin_vals (include text, height and width) are loaded once outside loops (shared): (pad_hd // 2) * 4 = 2 * pad_hd elements each
202+
# - In q heads loop (peak memory):
203+
# * q_left: BLOCK_Q * (pad_hd // 2) elements
204+
# * q_right: BLOCK_Q * (pad_hd // 2) elements
205+
# * new_left: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
206+
# * new_right: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
207+
# * Total: 4 * BLOCK_Q * (pad_hd // 2) = 2 * BLOCK_Q * pad_hd elements
208+
# - In k heads loop (peak memory):
209+
# * k_left: BLOCK_K * (pad_hd // 2) elements
210+
# * k_right: BLOCK_K * (pad_hd // 2) elements
211+
# * new_left: BLOCK_K * (pad_hd // 2) elements (intermediate result)
212+
# * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result)
213+
# * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements
214+
# - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case
215+
# - Plus shared cos/sin: 2 * (pad_hd // 2) = pad_hd elements
216+
# - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + pad_hd) * dtype_size * 8 bits
217+
# - Simplified: (2 * BLOCK_SIZE + 2) * pad_hd * dtype_size * 8 bits
218+
# - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
219+
# - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
220+
# - tiling_dims: (0, 0) means first dimension of each shape can be tiled
221+
# - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
222+
shapes = ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
223+
tile_shapes = compute_default_tiling_strategy(
224+
safety_margin=0.90,
225+
dtype_size=dtype_size,
226+
memory_multiplier=3.0,
227+
shapes=shapes,
228+
tiling_dims=(0, 0),
229+
)
230+
231+
if tile_shapes is not None and len(tile_shapes) == len(shapes):
232+
# Strategy returns ((block_size_q, pad_hd), (block_size_kv, pad_hd))
233+
q_tile_shape, k_tile_shape = tile_shapes
234+
BLOCK_Q, _ = q_tile_shape
235+
BLOCK_K, _ = k_tile_shape
236+
else:
237+
# Fallback to conservative defaults
238+
BLOCK_Q = triton.next_power_of_2(pad_n_q_head)
239+
BLOCK_K = triton.next_power_of_2(pad_n_kv_head)
240+
_triton_qwen2vl_mrope_npu[(n_row,)](
241+
dq,
242+
dq.stride(1),
243+
dk,
244+
dk.stride(1),
245+
cos,
246+
sin,
247+
seq_len,
248+
batch_size,
249+
n_q_head,
250+
n_kv_head,
251+
head_dim,
252+
mrope_section[0],
253+
mrope_section[1],
254+
BLOCK_Q,
255+
BLOCK_K,
256+
BACKWARD_PASS=True,
257+
)
258+
return dq.transpose(1, 2), dk.transpose(1, 2)
259+
260+
261+
class LigerQwen2VLMRopeFunction(torch.autograd.Function):
262+
@staticmethod
263+
def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1):
264+
"""
265+
q size: (bsz, n_q_head, seq_len, head_dim)
266+
k size: (bsz, n_kv_head, seq_len, head_dim)
267+
cos size: (3, bsz, seq_len, head_dim)
268+
sin size: (3, bsz, seq_len, head_dim)
269+
"""
270+
q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
271+
ctx.save_for_backward(cos, sin)
272+
ctx.mrope_section = mrope_section
273+
return q, k
274+
275+
def backward(ctx, dq, dk):
276+
"""
277+
dq size: (bsz, n_q_head, seq_len, head_dim)
278+
dk size: (bsz, n_kv_head, seq_len, head_dim)
279+
cos size: (3, bsz, seq_len, head_dim)
280+
sin size: (3, bsz, seq_len, head_dim)
281+
"""
282+
cos, sin = ctx.saved_tensors
283+
mrope_section = ctx.mrope_section
284+
dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
285+
return dq, dk, None, None, None, None

src/liger_kernel/ops/backends/_ascend/ops/rope.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,8 @@ def rope_backward(dq, dk, cos, sin):
239239
BLOCK_K, _ = k_tile_shape
240240
else:
241241
# Fallback to conservative defaults
242-
BLOCK_Q = min(32, triton.next_power_of_2(pad_n_q_head))
243-
BLOCK_K = min(32, triton.next_power_of_2(pad_n_kv_head))
242+
BLOCK_Q = triton.next_power_of_2(pad_n_q_head)
243+
BLOCK_K = triton.next_power_of_2(pad_n_kv_head)
244244

245245
_triton_rope_npu[(n_row,)](
246246
dq,

0 commit comments

Comments
 (0)