Skip to content

Commit 1f51687

Browse files
authored
[NPU]: optimize rope and mrope implementation (#1041)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Use a smaller grid and make it as close as possible to npu_core_count, and employ a pipeline operation with a tl.range of num stages. This way, the same core can handle as many rows as possible, thereby enhancing performance. ## Testing Done <img width="1689" height="402" alt="image" src="https://github.com/user-attachments/assets/9dcf35b5-b3d3-450a-8346-6ff640ed4163" /> <img width="1703" height="407" alt="image" src="https://github.com/user-attachments/assets/8f350c96-a8d9-4af0-ab97-53e107658bd1" /> - Hardware Type: Ascend NPU 910B4 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent 7aa3a4b commit 1f51687

File tree

2 files changed

+208
-243
lines changed

2 files changed

+208
-243
lines changed

src/liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py

Lines changed: 111 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import triton.language as tl
44

55
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
6+
from liger_kernel.ops.utils import get_npu_core_count
67

78

89
@triton.jit
@@ -15,111 +16,102 @@ def _triton_qwen2vl_mrope_npu(
1516
sin,
1617
sl,
1718
bs: tl.constexpr,
19+
total_rows: tl.constexpr,
1820
n_qh: tl.constexpr,
1921
n_kh: tl.constexpr,
2022
hd: tl.constexpr,
2123
mrope_section_t: tl.constexpr,
2224
mrope_section_h: tl.constexpr,
2325
BLOCK_Q: tl.constexpr,
2426
BLOCK_K: tl.constexpr,
27+
NUM_STAGES: tl.constexpr,
2528
BACKWARD_PASS: tl.constexpr = False,
2629
):
27-
pid = tl.program_id(0).to(tl.int64)
30+
program_id = tl.program_id(0)
31+
num_programs = tl.num_programs(0)
2832

29-
t_end = mrope_section_t
30-
h_end = t_end + mrope_section_h
33+
rows_per_program = (total_rows + num_programs - 1) // num_programs
34+
start_row = program_id * rows_per_program
35+
actual_rows = tl.minimum(rows_per_program, total_rows - start_row)
3136

32-
t_cos = cos + pid * hd
33-
h_cos = t_cos + bs * sl * hd
34-
w_cos = h_cos + bs * sl * hd
35-
t_sin = sin + pid * hd
36-
h_sin = t_sin + bs * sl * hd
37-
w_sin = h_sin + bs * sl * hd
37+
for row_offset in tl.range(0, actual_rows, num_stages=NUM_STAGES):
38+
pid = start_row + row_offset
3839

39-
q_base = q_ptr + pid * q_row_stride
40-
k_base = k_ptr + pid * k_row_stride
40+
t_end = mrope_section_t
41+
h_end = t_end + mrope_section_h
4142

42-
d_idx = tl.arange(0, hd // 2)
43-
d_mask = d_idx < (hd // 2)
43+
t_cos = cos + pid * hd
44+
h_cos = t_cos + bs * sl * hd
45+
w_cos = h_cos + bs * sl * hd
46+
t_sin = sin + pid * hd
47+
h_sin = t_sin + bs * sl * hd
48+
w_sin = h_sin + bs * sl * hd
4449

45-
pos_mask_t = d_idx < t_end
46-
pos_mask_h = (d_idx >= t_end) & (d_idx < h_end)
50+
q_base = q_ptr + pid * q_row_stride
51+
k_base = k_ptr + pid * k_row_stride
4752

48-
text_cos_vals = tl.load(t_cos + d_idx, mask=d_mask, other=0)
49-
text_sin_vals = tl.load(t_sin + d_idx, mask=d_mask, other=0)
50-
height_cos_vals = tl.load(h_cos + d_idx, mask=d_mask, other=0)
51-
height_sin_vals = tl.load(h_sin + d_idx, mask=d_mask, other=0)
52-
width_cos_vals = tl.load(w_cos + d_idx, mask=d_mask, other=0)
53-
width_sin_vals = tl.load(w_sin + d_idx, mask=d_mask, other=0)
53+
d_idx = tl.arange(0, hd // 2)
54+
d_mask = d_idx < (hd // 2)
5455

55-
cos_vals = tl.where(pos_mask_t, text_cos_vals, tl.where(pos_mask_h, height_cos_vals, width_cos_vals))
56-
sin_vals = tl.where(pos_mask_t, text_sin_vals, tl.where(pos_mask_h, height_sin_vals, width_sin_vals))
56+
pos_mask_t = d_idx < t_end
57+
pos_mask_h = (d_idx >= t_end) & (d_idx < h_end)
5758

58-
for qh_block in range(0, n_qh, BLOCK_Q):
59-
qh_idx = tl.arange(0, BLOCK_Q) + qh_block
60-
qh_mask = qh_idx < n_qh
59+
text_cos_vals = tl.load(t_cos + d_idx, mask=d_mask, other=0)
60+
text_sin_vals = tl.load(t_sin + d_idx, mask=d_mask, other=0)
61+
height_cos_vals = tl.load(h_cos + d_idx, mask=d_mask, other=0)
62+
height_sin_vals = tl.load(h_sin + d_idx, mask=d_mask, other=0)
63+
width_cos_vals = tl.load(w_cos + d_idx, mask=d_mask, other=0)
64+
width_sin_vals = tl.load(w_sin + d_idx, mask=d_mask, other=0)
6165

62-
block_mask = qh_mask[:, None] & d_mask[None, :]
63-
offsets = qh_idx[:, None] * hd + d_idx[None, :]
66+
cos_vals = tl.where(pos_mask_t, text_cos_vals, tl.where(pos_mask_h, height_cos_vals, width_cos_vals))
67+
sin_vals = tl.where(pos_mask_t, text_sin_vals, tl.where(pos_mask_h, height_sin_vals, width_sin_vals))
6468

65-
q_left = tl.load(q_base + offsets, mask=block_mask, other=0)
66-
q_right = tl.load(q_base + offsets + (hd // 2), mask=block_mask, other=0)
69+
# Process q heads in chunks to prevent UB overflow
70+
for qh_block in range(0, n_qh, BLOCK_Q):
71+
qh_idx = tl.arange(0, BLOCK_Q) + qh_block
72+
qh_mask = qh_idx < n_qh
6773

68-
if not BACKWARD_PASS:
69-
new_left = q_left * cos_vals - q_right * sin_vals
70-
new_right = q_right * cos_vals + q_left * sin_vals
71-
else:
72-
new_left = q_left * cos_vals + q_right * sin_vals
73-
new_right = q_right * cos_vals - q_left * sin_vals
74+
block_mask = qh_mask[:, None] & d_mask[None, :]
75+
offsets = qh_idx[:, None] * hd + d_idx[None, :]
7476

75-
tl.store(q_base + offsets, new_left, mask=block_mask)
76-
tl.store(q_base + offsets + (hd // 2), new_right, mask=block_mask)
77+
q_left = tl.load(q_base + offsets, mask=block_mask, other=0)
78+
q_right = tl.load(q_base + offsets + (hd // 2), mask=block_mask, other=0)
7779

78-
for kh_block in range(0, n_kh, BLOCK_K):
79-
kh_idx = tl.arange(0, BLOCK_K) + kh_block
80-
kh_mask = kh_idx < n_kh
80+
if not BACKWARD_PASS:
81+
new_left = q_left * cos_vals - q_right * sin_vals
82+
new_right = q_right * cos_vals + q_left * sin_vals
83+
else:
84+
new_left = q_left * cos_vals + q_right * sin_vals
85+
new_right = q_right * cos_vals - q_left * sin_vals
8186

82-
block_mask = kh_mask[:, None] & d_mask[None, :]
83-
offsets = kh_idx[:, None] * hd + d_idx[None, :]
87+
tl.store(q_base + offsets, new_left, mask=block_mask)
88+
tl.store(q_base + offsets + (hd // 2), new_right, mask=block_mask)
8489

85-
k_left = tl.load(k_base + offsets, mask=block_mask, other=0)
86-
k_right = tl.load(k_base + offsets + (hd // 2), mask=block_mask, other=0)
90+
# Process k heads in chunks to prevent UB overflow
91+
for kh_block in range(0, n_kh, BLOCK_K):
92+
kh_idx = tl.arange(0, BLOCK_K) + kh_block
93+
kh_mask = kh_idx < n_kh
8794

88-
if not BACKWARD_PASS:
89-
new_left = k_left * cos_vals - k_right * sin_vals
90-
new_right = k_right * cos_vals + k_left * sin_vals
91-
else:
92-
new_left = k_left * cos_vals + k_right * sin_vals
93-
new_right = k_right * cos_vals - k_left * sin_vals
95+
block_mask = kh_mask[:, None] & d_mask[None, :]
96+
offsets = kh_idx[:, None] * hd + d_idx[None, :]
9497

95-
tl.store(k_base + offsets, new_left, mask=block_mask)
96-
tl.store(k_base + offsets + (hd // 2), new_right, mask=block_mask)
98+
k_left = tl.load(k_base + offsets, mask=block_mask, other=0)
99+
k_right = tl.load(k_base + offsets + (hd // 2), mask=block_mask, other=0)
97100

101+
if not BACKWARD_PASS:
102+
new_left = k_left * cos_vals - k_right * sin_vals
103+
new_right = k_right * cos_vals + k_left * sin_vals
104+
else:
105+
new_left = k_left * cos_vals + k_right * sin_vals
106+
new_right = k_right * cos_vals - k_left * sin_vals
98107

99-
def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
100-
# transpose it back to the physical shape because Triton looks at the physical storage
101-
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
102-
q = q.transpose(1, 2)
103-
k = k.transpose(1, 2)
108+
tl.store(k_base + offsets, new_left, mask=block_mask)
109+
tl.store(k_base + offsets + (hd // 2), new_right, mask=block_mask)
104110

105-
batch_size, seq_len, n_q_head, head_dim = q.shape
106-
n_kv_head = k.shape[2]
107-
pad_hd = triton.next_power_of_2(head_dim)
108-
pad_n_q_head = triton.next_power_of_2(n_q_head)
109-
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
110111

111-
n_row = batch_size * seq_len
112-
113-
# ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
114-
q = q.contiguous()
115-
k = k.contiguous()
116-
cos = cos.contiguous()
117-
sin = sin.contiguous()
118-
119-
# Compute tiling strategy based on UB capacity
120-
dtype_size = q.element_size()
112+
def get_optimal_block_size_mrope(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size):
121113
# MROPE forward tiling strategy:
122-
# - cos_vals and sin_vals (include text, height and width) are loaded once outside loops (shared): (pad_hd // 2) * 4 = 2 * pad_hd elements each
114+
# - cos_vals and sin_vals (include text, height and width) are loaded once outside loops (shared): (pad_hd // 2) * 6 = 3 * pad_hd elements each
123115
# - In q heads loop (peak memory):
124116
# * q_left: BLOCK_Q * (pad_hd // 2) elements
125117
# * q_right: BLOCK_Q * (pad_hd // 2) elements
@@ -133,9 +125,9 @@ def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
133125
# * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result)
134126
# * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements
135127
# - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case
136-
# - Plus shared cos/sin: 2 * (pad_hd // 2) = pad_hd elements
137-
# - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + pad_hd) * dtype_size * 8 bits
138-
# - Simplified: (2 * BLOCK_SIZE + 2) * pad_hd * dtype_size * 8 bits
128+
# - Plus shared cos/sin: 6 * (pad_hd // 2) = 3 * pad_hd elements
129+
# - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + 3 * pad_hd) * dtype_size * 8 bits
130+
# - Simplified: (2 * BLOCK_SIZE + 3) * pad_hd * dtype_size * 8 bits
139131
# - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
140132
# - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
141133
# - tiling_dims: (0, 0) means first dimension of each shape can be tiled
@@ -156,9 +148,38 @@ def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
156148
BLOCK_K, _ = k_tile_shape
157149
else:
158150
# Fallback to conservative defaults
159-
BLOCK_Q = triton.next_power_of_2(pad_n_q_head)
160-
BLOCK_K = triton.next_power_of_2(pad_n_kv_head)
161-
_triton_qwen2vl_mrope_npu[(n_row,)](
151+
BLOCK_Q = 2048
152+
BLOCK_K = 2048
153+
154+
return BLOCK_Q, BLOCK_K
155+
156+
157+
def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
158+
# transpose it back to the physical shape because Triton looks at the physical storage
159+
q = q.transpose(1, 2)
160+
k = k.transpose(1, 2)
161+
162+
batch_size, seq_len, n_q_head, head_dim = q.shape
163+
n_kv_head = k.shape[2]
164+
pad_hd = triton.next_power_of_2(head_dim)
165+
pad_n_q_head = triton.next_power_of_2(n_q_head)
166+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
167+
168+
n_row = batch_size * seq_len
169+
170+
# ensure tensors passed into the kernel are contiguous
171+
q = q.contiguous()
172+
k = k.contiguous()
173+
cos = cos.contiguous()
174+
sin = sin.contiguous()
175+
176+
dtype_size = q.element_size()
177+
BLOCK_Q, BLOCK_K = get_optimal_block_size_mrope(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size)
178+
179+
num_cores = get_npu_core_count()
180+
grid_size = min(num_cores, n_row)
181+
182+
_triton_qwen2vl_mrope_npu[(grid_size,)](
162183
q,
163184
q.stride(1),
164185
k,
@@ -167,13 +188,15 @@ def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
167188
sin,
168189
seq_len,
169190
batch_size,
191+
n_row,
170192
n_q_head,
171193
n_kv_head,
172194
head_dim,
173195
mrope_section[0],
174196
mrope_section[1],
175197
BLOCK_Q,
176198
BLOCK_K,
199+
NUM_STAGES=3,
177200
BACKWARD_PASS=False,
178201
)
179202
return q.transpose(1, 2), k.transpose(1, 2), cos, sin
@@ -195,49 +218,13 @@ def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
195218
dq = dq.contiguous()
196219
dk = dk.contiguous()
197220

198-
# Compute tiling strategy based on UB capacity
199221
dtype_size = dq.element_size()
200-
# MROPE backward tiling strategy:
201-
# - cos_vals and sin_vals (include text, height and width) are loaded once outside loops (shared): (pad_hd // 2) * 4 = 2 * pad_hd elements each
202-
# - In q heads loop (peak memory):
203-
# * q_left: BLOCK_Q * (pad_hd // 2) elements
204-
# * q_right: BLOCK_Q * (pad_hd // 2) elements
205-
# * new_left: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
206-
# * new_right: BLOCK_Q * (pad_hd // 2) elements (intermediate result)
207-
# * Total: 4 * BLOCK_Q * (pad_hd // 2) = 2 * BLOCK_Q * pad_hd elements
208-
# - In k heads loop (peak memory):
209-
# * k_left: BLOCK_K * (pad_hd // 2) elements
210-
# * k_right: BLOCK_K * (pad_hd // 2) elements
211-
# * new_left: BLOCK_K * (pad_hd // 2) elements (intermediate result)
212-
# * new_right: BLOCK_K * (pad_hd // 2) elements (intermediate result)
213-
# * Total: 4 * BLOCK_K * (pad_hd // 2) = 2 * BLOCK_K * pad_hd elements
214-
# - Since q and k are processed separately, peak memory is max(BLOCK_Q, BLOCK_K) case
215-
# - Plus shared cos/sin: 2 * (pad_hd // 2) = pad_hd elements
216-
# - Conservative estimate: (2 * BLOCK_SIZE * pad_hd + pad_hd) * dtype_size * 8 bits
217-
# - Simplified: (2 * BLOCK_SIZE + 2) * pad_hd * dtype_size * 8 bits
218-
# - For safety, use: memory_multiplier=3.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
219-
# - shapes: ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
220-
# - tiling_dims: (0, 0) means first dimension of each shape can be tiled
221-
# - Returns: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
222-
shapes = ((pad_n_q_head, pad_hd), (pad_n_kv_head, pad_hd))
223-
tile_shapes = compute_default_tiling_strategy(
224-
safety_margin=0.90,
225-
dtype_size=dtype_size,
226-
memory_multiplier=3.0,
227-
shapes=shapes,
228-
tiling_dims=(0, 0),
229-
)
222+
BLOCK_Q, BLOCK_K = get_optimal_block_size_mrope(pad_n_q_head, pad_n_kv_head, pad_hd, dtype_size)
230223

231-
if tile_shapes is not None and len(tile_shapes) == len(shapes):
232-
# Strategy returns ((block_size_q, pad_hd), (block_size_kv, pad_hd))
233-
q_tile_shape, k_tile_shape = tile_shapes
234-
BLOCK_Q, _ = q_tile_shape
235-
BLOCK_K, _ = k_tile_shape
236-
else:
237-
# Fallback to conservative defaults
238-
BLOCK_Q = triton.next_power_of_2(pad_n_q_head)
239-
BLOCK_K = triton.next_power_of_2(pad_n_kv_head)
240-
_triton_qwen2vl_mrope_npu[(n_row,)](
224+
num_cores = get_npu_core_count()
225+
grid_size = min(num_cores, n_row)
226+
227+
_triton_qwen2vl_mrope_npu[(grid_size,)](
241228
dq,
242229
dq.stride(1),
243230
dk,
@@ -246,13 +233,15 @@ def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
246233
sin,
247234
seq_len,
248235
batch_size,
236+
n_row,
249237
n_q_head,
250238
n_kv_head,
251239
head_dim,
252240
mrope_section[0],
253241
mrope_section[1],
254242
BLOCK_Q,
255243
BLOCK_K,
244+
NUM_STAGES=3,
256245
BACKWARD_PASS=True,
257246
)
258247
return dq.transpose(1, 2), dk.transpose(1, 2)
@@ -272,6 +261,7 @@ def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1):
272261
ctx.mrope_section = mrope_section
273262
return q, k
274263

264+
@staticmethod
275265
def backward(ctx, dq, dk):
276266
"""
277267
dq size: (bsz, n_q_head, seq_len, head_dim)

0 commit comments

Comments
 (0)