Skip to content

Commit 974d775

Browse files
shihaobaisangchengmengwangzaijun
authored
mrope improved (#1147)
Co-authored-by: sangchengmeng <[email protected]> Co-authored-by: wangzaijun <[email protected]>
1 parent 8ddcadc commit 974d775

File tree

14 files changed

+387
-125
lines changed

14 files changed

+387
-125
lines changed

lightllm/models/llama/model.py

Lines changed: 2 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def _init_custom(self):
118118
scaling_type = rope_scaling["type"]
119119
else:
120120
raise ValueError(f"Unknown RoPE scaling format {rope_scaling}")
121-
if scaling_type == "default":
121+
if scaling_type == "default" or "mrope_section" in rope_scaling:
122122
self._init_to_get_rotary()
123123
elif scaling_type == "yarn":
124124
self._init_to_get_yarn_rotary()
@@ -129,7 +129,7 @@ def _init_custom(self):
129129
elif scaling_type == "llama3":
130130
self._init_to_get_llama3_rotary()
131131
elif scaling_type == "mrope":
132-
self._init_to_get_mrope_rotary()
132+
self._init_to_get_rotary()
133133
else:
134134
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
135135
return
@@ -373,47 +373,3 @@ def _init_to_get_llama3_rotary(self, default_base=10000):
373373
self._cos_cached = torch.cos(freqs).to(self.data_type).cuda()
374374
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
375375
return
376-
377-
def _init_to_get_mrope_rotary(self, default_base=10000):
378-
partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_)
379-
if self.config.get("rope_scaling", {}) is None:
380-
rope_scaling_factor = 1.0
381-
else:
382-
rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)
383-
384-
base = self.config.get("rope_theta", float(default_base))
385-
386-
if "max_sequence_length" in self.config:
387-
max_seq_len = self.config["max_sequence_length"]
388-
else:
389-
max_position_embeddings = self.config.get(
390-
"max_position_embeddings", 2048 if base <= 10000.0 + 1e-5 else 16384
391-
)
392-
max_seq_len = max_position_embeddings * rope_scaling_factor
393-
394-
# NTK
395-
try:
396-
ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1))
397-
assert ntk_alpha >= 1
398-
if ntk_alpha > 1:
399-
logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}")
400-
max_seq_len *= ntk_alpha
401-
base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula
402-
except:
403-
pass
404-
405-
inv_freq = 1.0 / (
406-
base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)
407-
)
408-
409-
t = (
410-
torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32)
411-
/ rope_scaling_factor
412-
)
413-
freqs = torch.outer(t, inv_freq).unsqueeze(0).expand(3, -1, -1)
414-
freqs = torch.cat((freqs, freqs), dim=-1)
415-
416-
self._cos_cached = torch.cos(freqs).to(self.data_type).cuda()
417-
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
418-
419-
return

lightllm/models/qwen2_vl/flashattention_infer_struct.py

Lines changed: 0 additions & 30 deletions
This file was deleted.
Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
1+
from typing import Optional, List
12
import torch
23
import numpy as np
34
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
45
from lightllm.common.basemodel.infer_struct import InferStateInfo
6+
from lightllm.models.qwen2_vl.triton_kernel.get_mrope_position_ids import get_mrope_position_triton
7+
from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo
8+
from lightllm.utils.envs_utils import get_env_start_args
59

610

711
class Qwen2VLInferStateInfo(LlamaInferStateInfo):
12+
init_flash_attention_state_func = FlashAttentionStateInfo._init_flash_attention_state
13+
814
def __init__(self):
915
super().__init__()
1016
self.position_cos = None
@@ -13,17 +19,64 @@ def __init__(self):
1319
def init_some_extra_state(self, model, input_ids: torch.Tensor):
1420
rope_scaling = model.config.get("rope_scaling", {})
1521
self.rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))
16-
if self.rope_type != "mrope":
17-
super().init_some_extra_state(model, input_ids)
18-
return
1922
InferStateInfo.init_some_extra_state(self, model, input_ids)
2023
if self.is_prefill:
21-
position_ids = self.position_ids
22-
self.position_sin = model._sin_cached[:, position_ids, :].unsqueeze(1)
23-
self.position_cos = model._cos_cached[:, position_ids, :].unsqueeze(1)
24-
position_ids = None
24+
self.position_ids = self.get_mrope_position(self.multimodal_params)
2525
else:
26-
position_ids = self.position_ids
27-
self.position_sin = model._sin_cached[:, position_ids, :].unsqueeze(1)
28-
self.position_cos = model._cos_cached[:, position_ids, :].unsqueeze(1)
26+
b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])]
27+
for batch_idx, p in enumerate(self.multimodal_params):
28+
position_delta = 0
29+
for image in p["images"]:
30+
position_delta += image["grid_thwd"][3]
31+
b_position_delta[batch_idx] = position_delta
32+
position_ids = self.position_ids + torch.tensor(b_position_delta, device=self.position_ids.device)
33+
self.position_ids = position_ids.unsqueeze(0).expand(3, -1)
34+
35+
self.position_ids = self.position_ids.contiguous()
36+
self.position_cos = model._cos_cached[self.position_ids] # (3, L, D)
37+
self.position_sin = model._sin_cached[self.position_ids] # (3, L, D)
38+
if get_env_start_args().enable_fa3:
39+
self.max_seq_len = self.max_kv_seq_len
40+
self.q_max_seq_len = self.max_q_seq_len
41+
self.init_flash_attention_state_func(model, input_ids)
2942
return
43+
44+
def get_mrope_position(self, multimodal_params: List[dict]) -> torch.Tensor:
45+
if len(multimodal_params) == 0:
46+
return self.position_ids.unsqueeze(0).expand(3, -1)
47+
b_image_start_idx = []
48+
b_image_nums = []
49+
b_image_start_num = []
50+
b_image_len = []
51+
image_start_num = 0
52+
b_image_thwd = []
53+
for _, p in enumerate(multimodal_params):
54+
images = p.get("images", [])
55+
for img in images:
56+
b_image_start_idx.append(img["start_idx"])
57+
b_image_len.append(img["token_num"])
58+
b_image_thwd.append(img["grid_thwd"])
59+
b_image_nums.append(len(images))
60+
b_image_start_num.append(image_start_num)
61+
image_start_num += len(images)
62+
# 没有任何图片
63+
if image_start_num == 0:
64+
return self.position_ids.unsqueeze(0).expand(3, -1).contiguous()
65+
b_image_start_idx = torch.tensor(b_image_start_idx, device="cpu").cuda(non_blocking=True)
66+
b_image_thwd = torch.tensor(b_image_thwd, device="cpu").cuda(non_blocking=True) # image_num x 4
67+
b_image_nums = torch.tensor(b_image_nums, device="cpu").cuda(non_blocking=True)
68+
b_image_start_num = torch.tensor(b_image_start_num, device="cpu").cuda(non_blocking=True)
69+
b_image_len = torch.tensor(b_image_len, device=self.position_ids.device)
70+
position_ids = self.position_ids.unsqueeze(0).expand(3, -1).contiguous()
71+
get_mrope_position_triton(
72+
b_image_start_idx=b_image_start_idx,
73+
b_image_thwd=b_image_thwd,
74+
b_image_nums=b_image_nums,
75+
b_image_start_num=b_image_start_num,
76+
b_image_len=b_image_len,
77+
position_ids=position_ids,
78+
b_ready_cache_len=self.b_ready_cache_len,
79+
b_q_seq_len=self.b_q_seq_len,
80+
b_start_loc=self.b_start_loc,
81+
)
82+
return position_ids

lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ def __init__(self, layer_num, network_config, mode=[]):
1919
self.axis_map = torch.tensor(axis_map, dtype=torch.int32, device="cuda")
2020

2121
def _get_qkv(self, input, infer_state, layer_weight):
22-
if infer_state.rope_type != "mrope":
23-
return super()._get_qkv(input, infer_state, layer_weight)
2422
q = layer_weight.q_proj.mm(input)
2523
cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
2624
seq_len, _ = q.shape

lightllm/models/qwen2_vl/model.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,15 @@
11
import json
22
import numpy as np
3-
import unicodedata
43
from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer
5-
from lightllm.models.qwen.model import QWenTpPartModel
64
from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer
75
from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem
8-
from transformers.feature_extraction_utils import BatchFeature
9-
from transformers.image_utils import ImageInput
10-
from transformers.processing_utils import ProcessorMixin
116
from lightllm.server.core.objs import SamplingParams
12-
from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
13-
from typing import List, Optional, Union
14-
from transformers.utils import TensorType, logging
15-
from lightllm.models.qwen2_vl.flashattention_infer_struct import Qwen2VLFlashAttentionStateInfo
167
from lightllm.common.build_utils import repair_config
178
from lightllm.models.registry import ModelRegistry
189
from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo
1910
from lightllm.models.qwen2_vl.layer_infer.transformer_layer_infer import Qwen2VLTransformerLayerInfer
2011

21-
import torch
22-
from PIL import Image
2312
from .vision_process import smart_resize
24-
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args
25-
from lightllm.models.qwen2.layer_weights import transformer_layer_weight, pre_and_post_layer_weight
2613
from lightllm.models.qwen2.model import Qwen2TpPartModel
2714
import os
2815

@@ -57,6 +44,9 @@ def get_image_token_length(self, img: ImageItem):
5744
)
5845
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
5946
token_num = (grid_h * grid_w) // (self.merge_size ** 2)
47+
position_delta = max(grid_h // self.merge_size, grid_w // self.merge_size) - token_num
48+
# delta 是为了mrope准备的,记录由于图片引入,position_id 产生的偏移量
49+
img.grid_thwd = (1, grid_h // self.merge_size, grid_w // self.merge_size, position_delta)
6050
return token_num
6151

6252
def get_audio_token_length(self, audio: AudioItem):
@@ -71,26 +61,25 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
7161
# <img></img> --> <img>id,id+1...id+num</img>
7262
input_ids = []
7363
image_id = 0
74-
start_idx = 0
7564
while True:
7665
try:
77-
start_idx = origin_ids.index(self.image_start_id, start_idx)
66+
start_idx = origin_ids.index(self.image_start_id)
7867
if start_idx + 1 >= len(origin_ids):
7968
break
8069
if origin_ids[start_idx + 1] == self.image_end_id:
8170
input_ids.extend(origin_ids[: start_idx + 1])
8271
token_id = multimodal_params.images[image_id].token_id
8372
token_num = multimodal_params.images[image_id].token_num
73+
multimodal_params.images[image_id].start_idx = len(input_ids)
8474
input_ids.extend(range(token_id, token_id + token_num))
8575
input_ids.append(self.image_end_id)
8676
origin_ids = origin_ids[start_idx + 2 :]
87-
start_idx = 0
8877
image_id += 1
8978
else:
9079
raise ValueError("image token error")
9180
except ValueError:
9281
break
93-
input_ids.extend(origin_ids[start_idx:])
82+
input_ids.extend(origin_ids)
9483
return input_ids
9584

9685

@@ -107,8 +96,7 @@ def __init__(self, kvargs):
10796
return
10897

10998
def _init_inferstate_cls(self):
110-
if get_env_start_args().enable_fa3:
111-
self.infer_state_class = Qwen2VLFlashAttentionStateInfo
99+
pass
112100

113101
def _init_config(self):
114102
with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:

0 commit comments

Comments
 (0)