1- #-------------------------------------------------------------------------------------------------#
1+ #------------------------------------------------------------------------------------------------------- #
22# kmeans虽然会对数据集中的框进行聚类,但是很多数据集由于框的大小相近,聚类出来的9个框相差不大,
3- # 这样的框反而不利于模型的训练。因为不同的特征层适合不同大小的先验框,越浅的特征层适合越大的先验框
3+ # 这样的框反而不利于模型的训练。因为不同的特征层适合不同大小的先验框,shape越小的特征层适合越大的先验框
44# 原始网络的先验框已经按大中小比例分配好了,不进行聚类也会有非常好的效果。
5- #-------------------------------------------------------------------------------------------------#
5+ #------------------------------------------------------------------------------------------------------- #
66import glob
77import xml .etree .ElementTree as ET
88
9+ import matplotlib .pyplot as plt
910import numpy as np
11+ from tqdm import tqdm
1012
11- def cas_iou (box ,cluster ):
12- x = np .minimum (cluster [:,0 ],box [0 ])
13- y = np .minimum (cluster [:,1 ],box [1 ])
13+
14+ def cas_iou (box , cluster ):
15+ x = np .minimum (cluster [:, 0 ], box [0 ])
16+ y = np .minimum (cluster [:, 1 ], box [1 ])
1417
1518 intersection = x * y
1619 area1 = box [0 ] * box [1 ]
1720
1821 area2 = cluster [:,0 ] * cluster [:,1 ]
19- iou = intersection / (area1 + area2 - intersection )
22+ iou = intersection / (area1 + area2 - intersection )
2023
2124 return iou
2225
23- def avg_iou (box ,cluster ):
24- return np .mean ([np .max (cas_iou (box [i ],cluster )) for i in range (box .shape [0 ])])
26+ def avg_iou (box , cluster ):
27+ return np .mean ([np .max (cas_iou (box [i ], cluster )) for i in range (box .shape [0 ])])
2528
26- def kmeans (box ,k ):
29+ def kmeans (box , k ):
2730 #-------------------------------------------------------------#
2831 # 取出一共有多少框
2932 #-------------------------------------------------------------#
@@ -32,30 +35,32 @@ def kmeans(box,k):
3235 #-------------------------------------------------------------#
3336 # 每个框各个点的位置
3437 #-------------------------------------------------------------#
35- distance = np .empty ((row ,k ))
38+ distance = np .empty ((row , k ))
3639
3740 #-------------------------------------------------------------#
3841 # 最后的聚类位置
3942 #-------------------------------------------------------------#
40- last_clu = np .zeros ((row ,))
43+ last_clu = np .zeros ((row , ))
4144
4245 np .random .seed ()
4346
4447 #-------------------------------------------------------------#
4548 # 随机选5个当聚类中心
4649 #-------------------------------------------------------------#
47- cluster = box [np .random .choice (row ,k ,replace = False )]
50+ cluster = box [np .random .choice (row , k , replace = False )]
51+
52+ iter = 0
4853 while True :
4954 #-------------------------------------------------------------#
50- # 计算每一行距离五个点的iou情况。
55+ # 计算当前框和先验框的宽高比例
5156 #-------------------------------------------------------------#
5257 for i in range (row ):
53- distance [i ] = 1 - cas_iou (box [i ],cluster )
58+ distance [i ] = 1 - cas_iou (box [i ], cluster )
5459
5560 #-------------------------------------------------------------#
5661 # 取出最小点
5762 #-------------------------------------------------------------#
58- near = np .argmin (distance ,axis = 1 )
63+ near = np .argmin (distance , axis = 1 )
5964
6065 if (last_clu == near ).all ():
6166 break
@@ -68,18 +73,21 @@ def kmeans(box,k):
6873 box [near == j ],axis = 0 )
6974
7075 last_clu = near
76+ if iter % 5 == 0 :
77+ print ('iter: {:d}. avg_iou:{:.2f}' .format (iter , avg_iou (box , cluster )))
78+ iter += 1
7179
72- return cluster
80+ return cluster , near
7381
7482def load_data (path ):
7583 data = []
7684 #-------------------------------------------------------------#
7785 # 对于每一个xml都寻找box
7886 #-------------------------------------------------------------#
79- for xml_file in glob .glob ('{}/*xml' .format (path )):
80- tree = ET .parse (xml_file )
81- height = int (tree .findtext ('./size/height' ))
82- width = int (tree .findtext ('./size/width' ))
87+ for xml_file in tqdm ( glob .glob ('{}/*xml' .format (path ) )):
88+ tree = ET .parse (xml_file )
89+ height = int (tree .findtext ('./size/height' ))
90+ width = int (tree .findtext ('./size/width' ))
8391 if height <= 0 or width <= 0 :
8492 continue
8593
@@ -97,42 +105,59 @@ def load_data(path):
97105 xmax = np .float64 (xmax )
98106 ymax = np .float64 (ymax )
99107 # 得到宽高
100- data .append ([xmax - xmin ,ymax - ymin ])
108+ data .append ([xmax - xmin , ymax - ymin ])
101109 return np .array (data )
102110
103-
104111if __name__ == '__main__' :
112+ np .random .seed (0 )
105113 #-------------------------------------------------------------#
106114 # 运行该程序会计算'./VOCdevkit/VOC2007/Annotations'的xml
107115 # 会生成yolo_anchors.txt
108116 #-------------------------------------------------------------#
109- SIZE = 416
117+ input_shape = [ 416 , 416 ]
110118 anchors_num = 9
111119 #-------------------------------------------------------------#
112120 # 载入数据集,可以使用VOC的xml
113121 #-------------------------------------------------------------#
114- path = r'./ VOCdevkit/VOC2007/Annotations'
122+ path = ' VOCdevkit/VOC2007/Annotations'
115123
116124 #-------------------------------------------------------------#
117125 # 载入所有的xml
118126 # 存储格式为转化为比例后的width,height
119127 #-------------------------------------------------------------#
128+ print ('Load xmls.' )
120129 data = load_data (path )
130+ print ('Load xmls done.' )
121131
122132 #-------------------------------------------------------------#
123133 # 使用k聚类算法
124134 #-------------------------------------------------------------#
125- out = kmeans (data ,anchors_num )
126- out = out [np .argsort (out [:,0 ])]
127- print ('acc:{:.2f}%' .format (avg_iou (data ,out ) * 100 ))
128- print (out * SIZE )
129- data = out * SIZE
135+ print ('K-means boxes.' )
136+ cluster , near = kmeans (data , anchors_num )
137+ print ('K-means boxes done.' )
138+ data = data * np .array ([input_shape [1 ], input_shape [0 ]])
139+ cluster = cluster * np .array ([input_shape [1 ], input_shape [0 ]])
140+
141+ #-------------------------------------------------------------#
142+ # 绘图
143+ #-------------------------------------------------------------#
144+ for j in range (anchors_num ):
145+ plt .scatter (data [near == j ][:,0 ], data [near == j ][:,1 ])
146+ plt .scatter (cluster [j ][0 ], cluster [j ][1 ], marker = 'x' , c = 'black' )
147+ plt .show ()
148+ plt .savefig ("kmeans_for_anchors.jpg" )
149+ print ('Save kmeans_for_anchors.jpg in root dir.' )
150+
151+ cluster = cluster [np .argsort (cluster [:, 0 ] * cluster [:, 1 ])]
152+ print ('avg_ratio:{:.2f}' .format (avg_iou (data , cluster )))
153+ print (data )
154+
130155 f = open ("yolo_anchors.txt" , 'w' )
131- row = np .shape (data )[0 ]
156+ row = np .shape (cluster )[0 ]
132157 for i in range (row ):
133158 if i == 0 :
134- x_y = "%d,%d" % (data [i ][0 ], data [i ][1 ])
159+ x_y = "%d,%d" % (cluster [i ][0 ], cluster [i ][1 ])
135160 else :
136- x_y = ", %d,%d" % (data [i ][0 ], data [i ][1 ])
161+ x_y = ", %d,%d" % (cluster [i ][0 ], cluster [i ][1 ])
137162 f .write (x_y )
138163 f .close ()
0 commit comments