Skip to content

Commit 469f27e

Browse files
authored
refine r1_mllm codes and shells (#1201)
1 parent e88d893 commit 469f27e

File tree

9 files changed

+473
-429
lines changed

9 files changed

+473
-429
lines changed

paddlemix/examples/r1_mllm/README.md

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,35 @@
2020

2121

2222
## 数据准备
23+
2324
### 指向性目标检测任务
25+
26+
* 下载PaddleMIX团队整理好的数据集:
27+
```bash
28+
https://paddlenlp.bj.bcebos.com/datasets/paddlemix/playground/r1_mllm/REC.tar
29+
```
30+
31+
或者分别下载原始数据集:
32+
2433
* 下载 [COCO Train2014 image](https://huggingface.co/datasets/omlab/VLM-R1/resolve/main/train2014.zip) 并且解压到指定路径PaddleMIX下的data/coco目录.
34+
```
35+
wget https://paddlenlp.bj.bcebos.com/datasets/paddlemix/refcoco/train2014.tar
36+
```
2537

2638
* 下载 [RefGTA](https://huggingface.co/datasets/omlab/VLM-R1/resolve/main/refgta.zip) 并解压到data/refgta目录。
2739

2840
* 下载 [RefCOCO/+/g and RefGTA Annotation files](https://huggingface.co/datasets/omlab/VLM-R1/resolve/main/rec_jsons_processed.zip) 解压放置PaddleMIX/data/rec_jsons_processed目录下 (RefGTA 域外测试数据,用于泛化性测试).
2941

42+
3043
### 计数任务
44+
45+
* 下载PaddleMIX团队整理好的数据集:
46+
```bash
47+
https://paddlenlp.bj.bcebos.com/datasets/paddlemix/playground/r1_mllm/Counting.tar
48+
```
49+
50+
或者分别下载原始数据集:
51+
3152
* 下载 [CLEVR-70K-Counting](https://huggingface.co/datasets/leonardPKU/clevr_cogen_a_train) 训练数据集,修改your_path为你的实际安装路径路径。例如data/clevr_cogen_a_train
3253
```bash
3354
huggingface-cli download --resume-download leonardPKU/clevr_cogen_a_train --local-dir data/clevr_cogen_a_train --repo-type="dataset"
@@ -39,6 +60,14 @@ huggingface-cli download --resume-download leonardPKU/clevr_cogen_a_train --loca
3960

4061

4162
### 几何推理任务
63+
64+
* 下载PaddleMIX团队整理好的数据集:
65+
```bash
66+
https://paddlenlp.bj.bcebos.com/datasets/paddlemix/playground/r1_mllm/GEO.tar
67+
```
68+
69+
或者分别下载原始数据集:
70+
4271
* 下载 [GEOQA-8k](https://huggingface.co/datasets/leonardPKU/GEOQA_R1V_Train_8K) 到data/GEOQA_R1V_Train_8K 目录。
4372
```bash
4473
huggingface-cli download --resume-download leonardPKU/GEOQA_R1V_Train_8K --local-dir data/GEOQA_R1V_Train_8K --repo-type="dataset"
@@ -57,7 +86,7 @@ unzip data/Geo170K/images.zip -d data/Geo170K
5786
### 性能指标
5887
固定随机种子,从验证集中抽取500条数据测试,结果如下:
5988

60-
| Model | refcoco val| refcoco+ val | refcocog val | RefGTA |
89+
| Model | refcoco val| refcoco+ val | refcocog val | RefGTA |
6190
|--------------------------------------|------------|---------------|--------------|--------|
6291
| Qwen2.5-VL-3B-Instruct |88.60% |79.60% | 81.80% | 71.80% |
6392
| R1-Qwen2.5-VL-3B-Instruct(500steps) |88.40% |83.60% | 81.80% | 74.60% |
@@ -140,7 +169,7 @@ python paddlemix/examples/r1_mllm/eval/test_r1-v.py \
140169
--steps 500 \
141170
--seed 42
142171

143-
# test r1 geoqa
172+
# test r1 geoqa
144173
python paddlemix/examples/r1_mllm/eval/test_r1-v.py \
145174
--model_name "Qwen2.5-VL-3B-Instruct" \
146175
--method "r1" \
@@ -171,4 +200,4 @@ python paddlemix/examples/r1_mllm/eval/test_r1-v.py \
171200
note = {Accessed: 2025-02-02},
172201
year = {2025}
173202
}
174-
```
203+
```
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
datasets:
2-
- json_path: data/rec_jsons_processed/refcoco_train.json
3-
- json_path: data/rec_jsons_processed/refcocop_train.json
4-
- json_path: data/rec_jsons_processed/refcocog_train.json
2+
- json_path: data/REC/rec_jsons_processed/refcoco_train.json
3+
- json_path: data/REC/rec_jsons_processed/refcocop_train.json
4+
- json_path: data/REC/rec_jsons_processed/refcocog_train.json
Lines changed: 69 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,33 @@
1-
import os
2-
from dataclasses import dataclass, field
3-
from typing import Optional, Tuple, Union, Iterable, Sequence, Dict, Any
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
414

5-
import yaml
615
import json
16+
import math
17+
import os
18+
import random
19+
from dataclasses import dataclass
20+
from typing import Any, Dict, Optional, Sequence
21+
722
import paddle
8-
from PIL import Image
23+
import yaml
924
from paddlenlp.data import DataCollatorForSeq2Seq
25+
from paddlenlp.transformers.processing_utils import ProcessorMixin
1026
from paddlenlp.utils.import_utils import import_module
11-
from paddlemix.processors.qwen2_5_vl_processing import Qwen2_5_VLImageProcessor, Qwen2_5_VLProcessor,process_vision_info
27+
from PIL import Image
28+
29+
from paddlemix.models.qwen2_vl.template import TEMPLATES
30+
from paddlemix.processors.qwen2_5_vl_processing import process_vision_info
1231

1332

1433
@dataclass
@@ -17,11 +36,12 @@ class Qwen2VLDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
1736
processor: Optional["ProcessorMixin"] = None
1837
process_vision_info = None
1938
template_name: Optional[str] = None
39+
2040
def __post_init__(self):
2141
if self.template is None:
2242
raise ValueError("Template is required for MultiModalDataCollator.")
2343
if self.process_vision_info is None:
24-
self.process_vision_info = import_module(f'paddlemix.processors.{self.template_name}_processing')
44+
self.process_vision_info = import_module(f"paddlemix.processors.{self.template_name}_processing")
2545

2646
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tensor"]:
2747
batched_pixel_values = []
@@ -30,12 +50,12 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens
3050
batched_labels = []
3151
batched_image_grid_thw = []
3252
for feature in features:
33-
messages = feature['prompt']
34-
solution_text = feature['solution']
53+
messages = feature["prompt"]
54+
solution_text = feature["solution"]
3555
prompt_text = self.processor.tokenizer.apply_chat_template(
3656
messages, tokenize=False, add_generation_prompt=True
3757
)
38-
messages[0]['content'][0]['image'] = feature['image']
58+
messages[0]["content"][0]["image"] = feature["image"]
3959
image_inputs, video_inputs = process_vision_info(messages)
4060
inputs = self.processor(
4161
text=prompt_text,
@@ -52,11 +72,11 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens
5272
padding_side="left",
5373
add_special_tokens=False,
5474
)
55-
batched_pixel_values.append(inputs['pixel_values'])
56-
batched_input_ids.append(inputs['input_ids'][0])
57-
batched_attention_mask.append(inputs['attention_mask'][0])
58-
batched_labels.append(solution_inputs['input_ids'][0])
59-
batched_image_grid_thw.append(inputs['image_grid_thw'])
75+
batched_pixel_values.append(inputs["pixel_values"])
76+
batched_input_ids.append(inputs["input_ids"][0])
77+
batched_attention_mask.append(inputs["attention_mask"][0])
78+
batched_labels.append(solution_inputs["input_ids"][0])
79+
batched_image_grid_thw.append(inputs["image_grid_thw"])
6080
return {
6181
"pixel_values": paddle.stack(batched_pixel_values),
6282
"attention_mask": paddle.stack(batched_attention_mask),
@@ -67,15 +87,7 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens
6787

6888

6989
class Qwen2VLRECDataset(paddle.io.Dataset):
70-
def __init__(self,
71-
data_path: str,
72-
script_args,
73-
training_args,
74-
model_args,
75-
tokenizer,
76-
processor,
77-
template
78-
):
90+
def __init__(self, data_path: str, script_args, training_args, model_args, tokenizer, processor, template):
7991
super(Qwen2VLRECDataset, self).__init__()
8092
self.script_args = script_args
8193
self.list_data_dict = []
@@ -105,15 +117,9 @@ def __init__(self,
105117
else:
106118
raise ValueError(f"Unsupported file type: {json_path}")
107119
if ":" in sampling_strategy:
108-
sampling_strategy, sampling_number = sampling_strategy.split(
109-
":"
110-
)
120+
sampling_strategy, sampling_number = sampling_strategy.split(":")
111121
if "%" in sampling_number:
112-
sampling_number = math.ceil(
113-
int(sampling_number.split("%")[0])
114-
* len(cur_data_dict)
115-
/ 100
116-
)
122+
sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
117123
else:
118124
sampling_number = int(sampling_number)
119125
if sampling_strategy == "first" and sampling_number is not None:
@@ -131,11 +137,13 @@ def __init__(self,
131137
def __len__(self):
132138
return len(self.list_data_dict)
133139

134-
def _preprocess_image(self, image,image_max_pixels,image_min_pixels):
140+
def _preprocess_image(self, image, image_max_pixels, image_min_pixels):
135141
r"""
136142
Pre-processes a single image.
137143
"""
138-
image = self.template.mm_plugin._preprocess_image(image,image_max_pixels=image_max_pixels,image_min_pixels=image_min_pixels)
144+
image = self.template.mm_plugin._preprocess_image(
145+
image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
146+
)
139147
return image
140148

141149
def get_image_path(self, image_path):
@@ -146,9 +154,7 @@ def get_transform(self):
146154

147155
def multi_modal_get_item(self, data_item):
148156
messages = data_item["messages"]
149-
text = self.processor.tokenizer.apply_chat_template(
150-
messages, tokenize=False, add_generation_prompt=True
151-
)
157+
text = self.processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
152158
image_inputs, video_inputs = process_vision_info(messages)
153159
inputs = self.processor(
154160
text=text,
@@ -158,38 +164,38 @@ def multi_modal_get_item(self, data_item):
158164
return_tensors="pd",
159165
)
160166
label_ids = self.processor.tokenizer(
161-
text=str(data_item['label']),
167+
text=str(data_item["label"]),
162168
padding=True,
163169
padding_side="left",
164170
return_tensors="pd",
165171
)
166-
# unwrap
167-
inputs['input_ids'] = inputs['input_ids'][0]
168-
inputs['attention_mask'] = inputs['attention_mask'][0]
172+
# unwrap
173+
inputs["input_ids"] = inputs["input_ids"][0]
174+
inputs["attention_mask"] = inputs["attention_mask"][0]
169175

170176
# Create the final return dictionary
171177
ret = dict(
172178
**inputs,
173-
labels=label_ids['input_ids'][0],
179+
labels=label_ids["input_ids"][0],
174180
)
175181
return ret
176182

177183
def __getitem__(self, i):
178184
QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."
179185

180-
def make_conversation_image(example,image):
186+
def make_conversation_image(example, image):
181187
return {
182188
"messages": [
183-
{
184-
"role": "user",
185-
"content": [
186-
{"type": "image", "image": image},
187-
{
188-
"type": "text",
189-
"text": QUESTION_TEMPLATE.format(Question=example['problem']),
190-
},
191-
],
192-
}
189+
{
190+
"role": "user",
191+
"content": [
192+
{"type": "image", "image": image},
193+
{
194+
"type": "text",
195+
"text": QUESTION_TEMPLATE.format(Question=example["problem"]),
196+
},
197+
],
198+
}
193199
]
194200
}
195201

@@ -198,22 +204,21 @@ def make_conversation_image(example,image):
198204
if "image" in example:
199205
image_path = os.path.join(image_root, example["image"])
200206
while not os.path.exists(image_path):
201-
print(
202-
f"Warning: Image {image_path} not found, randomly selecting another image"
203-
)
207+
print(f"Warning: Image {image_path} not found, randomly selecting another image")
204208
new_index = random.randint(0, len(self.list_data_dict) - 1)
205209
example = self.list_data_dict[new_index]
206210
image_path = os.path.join(image_root, example["image"])
207-
image = self._preprocess_image(Image.open(image_path).convert("RGB"),
208-
image_max_pixels=self.script_args.max_pixels,
209-
image_min_pixels=self.script_args.min_pixels,
210-
)
211+
image = self._preprocess_image(
212+
Image.open(image_path).convert("RGB"),
213+
image_max_pixels=self.script_args.max_pixels,
214+
image_min_pixels=self.script_args.min_pixels,
215+
)
211216
else:
212217
image = None
213-
data_item = {
218+
data_item = {
214219
"image": image,
215-
"image_path": example['image'],
220+
"image_path": example["image"],
216221
"label": example["solution"],
217-
"messages": make_conversation_image(example,image)["messages"]
222+
"messages": make_conversation_image(example, image)["messages"],
218223
}
219-
return self.multi_modal_get_item(data_item)
224+
return self.multi_modal_get_item(data_item)

0 commit comments

Comments
 (0)