1- import os
2- from dataclasses import dataclass , field
3- from typing import Optional , Tuple , Union , Iterable , Sequence , Dict , Any
1+ # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
414
5- import yaml
615import json
16+ import math
17+ import os
18+ import random
19+ from dataclasses import dataclass
20+ from typing import Any , Dict , Optional , Sequence
21+
722import paddle
8- from PIL import Image
23+ import yaml
924from paddlenlp .data import DataCollatorForSeq2Seq
25+ from paddlenlp .transformers .processing_utils import ProcessorMixin
1026from paddlenlp .utils .import_utils import import_module
11- from paddlemix .processors .qwen2_5_vl_processing import Qwen2_5_VLImageProcessor , Qwen2_5_VLProcessor ,process_vision_info
27+ from PIL import Image
28+
29+ from paddlemix .models .qwen2_vl .template import TEMPLATES
30+ from paddlemix .processors .qwen2_5_vl_processing import process_vision_info
1231
1332
1433@dataclass
@@ -17,11 +36,12 @@ class Qwen2VLDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
1736 processor : Optional ["ProcessorMixin" ] = None
1837 process_vision_info = None
1938 template_name : Optional [str ] = None
39+
2040 def __post_init__ (self ):
2141 if self .template is None :
2242 raise ValueError ("Template is required for MultiModalDataCollator." )
2343 if self .process_vision_info is None :
24- self .process_vision_info = import_module (f' paddlemix.processors.{ self .template_name } _processing' )
44+ self .process_vision_info = import_module (f" paddlemix.processors.{ self .template_name } _processing" )
2545
2646 def __call__ (self , features : Sequence [Dict [str , Any ]]) -> Dict [str , "paddle.Tensor" ]:
2747 batched_pixel_values = []
@@ -30,12 +50,12 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens
3050 batched_labels = []
3151 batched_image_grid_thw = []
3252 for feature in features :
33- messages = feature [' prompt' ]
34- solution_text = feature [' solution' ]
53+ messages = feature [" prompt" ]
54+ solution_text = feature [" solution" ]
3555 prompt_text = self .processor .tokenizer .apply_chat_template (
3656 messages , tokenize = False , add_generation_prompt = True
3757 )
38- messages [0 ][' content' ][0 ][' image' ] = feature [' image' ]
58+ messages [0 ][" content" ][0 ][" image" ] = feature [" image" ]
3959 image_inputs , video_inputs = process_vision_info (messages )
4060 inputs = self .processor (
4161 text = prompt_text ,
@@ -52,11 +72,11 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens
5272 padding_side = "left" ,
5373 add_special_tokens = False ,
5474 )
55- batched_pixel_values .append (inputs [' pixel_values' ])
56- batched_input_ids .append (inputs [' input_ids' ][0 ])
57- batched_attention_mask .append (inputs [' attention_mask' ][0 ])
58- batched_labels .append (solution_inputs [' input_ids' ][0 ])
59- batched_image_grid_thw .append (inputs [' image_grid_thw' ])
75+ batched_pixel_values .append (inputs [" pixel_values" ])
76+ batched_input_ids .append (inputs [" input_ids" ][0 ])
77+ batched_attention_mask .append (inputs [" attention_mask" ][0 ])
78+ batched_labels .append (solution_inputs [" input_ids" ][0 ])
79+ batched_image_grid_thw .append (inputs [" image_grid_thw" ])
6080 return {
6181 "pixel_values" : paddle .stack (batched_pixel_values ),
6282 "attention_mask" : paddle .stack (batched_attention_mask ),
@@ -67,15 +87,7 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens
6787
6888
6989class Qwen2VLRECDataset (paddle .io .Dataset ):
70- def __init__ (self ,
71- data_path : str ,
72- script_args ,
73- training_args ,
74- model_args ,
75- tokenizer ,
76- processor ,
77- template
78- ):
90+ def __init__ (self , data_path : str , script_args , training_args , model_args , tokenizer , processor , template ):
7991 super (Qwen2VLRECDataset , self ).__init__ ()
8092 self .script_args = script_args
8193 self .list_data_dict = []
@@ -105,15 +117,9 @@ def __init__(self,
105117 else :
106118 raise ValueError (f"Unsupported file type: { json_path } " )
107119 if ":" in sampling_strategy :
108- sampling_strategy , sampling_number = sampling_strategy .split (
109- ":"
110- )
120+ sampling_strategy , sampling_number = sampling_strategy .split (":" )
111121 if "%" in sampling_number :
112- sampling_number = math .ceil (
113- int (sampling_number .split ("%" )[0 ])
114- * len (cur_data_dict )
115- / 100
116- )
122+ sampling_number = math .ceil (int (sampling_number .split ("%" )[0 ]) * len (cur_data_dict ) / 100 )
117123 else :
118124 sampling_number = int (sampling_number )
119125 if sampling_strategy == "first" and sampling_number is not None :
@@ -131,11 +137,13 @@ def __init__(self,
131137 def __len__ (self ):
132138 return len (self .list_data_dict )
133139
134- def _preprocess_image (self , image ,image_max_pixels ,image_min_pixels ):
140+ def _preprocess_image (self , image , image_max_pixels , image_min_pixels ):
135141 r"""
136142 Pre-processes a single image.
137143 """
138- image = self .template .mm_plugin ._preprocess_image (image ,image_max_pixels = image_max_pixels ,image_min_pixels = image_min_pixels )
144+ image = self .template .mm_plugin ._preprocess_image (
145+ image , image_max_pixels = image_max_pixels , image_min_pixels = image_min_pixels
146+ )
139147 return image
140148
141149 def get_image_path (self , image_path ):
@@ -146,9 +154,7 @@ def get_transform(self):
146154
147155 def multi_modal_get_item (self , data_item ):
148156 messages = data_item ["messages" ]
149- text = self .processor .tokenizer .apply_chat_template (
150- messages , tokenize = False , add_generation_prompt = True
151- )
157+ text = self .processor .tokenizer .apply_chat_template (messages , tokenize = False , add_generation_prompt = True )
152158 image_inputs , video_inputs = process_vision_info (messages )
153159 inputs = self .processor (
154160 text = text ,
@@ -158,38 +164,38 @@ def multi_modal_get_item(self, data_item):
158164 return_tensors = "pd" ,
159165 )
160166 label_ids = self .processor .tokenizer (
161- text = str (data_item [' label' ]),
167+ text = str (data_item [" label" ]),
162168 padding = True ,
163169 padding_side = "left" ,
164170 return_tensors = "pd" ,
165171 )
166- # unwrap
167- inputs [' input_ids' ] = inputs [' input_ids' ][0 ]
168- inputs [' attention_mask' ] = inputs [' attention_mask' ][0 ]
172+ # unwrap
173+ inputs [" input_ids" ] = inputs [" input_ids" ][0 ]
174+ inputs [" attention_mask" ] = inputs [" attention_mask" ][0 ]
169175
170176 # Create the final return dictionary
171177 ret = dict (
172178 ** inputs ,
173- labels = label_ids [' input_ids' ][0 ],
179+ labels = label_ids [" input_ids" ][0 ],
174180 )
175181 return ret
176182
177183 def __getitem__ (self , i ):
178184 QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."
179185
180- def make_conversation_image (example ,image ):
186+ def make_conversation_image (example , image ):
181187 return {
182188 "messages" : [
183- {
184- "role" : "user" ,
185- "content" : [
186- {"type" : "image" , "image" : image },
187- {
188- "type" : "text" ,
189- "text" : QUESTION_TEMPLATE .format (Question = example [' problem' ]),
190- },
191- ],
192- }
189+ {
190+ "role" : "user" ,
191+ "content" : [
192+ {"type" : "image" , "image" : image },
193+ {
194+ "type" : "text" ,
195+ "text" : QUESTION_TEMPLATE .format (Question = example [" problem" ]),
196+ },
197+ ],
198+ }
193199 ]
194200 }
195201
@@ -198,22 +204,21 @@ def make_conversation_image(example,image):
198204 if "image" in example :
199205 image_path = os .path .join (image_root , example ["image" ])
200206 while not os .path .exists (image_path ):
201- print (
202- f"Warning: Image { image_path } not found, randomly selecting another image"
203- )
207+ print (f"Warning: Image { image_path } not found, randomly selecting another image" )
204208 new_index = random .randint (0 , len (self .list_data_dict ) - 1 )
205209 example = self .list_data_dict [new_index ]
206210 image_path = os .path .join (image_root , example ["image" ])
207- image = self ._preprocess_image (Image .open (image_path ).convert ("RGB" ),
208- image_max_pixels = self .script_args .max_pixels ,
209- image_min_pixels = self .script_args .min_pixels ,
210- )
211+ image = self ._preprocess_image (
212+ Image .open (image_path ).convert ("RGB" ),
213+ image_max_pixels = self .script_args .max_pixels ,
214+ image_min_pixels = self .script_args .min_pixels ,
215+ )
211216 else :
212217 image = None
213- data_item = {
218+ data_item = {
214219 "image" : image ,
215- "image_path" : example [' image' ],
220+ "image_path" : example [" image" ],
216221 "label" : example ["solution" ],
217- "messages" : make_conversation_image (example ,image )["messages" ]
222+ "messages" : make_conversation_image (example , image )["messages" ],
218223 }
219- return self .multi_modal_get_item (data_item )
224+ return self .multi_modal_get_item (data_item )
0 commit comments