|
| 1 | +# -*- encoding: utf-8 -*- |
| 2 | +# @Author: SWHL |
| 3 | + |
| 4 | +import argparse |
| 5 | +import json |
| 6 | +import shutil |
| 7 | +from pathlib import Path |
| 8 | + |
| 9 | +import numpy as np |
| 10 | +from tqdm import tqdm |
| 11 | + |
| 12 | + |
| 13 | +class COCO2labelImg(): |
| 14 | + def __init__(self, data_dir: str = None): |
| 15 | + # coco dir |
| 16 | + self.data_dir = Path(data_dir) |
| 17 | + self.verify_exists(self.data_dir) |
| 18 | + |
| 19 | + anno_dir = self.data_dir / 'annotations' |
| 20 | + self.verify_exists(anno_dir) |
| 21 | + |
| 22 | + self.train_json = anno_dir / 'instances_train2017.json' |
| 23 | + self.val_json = anno_dir / 'instances_val2017.json' |
| 24 | + self.verify_exists(self.train_json) |
| 25 | + self.verify_exists(self.val_json) |
| 26 | + |
| 27 | + self.train2017_dir = self.data_dir / 'train2017' |
| 28 | + self.val2017_dir = self.data_dir / 'val2017' |
| 29 | + self.verify_exists(self.train2017_dir) |
| 30 | + self.verify_exists(self.val2017_dir) |
| 31 | + |
| 32 | + # save dir |
| 33 | + self.save_dir = self.data_dir.parent / 'COCO_labelImg_format' |
| 34 | + self.mkdir(self.save_dir) |
| 35 | + |
| 36 | + self.save_train_dir = self.save_dir / 'train' |
| 37 | + self.mkdir(self.save_train_dir) |
| 38 | + |
| 39 | + self.save_val_dir = self.save_dir / 'val' |
| 40 | + self.mkdir(self.save_val_dir) |
| 41 | + |
| 42 | + def __call__(self, ): |
| 43 | + train_list = [self.train_json, self.save_train_dir, self.train2017_dir] |
| 44 | + self.convert(train_list) |
| 45 | + |
| 46 | + val_list = [self.val_json, self.save_val_dir, self.val2017_dir] |
| 47 | + self.convert(val_list) |
| 48 | + |
| 49 | + print(f'Successfully convert, detail in {self.save_dir}') |
| 50 | + |
| 51 | + def convert(self, info_list: list): |
| 52 | + json_path, save_dir, img_dir = info_list |
| 53 | + |
| 54 | + data = self.read_json(str(json_path)) |
| 55 | + self.gen_classes_txt(save_dir, data.get('categories')) |
| 56 | + |
| 57 | + id_img_dict = {v['id']: v for v in data.get('images')} |
| 58 | + all_annotaions = data.get('annotations') |
| 59 | + for one_anno in tqdm(all_annotaions): |
| 60 | + image_info = id_img_dict.get(one_anno['image_id']) |
| 61 | + img_name = image_info.get('file_name') |
| 62 | + img_height = image_info.get('height') |
| 63 | + img_width = image_info.get('width') |
| 64 | + |
| 65 | + seg_info = one_anno.get('segmentation') |
| 66 | + if seg_info: |
| 67 | + bbox = self.get_bbox(seg_info) |
| 68 | + xywh = self.xyxy_to_xywh(bbox, img_width, img_height) |
| 69 | + category_id = int(one_anno.get('category_id')) - 1 |
| 70 | + xywh_str = ' '.join([str(v) for v in xywh]) |
| 71 | + label_str = f'{category_id} {xywh_str}' |
| 72 | + |
| 73 | + # 写入标注的txt文件 |
| 74 | + txt_full_path = save_dir / f'{Path(img_name).stem}.txt' |
| 75 | + self.write_txt(txt_full_path, label_str, mode='a') |
| 76 | + |
| 77 | + # 复制图像到转换后目录 |
| 78 | + img_full_path = img_dir / img_name |
| 79 | + shutil.copy2(img_full_path, save_dir) |
| 80 | + |
| 81 | + @staticmethod |
| 82 | + def read_json(json_path): |
| 83 | + with open(json_path, 'r', encoding='utf-8') as f: |
| 84 | + data = json.load(f) |
| 85 | + return data |
| 86 | + |
| 87 | + def gen_classes_txt(self, save_dir, categories_dict): |
| 88 | + class_info = [value['name'] for value in categories_dict] |
| 89 | + self.write_txt(save_dir / 'classes.txt', class_info) |
| 90 | + |
| 91 | + def get_bbox(self, seg_info): |
| 92 | + seg_info = np.array(seg_info[0]).reshape(4, 2) |
| 93 | + x0, y0 = np.min(seg_info, axis=0) |
| 94 | + x1, y1 = np.max(seg_info, axis=0) |
| 95 | + bbox = [x0, y0, x1, y1] |
| 96 | + return bbox |
| 97 | + |
| 98 | + @staticmethod |
| 99 | + def write_txt(save_path: str, content: list, mode='w'): |
| 100 | + if not isinstance(save_path, str): |
| 101 | + save_path = str(save_path) |
| 102 | + |
| 103 | + if isinstance(content, str): |
| 104 | + content = [content] |
| 105 | + with open(save_path, mode, encoding='utf-8') as f: |
| 106 | + for value in content: |
| 107 | + f.write(f'{value}\n') |
| 108 | + |
| 109 | + @staticmethod |
| 110 | + def xyxy_to_xywh(xyxy: list, |
| 111 | + img_width: int, |
| 112 | + img_height: int) -> tuple([float, float, float, float]): |
| 113 | + """ |
| 114 | + xyxy: (list), [x1, y1, x2, y2] |
| 115 | + """ |
| 116 | + x_center = (xyxy[0] + xyxy[2]) / (2 * img_width) |
| 117 | + y_center = (xyxy[1] + xyxy[3]) / (2 * img_height) |
| 118 | + |
| 119 | + box_w = abs(xyxy[2] - xyxy[0]) |
| 120 | + box_h = abs(xyxy[3] - xyxy[1]) |
| 121 | + |
| 122 | + w = box_w / img_width |
| 123 | + h = box_h / img_height |
| 124 | + return x_center, y_center, w, h |
| 125 | + |
| 126 | + @staticmethod |
| 127 | + def verify_exists(file_path): |
| 128 | + file_path = Path(file_path) |
| 129 | + if not file_path.exists(): |
| 130 | + raise FileNotFoundError(f'The {file_path} is not exists!!!') |
| 131 | + |
| 132 | + @staticmethod |
| 133 | + def mkdir(dir_path): |
| 134 | + Path(dir_path).mkdir(parents=True, exist_ok=True) |
| 135 | + |
| 136 | + |
| 137 | +if __name__ == '__main__': |
| 138 | + parser = argparse.ArgumentParser('Datasets convert from COCO to labelImg') |
| 139 | + parser.add_argument('--data_dir', type=str, |
| 140 | + default='dataset/YOLOV5_COCO_format', |
| 141 | + help='Dataset root path') |
| 142 | + args = parser.parse_args() |
| 143 | + |
| 144 | + converter = COCO2labelImg(args.data_dir) |
| 145 | + converter() |
0 commit comments