@@ -44,7 +44,7 @@ def __init__(self,
4444 hf_config : AutoConfig = None ,
4545 backend : str = '' ):
4646 super ().__init__ (model_path , with_llm , max_memory , hf_config , backend )
47- self .arch = hf_config .architectures [0 ]
47+ self .arch = self . hf_config .architectures [0 ]
4848
4949 def build_preprocessor (self ):
5050 self .processor = AutoProcessor .from_pretrained (self .model_path , trust_remote_code = True )
@@ -146,8 +146,32 @@ def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
146146 messages .append (dict (role = 'forward' , content = outputs ))
147147 return messages
148148
149- @staticmethod
149+ def proc_internvl_hf_messages (self , content : List [Dict ], IMAGE_TOKEN : str ):
150+ """Process the content list of role 'user' for InternVL HF models."""
151+ res = []
152+ for item in content :
153+ if item ['type' ] == 'text' :
154+ res .append (item ['text' ])
155+ elif item ['type' ] in ['image' , 'image_url' ]:
156+ res .append (f'{ IMAGE_TOKEN } \n ' )
157+ else :
158+ raise ValueError (f'Unsupported message type: { item ["type" ]} ' )
159+ return '' .join (res )
160+
161+ def proc_interns1_messages (self , content : List [Dict ], IMAGE_TOKEN : str ):
162+ """Process the content list of role 'user' for InternS1 models."""
163+ res = []
164+ for item in content :
165+ if item ['type' ] == 'text' :
166+ res .append (item ['text' ])
167+ elif item ['type' ] in ['image' , 'image_url' ]:
168+ res .append (IMAGE_TOKEN )
169+ else :
170+ raise ValueError (f'Unsupported message type: { item ["type" ]} ' )
171+ return '\n ' .join (res )
172+
150173 def proc_messages (
174+ self ,
151175 messages ,
152176 chat_template ,
153177 sequence_start ,
@@ -158,24 +182,17 @@ def proc_messages(
158182 prompt_messages = []
159183 IMAGE_TOKEN = '<IMAGE_TOKEN>'
160184 for message in messages :
161- if isinstance (message ['content' ], str ):
162- prompt_messages .append (message )
185+ if message ['role' ] in ['preprocess' , 'forward' ]:
163186 continue
164- elif message ['role' ] in ['preprocess' , 'forward' ]:
165- continue
166- n_images = len ([1 for x in message ['content' ] if x ['type' ] == 'image' ])
167- content = [x .get ('text' , '' ) for x in message ['content' ] if x ['type' ] == 'text' ]
168- prompt = content [0 ]
169- if IMAGE_TOKEN in prompt and f'<img>{ IMAGE_TOKEN } ' not in prompt :
170- prompt = prompt .replace (f'{ IMAGE_TOKEN } ' , f'<img>{ IMAGE_TOKEN } </img>' )
171- prompt = prompt .replace ('</img><img>' , '' )
172- prompt = prompt .replace ('<img><img>' , '<img>' )
173- prompt = prompt .replace ('</img></img>' , '</img>' )
174- elif IMAGE_TOKEN not in prompt :
175- prompt = f'<img>{ IMAGE_TOKEN * n_images } </img>\n ' + prompt
187+ role , content = message ['role' ], message ['content' ]
188+ if role == 'user' and isinstance (content , List ):
189+ content = (self .proc_internvl_hf_messages (content , IMAGE_TOKEN ) if self .arch
190+ == 'InternVLForConditionalGeneration' else self .proc_interns1_messages (content , IMAGE_TOKEN ))
191+ message = dict (role = role , content = content )
192+ prompt_messages .append (message )
176193 else :
177- pass
178- prompt_messages . append ( dict ( role = 'user' , content = prompt ))
194+ prompt_messages . append ( message )
195+
179196 prompt = chat_template .messages2prompt (prompt_messages ,
180197 sequence_start ,
181198 tools = tools ,
0 commit comments