|
13 | 13 | from tqdm import tqdm |
14 | 14 |
|
15 | 15 | ValueType = Union[str, Path, None] |
| 16 | +RECTANGLE = "rectangle" |
| 17 | +POLYGON = "polygon" |
16 | 18 |
|
17 | 19 |
|
18 | 20 | class LabelmeToCOCO: |
@@ -192,24 +194,36 @@ def generate_json(self, img_list, save_dir): |
192 | 194 | shapes = raw_json_data.get("shapes", []) |
193 | 195 | anno_list = [] |
194 | 196 | for shape in shapes: |
| 197 | + shape_type = shape.get("shape_type") |
| 198 | + if shape_type not in [RECTANGLE, POLYGON]: |
| 199 | + print( |
| 200 | + f"Current shape type is {shape_type}, not between {RECTANGLE} and {POLYGON}, skip" |
| 201 | + ) |
| 202 | + continue |
| 203 | + |
195 | 204 | label_name = shape.get("label") |
196 | 205 | label_id = self.cls_to_idx[label_name] |
197 | | - |
198 | 206 | points = np.array(shape.get("points")) |
199 | | - x0, y0 = np.min(points, axis=0) |
200 | | - x1, y1 = np.max(points, axis=0) |
201 | | - area = (x1 - x0) * (y1 - y0) |
202 | | - |
203 | | - seg_points = [np.ravel(points, order="C").tolist()] |
204 | | - one_anno_dict = { |
205 | | - "segmentation": seg_points, |
206 | | - "area": area, |
207 | | - "iscrowd": 0, |
208 | | - "image_id": img_id, |
209 | | - "bbox": [x0, y0, x1, y1], |
210 | | - "category_id": label_id, |
211 | | - "id": self.object_id, |
212 | | - } |
| 207 | + |
| 208 | + if shape_type == RECTANGLE: |
| 209 | + x0, y0 = np.min(points, axis=0) |
| 210 | + x1, y1 = np.max(points, axis=0) |
| 211 | + area = (x1 - x0) * (y1 - y0) |
| 212 | + |
| 213 | + seg_points = [np.ravel(points, order="C").tolist()] |
| 214 | + |
| 215 | + one_anno_dict = { |
| 216 | + "segmentation": seg_points, |
| 217 | + "area": area, |
| 218 | + "iscrowd": 0, |
| 219 | + "image_id": img_id, |
| 220 | + "bbox": [x0, y0, x1, y1], |
| 221 | + "category_id": label_id, |
| 222 | + "id": self.object_id, |
| 223 | + } |
| 224 | + elif shape_type == POLYGON: |
| 225 | + pass |
| 226 | + |
213 | 227 | anno_list.append(one_anno_dict) |
214 | 228 | self.object_id += 1 |
215 | 229 | anno["annotations"].extend(anno_list) |
@@ -244,7 +258,11 @@ def cp_file(self, file_path: Path, dst_dir: Path): |
244 | 258 |
|
245 | 259 | def main(): |
246 | 260 | parser = argparse.ArgumentParser("Datasets converter from labelme to COCO") |
247 | | - parser.add_argument("--data_dir", type=str, default=None) |
| 261 | + parser.add_argument( |
| 262 | + "--data_dir", |
| 263 | + type=str, |
| 264 | + default="/Users/joshuawang/projects/_self/LabelConvert/data", |
| 265 | + ) |
248 | 266 | parser.add_argument("--save_dir", type=str, default=None) |
249 | 267 | parser.add_argument("--val_ratio", type=float, default=0.2) |
250 | 268 | parser.add_argument("--have_test", action="store_true", default=False) |
|
0 commit comments