88import json
99import shutil
1010
11- import cv2 as cv
11+ import cv2
1212
1313
1414def read_txt (txt_path ):
@@ -47,6 +47,7 @@ def __init__(self, dir_path):
4747 # 构建json内容结构
4848 self .type = 'instances'
4949 self .categories = []
50+ self .annotation_id = 1
5051
5152 # 读取类别数
5253 self ._get_category ()
@@ -68,9 +69,9 @@ def _get_category(self):
6869 class_list = read_txt (self .src_data / 'classes.txt' )
6970 for i , category in enumerate (class_list , 1 ):
7071 self .categories .append ({
72+ 'supercategory' : category ,
7173 'id' : i ,
7274 'name' : category ,
73- 'supercategory' : category ,
7475 })
7576
7677 def generate (self ):
@@ -96,7 +97,6 @@ def gen_dataset(self, img_paths, target_img_path, target_json):
9697 """
9798 images = []
9899 annotations = []
99- annotation_id = 1
100100 for img_id , img_path in enumerate (img_paths , 1 ):
101101 img_path = Path (img_path )
102102
@@ -106,7 +106,7 @@ def gen_dataset(self, img_paths, target_img_path, target_json):
106106 label_path = str (img_path .parent .parent
107107 / 'labels' / f'{ img_path .stem } .txt' )
108108
109- imgsrc = cv .imread (str (img_path ))
109+ imgsrc = cv2 .imread (str (img_path ))
110110 height , width = imgsrc .shape [:2 ]
111111
112112 dest_file_name = f'{ img_id :012d} .jpg'
@@ -115,7 +115,7 @@ def gen_dataset(self, img_paths, target_img_path, target_json):
115115 if img_path .suffix .lower () == ".jpg" :
116116 shutil .copyfile (img_path , save_img_path )
117117 else :
118- cv .imwrite (str (save_img_path ), imgsrc )
118+ cv2 .imwrite (str (save_img_path ), imgsrc )
119119
120120 images .append ({
121121 'date_captured' : '2021' ,
@@ -127,8 +127,7 @@ def gen_dataset(self, img_paths, target_img_path, target_json):
127127
128128 if Path (label_path ).exists ():
129129 new_anno = self .read_annotation (label_path , img_id ,
130- height , width ,
131- annotation_id )
130+ height , width )
132131 if len (new_anno ) > 0 :
133132 annotations .extend (new_anno )
134133 else :
@@ -148,10 +147,11 @@ def gen_dataset(self, img_paths, target_img_path, target_json):
148147 json .dump (json_data , f , ensure_ascii = False )
149148
150149 def read_annotation (self , txtfile , img_id ,
151- height , width , annotation_id ):
150+ height , width ):
152151 annotation = []
153152 allinfo = read_txt (txtfile )
154153 for label_info in allinfo :
154+ # 遍历一张图中不同标注对象
155155 label_info = label_info .split (" " )
156156 if len (label_info ) < 5 :
157157 continue
@@ -166,9 +166,9 @@ def read_annotation(self, txtfile, img_id,
166166 'image_id' : img_id ,
167167 'bbox' : bbox ,
168168 'category_id' : int (category_id )+ 1 ,
169- 'id' : annotation_id ,
169+ 'id' : self . annotation_id ,
170170 })
171- annotation_id += 1
171+ self . annotation_id += 1
172172 return annotation
173173
174174 @staticmethod
0 commit comments