Skip to content

Commit 0d175e0

Browse files
authored
Merge pull request #34 from fdcp/sample_packing_example
offline_sample_packing description
2 parents 5654460 + 1f540a3 commit 0d175e0

18 files changed

+4330
-0
lines changed

tools/data_preprocess/offline_packing_examples/README.md

Lines changed: 358 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# 数据路径配置
2+
data:
3+
# 数据样本目录
4+
directory: "/data_1/llava_next_raw_full/split_json_files/"
5+
# directory: "/vlm/data/llava_next_500/split_json_files"
6+
# 存储配对文件名的临时文件
7+
output_base: "base_name_v4_MR_sft_780k_8k.txt"
8+
# 最终输出文件(包含token长度信息)
9+
output_token: "token_info_MR_sft_780k_8k.txt"
10+
11+
# 模型路径
12+
model:
13+
checkpoint: "/vlm/xiangan/pretrain_models/rice_vl/rice_vl_rice_300m_qwen2.5_7b_adapter_v1_fixed_tokenizer_huggingface"
14+
15+
sample:
16+
# 训练数据的最大长度
17+
max_len: 8192
18+
del_one_token: false
19+
# 决定解析方式
20+
task_type: sft
21+
max_prompt: null
22+
max_answer: null
23+
24+
# 图像处理参数
25+
image:
26+
baidu_resolution: 1800 # baidu 代码中的限制参数(null)
27+
min_pixels: 3136 # 4*28*28
28+
max_pixels: 4014080 # 5120*28*28(4014080,8192)
29+
# 最大宽高比限制(超过此值的图片将被过滤),这是隐含信息(qwen vl自己处理)
30+
max_aspect_ratio: 200
31+
32+
# 并行处理参数
33+
processing:
34+
# 每个进程处理的样本块大小
35+
chunk_size: 5000
36+
# 归并参数(排序),每N个stage0文件合并为1个stage1文件
37+
stage1_merge_chunk: 20
38+
n_workers: 64
39+
# 线程池最小线程数
40+
min_workers: 10
41+
# 线程池最大线程数
42+
max_workers: 32
43+
# 超时设置(根据数据量定,1M数据按 45分钟(2700s)估算)
44+
time_out: 20000
45+
46+
# 日志与临时文件
47+
logging:
48+
# 日志级别(DEBUG, INFO, WARNING, ERROR, CRITICAL)
49+
level: "INFO"
50+
# 日志文件路径
51+
file: "./logs/s1_processing_MR_sft_780k_8k.log"
52+
# 是否使用 /dev/shm 作为临时目录
53+
use_shm: false
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
import json
2+
import os
3+
import shutil
4+
import logging
5+
from concurrent.futures import ThreadPoolExecutor
6+
from multiprocessing import Manager, cpu_count, Process
7+
from tqdm import tqdm
8+
9+
# 1)__img--output 独立进行编号
10+
11+
logging.basicConfig(
12+
level=logging.INFO,
13+
format='%(asctime)s - %(levelname)s - %(message)s'
14+
)
15+
logger = logging.getLogger()
16+
17+
# ---------- 工具 ----------
18+
def extract_filename_without_ext(image_path: str) -> str:
19+
return os.path.splitext(os.path.basename(image_path))[0]
20+
21+
22+
23+
# ---------------------------- patch 1 ----------------------------
24+
# 新增:线程安全重名计数器
25+
from collections import defaultdict
26+
import re
27+
import threading
28+
29+
def _unique_filename(name: str, name_counter, name_lock) -> str:
30+
base, ext = os.path.splitext(name)
31+
with name_lock:
32+
# 用 get 避免 KeyError
33+
cnt = name_counter.get(name, 0)
34+
name_counter[name] = cnt + 1
35+
if cnt == 0:
36+
return name
37+
return f"{base}_{cnt}{ext}"
38+
39+
# -----------------------------------------------------------------
40+
41+
42+
43+
# ---------- 单元素处理 ----------
44+
def _process_single_item(args):
45+
"""
46+
线程级:处理单条数据
47+
参数打包成元组,便于 ThreadPoolExecutor
48+
"""
49+
# item, base_dir, output_dir, rel_img_path, no_img_indices = args
50+
(item, base_dir, output_dir, rel_img_path, no_img_indices,
51+
name_counter, name_lock) = args # patch 6
52+
53+
# ---------- 整理原始图片路径 ----------
54+
original_image_paths = []
55+
if item.get("images"):
56+
original_image_paths = item["images"] if isinstance(item["images"], list) else [item["images"]]
57+
else:
58+
item["images"] = []
59+
60+
if rel_img_path:
61+
original_image_paths = [
62+
os.path.normpath(os.path.join(base_dir, rel_img_path, p))
63+
for p in original_image_paths
64+
]
65+
else:
66+
original_image_paths = [
67+
os.path.normpath(os.path.join(base_dir, p))
68+
for p in original_image_paths
69+
]
70+
71+
# ---------- 统一重命名并拷贝图片 ----------
72+
new_image_basenames = []
73+
for src_path in original_image_paths:
74+
if not os.path.exists(src_path):
75+
logger.warning(f"图片不存在:{src_path}")
76+
continue
77+
old_name = os.path.basename(src_path)
78+
# new_name = _unique_filename(old_name) # 可能改名
79+
new_name = _unique_filename(old_name, name_counter, name_lock)
80+
new_image_basenames.append(new_name)
81+
82+
dst_path = os.path.join(output_dir, new_name)
83+
try:
84+
shutil.copy2(src_path, dst_path)
85+
except Exception as e:
86+
logger.error(f"拷贝图片失败: {src_path} -> {dst_path} | {e}")
87+
88+
# 同步更新 JSON 里的 images
89+
item["images"] = new_image_basenames
90+
91+
92+
#--------------patch 001----------
93+
# ✨ 新增:所有图片都不存在,直接返回 None
94+
if original_image_paths and not new_image_basenames:
95+
logger.info(f"跳过无有效图片的元素:{item.get('id', item['_orig_index'])}")
96+
return None
97+
#--------------patch 001 end----------
98+
99+
# ---------- 生成 json 文件名 ----------
100+
if new_image_basenames:
101+
json_name_root = os.path.splitext(new_image_basenames[0])[0]
102+
else:
103+
idx_in_no_img = no_img_indices.index(item['_orig_index'])
104+
json_name_root = f"__img--output_{idx_in_no_img:08d}"
105+
106+
# json_name = _unique_filename(json_name_root + ".json")
107+
json_name = _unique_filename(json_name_root + ".json", name_counter, name_lock)
108+
json_path = os.path.join(output_dir, json_name)
109+
try:
110+
with open(json_path, 'w', encoding='utf-8') as f:
111+
json.dump(item, f, indent=4, ensure_ascii=False)
112+
except Exception as e:
113+
logger.error(f"写 JSON 失败: {json_path} | {e}")
114+
115+
return os.path.splitext(json_name)[0]
116+
117+
# ---------- 进程级 ----------
118+
def _worker_process(job_queue, result_list, base_dir, output_dir,
119+
rel_img_path, m, no_img_indices,
120+
name_counter, name_lock): # <-- patch4
121+
while True:
122+
try:
123+
chunk = job_queue.get_nowait()
124+
except:
125+
break
126+
127+
logger.info(f"进程 {os.getpid()} 处理 chunk({len(chunk)} 条)")
128+
# 构造参数列表
129+
arg_list = [(item, base_dir, output_dir, rel_img_path, no_img_indices, name_counter, name_lock)
130+
for item in chunk]
131+
132+
valid_names = []
133+
with ThreadPoolExecutor(max_workers=m) as pool:
134+
for fut in tqdm(pool.map(_process_single_item, arg_list),
135+
total=len(arg_list),
136+
desc=f"PID-{os.getpid()}",
137+
leave=False):
138+
if fut is not None: # ✨ 过滤掉 None。patch 002
139+
valid_names.append(fut)
140+
result_list.extend(valid_names)
141+
142+
# ---------- 主入口 ----------
143+
def split_json_file(fin_name, rel_img_path=None, *, chunk_dim=1000, m=8):
144+
# 读数据
145+
try:
146+
with open(fin_name, 'r', encoding='utf-8') as f:
147+
data = json.load(f)
148+
except Exception as e:
149+
logger.error(f"读取 JSON 失败: {e}")
150+
return set()
151+
152+
if not isinstance(data, list):
153+
logger.error("JSON 根节点不是数组")
154+
return set()
155+
156+
# 打原始索引 & 收集无图索引
157+
for i, item in enumerate(data):
158+
item['_orig_index'] = i
159+
no_img_indices = [i for i, item in enumerate(data) if not item.get("images")]
160+
161+
# 目录准备
162+
base_dir = os.path.dirname(os.path.abspath(fin_name))
163+
output_dir = os.path.join(base_dir, "split_json_files")
164+
if os.path.exists(output_dir):
165+
shutil.rmtree(output_dir)
166+
os.makedirs(output_dir, exist_ok=True)
167+
168+
# 分块
169+
total = len(data)
170+
num_chunks = (total + chunk_dim - 1) // chunk_dim
171+
chunks = [data[i * chunk_dim:(i + 1) * chunk_dim] for i in range(num_chunks)]
172+
173+
max_workers = min(num_chunks, cpu_count())
174+
logger.info(f"共 {total} 条,切成 {num_chunks} 块,启动 {max_workers} 进程,每进程 {m} 线程")
175+
176+
with Manager() as manager:
177+
job_queue = manager.Queue()
178+
for c in chunks:
179+
job_queue.put(c)
180+
181+
result_list = manager.list()
182+
name_counter = manager.dict() # <-- 新增 patch2
183+
name_lock = manager.Lock() # <-- 新增 patch2
184+
185+
processes = [
186+
Process(target=_worker_process,
187+
args=(job_queue, result_list, base_dir,
188+
output_dir, rel_img_path, m, no_img_indices,
189+
name_counter, name_lock)) # <-- 新增 patch3
190+
for _ in range(max_workers)
191+
]
192+
for p in processes:
193+
p.start()
194+
for p in processes:
195+
p.join()
196+
197+
all_valid_names = set(result_list)
198+
199+
logger.info("全部处理完成")
200+
return all_valid_names
201+
202+
# ---------- 脚本 ----------
203+
if __name__ == "__main__":
204+
# f_json = "/vlm/data/llava_next_500/sampled_data.json"
205+
f_json = "/data_1/llava_next_raw_full/megatron_format_780k.json"
206+
rel_img = "images"
207+
res = split_json_file(
208+
f_json,
209+
"images",
210+
chunk_dim=2000,
211+
m=8
212+
)
213+
print(f"共生成 {len(res)} 个文件")

0 commit comments

Comments
 (0)