Skip to content

Commit 9787303

Browse files
author
yihuiwen
committed
support tarsier2
1 parent 8bc96ba commit 9787303

File tree

9 files changed

+458
-2
lines changed

9 files changed

+458
-2
lines changed

lightllm/models/tarsier2/__init__.py

Whitespace-only changes.

lightllm/models/tarsier2/layer_weights/__init__.py

Whitespace-only changes.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch
2+
import numpy as np
3+
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight
4+
5+
from lightllm.models.internlm2.layer_weights.pre_and_post_layer_weight import Internlm2PreAndPostLayerWeight
6+
from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight
7+
8+
9+
# add key: language_model.xxx -> xxx
10+
# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now
11+
def rename_weight_keys(weights):
12+
prefix = "language_model."
13+
keys = list(weights.keys())
14+
for k in keys:
15+
if prefix in k:
16+
weights[k[len(prefix) :]] = weights[k]
17+
18+
19+
class Tarsier2Qwen2PreAndPostLayerWeight(Qwen2PreAndPostLayerWeight):
20+
def __init__(self, data_type, network_config, mode):
21+
super().__init__(data_type, network_config, mode)
22+
return
23+
24+
def load_hf_weights(self, weights):
25+
rename_weight_keys(weights)
26+
super().load_hf_weights(weights)
27+
return
28+
29+
30+
class Tarsier2LlamaPreAndPostLayerWeight(LlamaPreAndPostLayerWeight):
31+
def __init__(self, data_type, network_config, mode):
32+
super().__init__(data_type, network_config, mode)
33+
return
34+
35+
def load_hf_weights(self, weights):
36+
rename_weight_keys(weights)
37+
super().load_hf_weights(weights)
38+
return

lightllm/models/tarsier2/model.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import json
2+
import os
3+
4+
from lightllm.common.build_utils import repair_config
5+
from lightllm.models.llama.model import LlamaTpPartModel
6+
from lightllm.models.qwen2.model import Qwen2TpPartModel
7+
from lightllm.models.qwen2_vl.model import Qwen2VLTpPartModel
8+
from lightllm.models.qwen2_vl.vision_process import smart_resize
9+
from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer
10+
from lightllm.models.tarsier2.layer_weights.pre_and_post_layer_weight import Tarsier2Qwen2PreAndPostLayerWeight, Tarsier2LlamaPreAndPostLayerWeight
11+
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
12+
from lightllm.server.core.objs import SamplingParams
13+
14+
15+
class Tarsier2Tokenizer:
16+
17+
def __init__(self, tokenizer=None, image_processor=None, **kwargs):
18+
self.tokenizer = tokenizer
19+
self.image_processor = image_processor
20+
self.image_start_id = kwargs["model_cfg"]["text_config"]["vision_start_token_id"]
21+
self.image_end_id = kwargs["model_cfg"]["text_config"]["vision_end_token_id"]
22+
self.image_token_id = kwargs["model_cfg"]["text_config"]["image_token_id"]
23+
24+
def init_imageItem_extral_params(
25+
self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams
26+
):
27+
return
28+
29+
def get_image_token_length(self, img: ImageItem):
30+
width = img.image_w
31+
height = img.image_h
32+
resized_height, resized_width = smart_resize(height=height, width=width)
33+
self.patch_size = self.image_processor.patch_size
34+
self.merge_size = self.image_processor.merge_size
35+
grid_t = 1
36+
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
37+
merge_length = self.merge_size ** 2
38+
self.token_num = (grid_t * grid_h * grid_w) // merge_length
39+
self.image_length = self.token_num
40+
return self.image_length
41+
42+
def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
43+
44+
origin_ids = self.tokenizer.encode(prompt)
45+
46+
# <img><image_pad></img> -> <img></img>
47+
origin_ids = [token for token in origin_ids if token != self.image_token_id]
48+
# <img></img> --> <img>id,id+1...id+num</img>
49+
input_ids = []
50+
image_id = 0
51+
start_idx = 0
52+
while True:
53+
try:
54+
start_idx = origin_ids.index(self.image_start_id, start_idx)
55+
if start_idx + 1 >= len(origin_ids):
56+
break
57+
if origin_ids[start_idx + 1] == self.image_end_id:
58+
input_ids.extend(origin_ids[: start_idx + 1])
59+
token_id = multimodal_params.images[image_id].token_id
60+
token_num = multimodal_params.images[image_id].token_num
61+
input_ids.extend(range(token_id, token_id + token_num))
62+
input_ids.append(self.image_end_id)
63+
origin_ids = origin_ids[start_idx + 2 :]
64+
start_idx = 0
65+
image_id += 1
66+
else:
67+
raise ValueError("image token error")
68+
except ValueError:
69+
break
70+
input_ids.extend(origin_ids[start_idx:])
71+
return input_ids
72+
73+
def __getattr__(self, name):
74+
if name != "encode":
75+
return getattr(self.tokenizer, name)
76+
return self.encode
77+
78+
79+
pass
80+
81+
class Tarsier2Qwen2TpPartModel(Qwen2TpPartModel):
82+
# weight class
83+
pre_and_post_weight_class = Tarsier2Qwen2PreAndPostLayerWeight
84+
85+
# infer class
86+
pre_layer_infer_class = LlamaMultimodalPreLayerInfer
87+
88+
def __init__(self, kvargs):
89+
super().__init__(kvargs)
90+
return
91+
92+
def _init_config(self):
93+
with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
94+
self.config = json.load(json_file)["text_config"]
95+
# rename keys
96+
repair_config(self.config, same_names=["num_attention_heads", "n_head"])
97+
repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
98+
repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
99+
return
100+
101+
102+
class Tarsier2Qwen2VLTpPartModel(Qwen2VLTpPartModel):
103+
# weight class
104+
pre_and_post_weight_class = Tarsier2Qwen2PreAndPostLayerWeight
105+
106+
# infer class
107+
pre_layer_infer_class = LlamaMultimodalPreLayerInfer
108+
109+
def __init__(self, kvargs):
110+
super().__init__(kvargs)
111+
return
112+
113+
def _init_config(self):
114+
with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
115+
self.config = json.load(json_file)["text_config"]
116+
# rename keys
117+
repair_config(self.config, same_names=["num_attention_heads", "n_head"])
118+
repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
119+
repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
120+
return
121+
122+
class Tarsier2LlamaTpPartModel(LlamaTpPartModel):
123+
124+
pre_and_post_weight_class = Tarsier2LlamaPreAndPostLayerWeight
125+
126+
# infer class
127+
pre_layer_infer_class = LlamaMultimodalPreLayerInfer
128+
129+
def __init__(self, kvargs):
130+
super().__init__(kvargs)
131+
return
132+
133+
def _init_config(self):
134+
with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
135+
self.config = json.load(json_file)["text_config"]
136+
# rename keys
137+
repair_config(self.config, same_names=["num_attention_heads", "n_head"])
138+
repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
139+
repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
140+
return

0 commit comments

Comments
 (0)