Skip to content

Commit 08a3484

Browse files
committed
refactor mrope
1 parent 1321f2e commit 08a3484

File tree

5 files changed

+21
-35
lines changed

5 files changed

+21
-35
lines changed

lightllm/models/qwen2_vl/infer_struct.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
3333
self.position_ids = position_ids.unsqueeze(0).expand(3, -1)
3434

3535
self.position_ids = self.position_ids.contiguous()
36-
self._cos_cached = model._cos_cached
37-
self._sin_cached = model._sin_cached
36+
self.position_cos = model._cos_cached[self.position_ids]
37+
self.position_sin = model._sin_cached[self.position_ids]
3838
if get_env_start_args().enable_fa3:
3939
self.max_seq_len = self.max_kv_seq_len
4040
self.q_max_seq_len = self.max_q_seq_len

lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@ def _get_qkv(self, input, infer_state, layer_weight):
2121
mrope_triton_fused(
2222
q.view(-1, self.tp_q_head_num_, self.head_dim_),
2323
cache_kv[:, : self.tp_k_head_num_, :],
24-
infer_state._cos_cached,
25-
infer_state._sin_cached,
26-
infer_state.position_ids,
24+
infer_state.position_cos,
25+
infer_state.position_sin,
2726
self.mrope_section,
2827
is_interleaved=False,
2928
)

lightllm/models/qwen2_vl/triton_kernel/mrope.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -74,40 +74,33 @@ def _mrope_triton_fused_kernel(
7474
Cos,
7575
Sin,
7676
mrope_section,
77-
position_ids,
78-
stride_positions,
77+
stride_cosld,
78+
stride_cosd,
79+
stride_sinld,
80+
stride_sind,
7981
stride_qbs,
8082
stride_qh,
8183
stride_qd,
8284
stride_kbs,
8385
stride_kh,
8486
stride_kd,
85-
stride_cosbs,
86-
stride_cosd,
87-
stride_sinbs,
88-
stride_sind,
8987
is_interleaved: tl.constexpr,
9088
HEAD_Q: tl.constexpr,
9189
HEAD_K: tl.constexpr,
9290
BLOCK_DMODEL: tl.constexpr,
93-
NUM_STAGE: tl.constexpr,
9491
):
9592
head_index = tl.program_id(0)
9693
seq_index = tl.program_id(1)
9794

9895
dim_range0 = tl.arange(0, BLOCK_DMODEL // 2)
9996
dim_range1 = dim_range0 + BLOCK_DMODEL // 2
10097

101-
t = tl.load(position_ids + 0 * stride_positions + seq_index)
102-
h = tl.load(position_ids + 1 * stride_positions + seq_index)
103-
w = tl.load(position_ids + 2 * stride_positions + seq_index)
104-
105-
t_cos = Cos + t * stride_cosbs
106-
h_cos = Cos + h * stride_cosbs
107-
w_cos = Cos + w * stride_cosbs
108-
t_sin = Sin + t * stride_sinbs
109-
h_sin = Sin + h * stride_sinbs
110-
w_sin = Sin + w * stride_sinbs
98+
t_cos = Cos + seq_index * stride_cosd
99+
h_cos = Cos + stride_cosld + seq_index * stride_cosd
100+
w_cos = Cos + 2 * stride_cosld + seq_index * stride_cosd
101+
t_sin = Sin + seq_index * stride_sind
102+
h_sin = Sin + stride_sinld + seq_index * stride_sind
103+
w_sin = Sin + 2 * stride_sinld + seq_index * stride_sind
111104

112105
mrope_section_t = tl.load(mrope_section + 0)
113106
mrope_section_h = tl.load(mrope_section + 1)
@@ -198,7 +191,6 @@ def mrope_triton_fused(
198191
k: torch.Tensor,
199192
cos: torch.Tensor,
200193
sin: torch.Tensor,
201-
position_ids: torch.Tensor,
202194
mrope_section: torch.Tensor,
203195
is_interleaved: bool,
204196
run_config: Optional[dict] = None,
@@ -224,24 +216,21 @@ def mrope_triton_fused(
224216
k=k,
225217
Cos=cos,
226218
Sin=sin,
227-
position_ids=position_ids,
228219
mrope_section=mrope_section,
229-
stride_positions=position_ids.stride(0),
220+
stride_cosld=cos.stride(0),
221+
stride_cosd=cos.stride(1),
222+
stride_sinld=sin.stride(0),
223+
stride_sind=sin.stride(1),
230224
stride_qbs=q.stride(0),
231225
stride_qh=q.stride(1),
232226
stride_qd=q.stride(2),
233227
stride_kbs=k.stride(0),
234228
stride_kh=k.stride(1),
235229
stride_kd=k.stride(2),
236-
stride_cosbs=cos.stride(0),
237-
stride_cosd=cos.stride(1),
238-
stride_sinbs=sin.stride(0),
239-
stride_sind=sin.stride(1),
240230
is_interleaved=is_interleaved,
241231
HEAD_Q=head_num_q,
242232
HEAD_K=head_num_k,
243233
BLOCK_DMODEL=head_dim,
244-
NUM_STAGE=num_stages,
245234
num_warps=num_warps,
246235
num_stages=num_stages,
247236
)

lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,8 @@ def _get_qkv(
5151
mrope_triton_fused(
5252
q.view(-1, self.tp_q_head_num_, self.head_dim_),
5353
cache_kv[:, : self.tp_k_head_num_, :],
54-
infer_state._cos_cached,
55-
infer_state._sin_cached,
56-
infer_state.position_ids,
54+
infer_state.position_cos,
55+
infer_state.position_sin,
5756
self.mrope_section,
5857
is_interleaved=True,
5958
)

unit_tests/models/qwen2_vl/test_mrope.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,8 @@ def test_mrope_triton_correctness(B, H_q, H_k, L, D, mrope_section):
5454

5555
q = q.transpose(1, 2).contiguous().view(L, H_q, D)
5656
k = k.transpose(1, 2).contiguous().view(L, H_k, D)
57-
position_ids = torch.arange(L, dtype=torch.int32, device="cuda").unsqueeze(0).expand(3, L).contiguous()
5857
mrope_section = torch.tensor(mrope_section, dtype=torch.int32, device="cuda")
59-
mrope_triton_fused(q, k, cos_half[0], sin_half[0], position_ids, mrope_section, is_interleaved=False)
58+
mrope_triton_fused(q, k, cos_half, sin_half, mrope_section, is_interleaved=False)
6059
q = q.transpose(0, 1).contiguous().view(B, H_q, L, D)
6160
k = k.transpose(0, 1).contiguous().view(B, H_k, L, D)
6261
assert torch.allclose(q, ref_q, rtol=1e-3, atol=1e-3)

0 commit comments

Comments
 (0)