|
| 1 | +import json |
| 2 | +import os |
| 3 | +import shutil |
| 4 | +import logging |
| 5 | +from concurrent.futures import ThreadPoolExecutor |
| 6 | +from multiprocessing import Manager, cpu_count, Process |
| 7 | +from tqdm import tqdm |
| 8 | + |
| 9 | +# 1)__img--output 独立进行编号 |
| 10 | + |
| 11 | +logging.basicConfig( |
| 12 | + level=logging.INFO, |
| 13 | + format='%(asctime)s - %(levelname)s - %(message)s' |
| 14 | +) |
| 15 | +logger = logging.getLogger() |
| 16 | + |
| 17 | +# ---------- 工具 ---------- |
| 18 | +def extract_filename_without_ext(image_path: str) -> str: |
| 19 | + return os.path.splitext(os.path.basename(image_path))[0] |
| 20 | + |
| 21 | + |
| 22 | + |
| 23 | +# ---------------------------- patch 1 ---------------------------- |
| 24 | +# 新增:线程安全重名计数器 |
| 25 | +from collections import defaultdict |
| 26 | +import re |
| 27 | +import threading |
| 28 | + |
| 29 | +def _unique_filename(name: str, name_counter, name_lock) -> str: |
| 30 | + base, ext = os.path.splitext(name) |
| 31 | + with name_lock: |
| 32 | + # 用 get 避免 KeyError |
| 33 | + cnt = name_counter.get(name, 0) |
| 34 | + name_counter[name] = cnt + 1 |
| 35 | + if cnt == 0: |
| 36 | + return name |
| 37 | + return f"{base}_{cnt}{ext}" |
| 38 | + |
| 39 | +# ----------------------------------------------------------------- |
| 40 | + |
| 41 | + |
| 42 | + |
| 43 | +# ---------- 单元素处理 ---------- |
| 44 | +def _process_single_item(args): |
| 45 | + """ |
| 46 | + 线程级:处理单条数据 |
| 47 | + 参数打包成元组,便于 ThreadPoolExecutor |
| 48 | + """ |
| 49 | + # item, base_dir, output_dir, rel_img_path, no_img_indices = args |
| 50 | + (item, base_dir, output_dir, rel_img_path, no_img_indices, |
| 51 | + name_counter, name_lock) = args # patch 6 |
| 52 | + |
| 53 | + # ---------- 整理原始图片路径 ---------- |
| 54 | + original_image_paths = [] |
| 55 | + if item.get("images"): |
| 56 | + original_image_paths = item["images"] if isinstance(item["images"], list) else [item["images"]] |
| 57 | + else: |
| 58 | + item["images"] = [] |
| 59 | + |
| 60 | + if rel_img_path: |
| 61 | + original_image_paths = [ |
| 62 | + os.path.normpath(os.path.join(base_dir, rel_img_path, p)) |
| 63 | + for p in original_image_paths |
| 64 | + ] |
| 65 | + else: |
| 66 | + original_image_paths = [ |
| 67 | + os.path.normpath(os.path.join(base_dir, p)) |
| 68 | + for p in original_image_paths |
| 69 | + ] |
| 70 | + |
| 71 | + # ---------- 统一重命名并拷贝图片 ---------- |
| 72 | + new_image_basenames = [] |
| 73 | + for src_path in original_image_paths: |
| 74 | + if not os.path.exists(src_path): |
| 75 | + logger.warning(f"图片不存在:{src_path}") |
| 76 | + continue |
| 77 | + old_name = os.path.basename(src_path) |
| 78 | + # new_name = _unique_filename(old_name) # 可能改名 |
| 79 | + new_name = _unique_filename(old_name, name_counter, name_lock) |
| 80 | + new_image_basenames.append(new_name) |
| 81 | + |
| 82 | + dst_path = os.path.join(output_dir, new_name) |
| 83 | + try: |
| 84 | + shutil.copy2(src_path, dst_path) |
| 85 | + except Exception as e: |
| 86 | + logger.error(f"拷贝图片失败: {src_path} -> {dst_path} | {e}") |
| 87 | + |
| 88 | + # 同步更新 JSON 里的 images |
| 89 | + item["images"] = new_image_basenames |
| 90 | + |
| 91 | + |
| 92 | + #--------------patch 001---------- |
| 93 | + # ✨ 新增:所有图片都不存在,直接返回 None |
| 94 | + if original_image_paths and not new_image_basenames: |
| 95 | + logger.info(f"跳过无有效图片的元素:{item.get('id', item['_orig_index'])}") |
| 96 | + return None |
| 97 | + #--------------patch 001 end---------- |
| 98 | + |
| 99 | + # ---------- 生成 json 文件名 ---------- |
| 100 | + if new_image_basenames: |
| 101 | + json_name_root = os.path.splitext(new_image_basenames[0])[0] |
| 102 | + else: |
| 103 | + idx_in_no_img = no_img_indices.index(item['_orig_index']) |
| 104 | + json_name_root = f"__img--output_{idx_in_no_img:08d}" |
| 105 | + |
| 106 | + # json_name = _unique_filename(json_name_root + ".json") |
| 107 | + json_name = _unique_filename(json_name_root + ".json", name_counter, name_lock) |
| 108 | + json_path = os.path.join(output_dir, json_name) |
| 109 | + try: |
| 110 | + with open(json_path, 'w', encoding='utf-8') as f: |
| 111 | + json.dump(item, f, indent=4, ensure_ascii=False) |
| 112 | + except Exception as e: |
| 113 | + logger.error(f"写 JSON 失败: {json_path} | {e}") |
| 114 | + |
| 115 | + return os.path.splitext(json_name)[0] |
| 116 | + |
| 117 | +# ---------- 进程级 ---------- |
| 118 | +def _worker_process(job_queue, result_list, base_dir, output_dir, |
| 119 | + rel_img_path, m, no_img_indices, |
| 120 | + name_counter, name_lock): # <-- patch4 |
| 121 | + while True: |
| 122 | + try: |
| 123 | + chunk = job_queue.get_nowait() |
| 124 | + except: |
| 125 | + break |
| 126 | + |
| 127 | + logger.info(f"进程 {os.getpid()} 处理 chunk({len(chunk)} 条)") |
| 128 | + # 构造参数列表 |
| 129 | + arg_list = [(item, base_dir, output_dir, rel_img_path, no_img_indices, name_counter, name_lock) |
| 130 | + for item in chunk] |
| 131 | + |
| 132 | + valid_names = [] |
| 133 | + with ThreadPoolExecutor(max_workers=m) as pool: |
| 134 | + for fut in tqdm(pool.map(_process_single_item, arg_list), |
| 135 | + total=len(arg_list), |
| 136 | + desc=f"PID-{os.getpid()}", |
| 137 | + leave=False): |
| 138 | + if fut is not None: # ✨ 过滤掉 None。patch 002 |
| 139 | + valid_names.append(fut) |
| 140 | + result_list.extend(valid_names) |
| 141 | + |
| 142 | +# ---------- 主入口 ---------- |
| 143 | +def split_json_file(fin_name, rel_img_path=None, *, chunk_dim=1000, m=8): |
| 144 | + # 读数据 |
| 145 | + try: |
| 146 | + with open(fin_name, 'r', encoding='utf-8') as f: |
| 147 | + data = json.load(f) |
| 148 | + except Exception as e: |
| 149 | + logger.error(f"读取 JSON 失败: {e}") |
| 150 | + return set() |
| 151 | + |
| 152 | + if not isinstance(data, list): |
| 153 | + logger.error("JSON 根节点不是数组") |
| 154 | + return set() |
| 155 | + |
| 156 | + # 打原始索引 & 收集无图索引 |
| 157 | + for i, item in enumerate(data): |
| 158 | + item['_orig_index'] = i |
| 159 | + no_img_indices = [i for i, item in enumerate(data) if not item.get("images")] |
| 160 | + |
| 161 | + # 目录准备 |
| 162 | + base_dir = os.path.dirname(os.path.abspath(fin_name)) |
| 163 | + output_dir = os.path.join(base_dir, "split_json_files") |
| 164 | + if os.path.exists(output_dir): |
| 165 | + shutil.rmtree(output_dir) |
| 166 | + os.makedirs(output_dir, exist_ok=True) |
| 167 | + |
| 168 | + # 分块 |
| 169 | + total = len(data) |
| 170 | + num_chunks = (total + chunk_dim - 1) // chunk_dim |
| 171 | + chunks = [data[i * chunk_dim:(i + 1) * chunk_dim] for i in range(num_chunks)] |
| 172 | + |
| 173 | + max_workers = min(num_chunks, cpu_count()) |
| 174 | + logger.info(f"共 {total} 条,切成 {num_chunks} 块,启动 {max_workers} 进程,每进程 {m} 线程") |
| 175 | + |
| 176 | + with Manager() as manager: |
| 177 | + job_queue = manager.Queue() |
| 178 | + for c in chunks: |
| 179 | + job_queue.put(c) |
| 180 | + |
| 181 | + result_list = manager.list() |
| 182 | + name_counter = manager.dict() # <-- 新增 patch2 |
| 183 | + name_lock = manager.Lock() # <-- 新增 patch2 |
| 184 | + |
| 185 | + processes = [ |
| 186 | + Process(target=_worker_process, |
| 187 | + args=(job_queue, result_list, base_dir, |
| 188 | + output_dir, rel_img_path, m, no_img_indices, |
| 189 | + name_counter, name_lock)) # <-- 新增 patch3 |
| 190 | + for _ in range(max_workers) |
| 191 | + ] |
| 192 | + for p in processes: |
| 193 | + p.start() |
| 194 | + for p in processes: |
| 195 | + p.join() |
| 196 | + |
| 197 | + all_valid_names = set(result_list) |
| 198 | + |
| 199 | + logger.info("全部处理完成") |
| 200 | + return all_valid_names |
| 201 | + |
| 202 | +# ---------- 脚本 ---------- |
| 203 | +if __name__ == "__main__": |
| 204 | + # f_json = "/vlm/data/llava_next_500/sampled_data.json" |
| 205 | + f_json = "/data_1/llava_next_raw_full/megatron_format_780k.json" |
| 206 | + rel_img = "images" |
| 207 | + res = split_json_file( |
| 208 | + f_json, |
| 209 | + "images", |
| 210 | + chunk_dim=2000, |
| 211 | + m=8 |
| 212 | + ) |
| 213 | + print(f"共生成 {len(res)} 个文件") |
0 commit comments