Skip to content

Commit 7fb3692

Browse files
author
sangchengmeng
committed
[add] mrope triton
1 parent 06e37ed commit 7fb3692

File tree

2 files changed

+3
-8
lines changed

2 files changed

+3
-8
lines changed

lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton
88
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
99

10-
1110
class Qwen2VLTransformerLayerInfer(LlamaTransformerLayerInfer):
1211
def __init__(self, layer_num, network_config, mode=[]):
1312
super().__init__(layer_num, network_config, mode)
@@ -19,10 +18,10 @@ def _get_qkv(self, input, cache_kv, infer_state, layer_weight):
1918
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
2019
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
2120
seq_len, _ = q.shape
22-
q = q.view(1, seq_len, -1, self.head_dim_).transpose(1, 2)
23-
k = cache_kv[:, : self.tp_k_head_num_, :].view(1, seq_len, -1, self.head_dim_).transpose(1, 2)
21+
q = q.view(1, seq_len, -1, self.head_dim_).transpose(1, 2).contiguous()
22+
k = cache_kv[:, : self.tp_k_head_num_, :].view(1, seq_len, -1, self.head_dim_).transpose(1, 2).contiguous()
2423
new_q, new_k = mrope_triton(q, k, infer_state.position_cos, infer_state.position_sin, self.mrope_section)
25-
new_q = new_q.transpose(1, 2).reshape(1, seq_len, -1)
24+
new_q = new_q.transpose(1, 2).reshape(1, seq_len, -1).contiguous()
2625
cache_kv[:, : self.tp_k_head_num_, :] = new_k.squeeze(0).permute(1, 0, 2)
2726

2827
return new_q, cache_kv

lightllm/models/qwen2_vl/triton_kernel/mrope.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import triton
44
import triton.language as tl
55

6-
76
@triton.jit
87
def mrope_kernel_combined(
98
Q_ptr,
@@ -170,6 +169,3 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
170169
print(f"torch {t_ref:.2f} ms/iter")
171170
print(f"triton {t_tri:.2f} ms/iter")
172171

173-
174-
if __name__ == "__main__":
175-
test()

0 commit comments

Comments
 (0)