77from lightllm .server .core .objs import SamplingParams
88from lightllm .models .registry import ModelRegistry
99from lightllm .models .qwen2 .model import Qwen2TpPartModel
10- from lightllm .models .qwen2_vl .vision_process import smart_resize
10+ from lightllm .models .qwen_vl .layer_infer .pre_layer_infer import LlamaMultimodalPreLayerInfer
11+ from lightllm .models .internvl .layer_weights .pre_and_post_layer_weight import InternVLLlamaPreAndPostLayerWeight
12+ from lightllm .models .internvl .img_process import get_image_patch
1113
1214from ..mineru2_qwen .image_processing_mineru2 import Mineru2ImageProcessor
15+ from .image_processing_mineru2 import get_anyres_image_grid_shape
16+
17+ IMG_START_TOKEN = "<img>"
18+ IMG_END_TOKEN = "</img>"
19+ IMG_TOKEN = "<image>"
1320
1421
1522class Mineru2QwenTokenizer (BaseMultiModalTokenizer ):
1623 def __init__ (self , tokenizer , model_cfg ):
1724 super ().__init__ (tokenizer )
18- self .image_token = model_cfg .get ("image_token" , "<image>" )
19- # for llava-v1.5-7b-hf model
25+
26+ self .image_token = model_cfg .get ("image_token" , IMG_TOKEN )
27+ self .img_token_index = model_cfg .get ("image_token_index" , 151646 )
28+
29+ self .image_start_tag = IMG_START_TOKEN
30+ self .image_start_id = tokenizer .convert_tokens_to_ids (self .image_start_tag )
31+
32+ self .image_end_tag = IMG_END_TOKEN
33+ self .image_end_id = tokenizer .convert_tokens_to_ids (self .image_end_tag )
34+
2035 if "text_config" in model_cfg :
2136 patch_size = model_cfg ["vision_config" ]["patch_size" ]
2237 image_size = model_cfg ["vision_config" ]["image_size" ]
@@ -30,9 +45,12 @@ def __init__(self, tokenizer, model_cfg):
3045 default_img_size = int (vision_tower_match .group (3 ))
3146 image_size = model_cfg .get ("img_size" , default_img_size )
3247 image_size = model_cfg .get ("mm_image_size" , image_size )
33- # (image_size // patch_size) ** 2: (384 // 14) ** 2 = 729
48+
49+ self .image_processor = Mineru2ImageProcessor (
50+ image_aspect_ratio = getattr (model_cfg , "image_aspect_ratio" , None ),
51+ image_grid_pinpoints = getattr (model_cfg , "image_grid_pinpoints" , None ),
52+ )
3453 self .image_length = (image_size // patch_size ) ** 2
35- self .skip_start = model_cfg .get ("skip_start" , True )
3654
3755 def init_imageitem_extral_params (
3856 self , img : ImageItem , multi_params : MultimodalParams , sampling_params : SamplingParams
@@ -52,30 +70,47 @@ def get_audio_token_length(self, audio: AudioItem):
5270
5371 # only change the impl of the encode func:
5472 def encode (self , prompt , multimodal_params : MultimodalParams = None , add_special_tokens : bool = True ):
55- image_token_id = getattr (self , "image_token_index" , 151646 )
56- image_token = self .image_token
57-
58- text_parts = prompt .split (image_token )
59- token_ids = []
60- image_offsets = []
61- offset = 0
62- for i , part in enumerate (text_parts ):
63- part_ids = self .tokenizer .encode (part , add_special_tokens = (add_special_tokens if i == 0 else False ))
64- token_ids .extend (part_ids )
65- offset += len (part_ids )
66- if i < len (text_parts ) - 1 :
67- token_ids .append (image_token_id )
68- image_offsets .append (offset )
69- offset += 1
70-
71- # 记录image_offsets方便后处理
72- if multimodal_params is not None :
73- multimodal_params .image_offsets = image_offsets
74- # multimodal_params.image_pad_len 可在后处理时补充
75- return token_ids
73+ # TEXT<image>TEXT<image>TEXT --> TEXT<img></img>TEXT<img></img>TEXT
74+ image_tokens = IMG_START_TOKEN + IMG_END_TOKEN
75+ if multimodal_params is None :
76+ return self .tokenizer .encode (prompt , add_special_tokens = add_special_tokens )
77+ image_count = len (multimodal_params .images )
78+ prompt = prompt .replace (IMG_TOKEN , image_tokens , image_count )
79+
80+ origin_ids = self .tokenizer .encode (prompt , add_special_tokens = add_special_tokens )
81+ # <img></img> --> <img>id,id+1...id+num</img>
82+ input_ids = []
83+ image_id = 0
84+ start_idx = 0
85+ while True :
86+ try :
87+ start_idx = origin_ids .index (self .image_start_id , start_idx )
88+ if start_idx + 1 >= len (origin_ids ):
89+ break
90+ if origin_ids [start_idx + 1 ] == self .image_end_id :
91+ input_ids .extend (origin_ids [: start_idx + 1 ])
92+ token_id = multimodal_params .images [image_id ].token_id
93+ token_num = multimodal_params .images [image_id ].token_num
94+ input_ids .extend (range (token_id , token_id + token_num ))
95+ input_ids .append (self .image_end_id )
96+ origin_ids = origin_ids [start_idx + 2 :]
97+ start_idx = 0
98+ image_id += 1
99+ else :
100+ raise ValueError ("image token error" )
101+ except ValueError :
102+ break
103+ input_ids .extend (origin_ids [start_idx :])
104+ return input_ids
76105
77106
78107@ModelRegistry ("mineru2_qwen" , is_multimodal = True )
79108class Mineru2QwenForCausalLM (Qwen2TpPartModel ):
109+ # weight class
110+ pre_and_post_weight_class = InternVLLlamaPreAndPostLayerWeight
111+
112+ # infer class
113+ pre_layer_infer_class = LlamaMultimodalPreLayerInfer
114+
80115 def __init__ (self , kvargs ):
81116 super ().__init__ (kvargs )
0 commit comments