1010class SingleShotMultiBoxDetector :
1111 """
1212 """
13+
1314 available_type = ["ssd300" , "ssd512" ]
1415 available_net = ["vgg16" , "resnet50" , "xception" ]
1516
1617 ar_presets = dict (
17- ssd300 = [[2. , 1 / 2. ],
18- [2. , 1 / 2. , 3. , 1 / 3. ],
19- [2. , 1 / 2. , 3. , 1 / 3. ],
20- [2. , 1 / 2. , 3. , 1 / 3. ],
21- [2. , 1 / 2. ],
22- [2. , 1 / 2. ]],
23- ssd512 = [[2. , 1 / 2. ],
24- [2. , 1 / 2. , 3. , 1 / 3. ],
25- [2. , 1 / 2. , 3. , 1 / 3. ],
26- [2. , 1 / 2. , 3. , 1 / 3. ],
27- [2. , 1 / 2. , 3. , 1 / 3. ],
28- [2. , 1 / 2. ],
29- [2. , 1 / 2. ]]
18+ ssd300 = [
19+ [2.0 , 1 / 2.0 ],
20+ [2.0 , 1 / 2.0 , 3.0 , 1 / 3.0 ],
21+ [2.0 , 1 / 2.0 , 3.0 , 1 / 3.0 ],
22+ [2.0 , 1 / 2.0 , 3.0 , 1 / 3.0 ],
23+ [2.0 , 1 / 2.0 ],
24+ [2.0 , 1 / 2.0 ],
25+ ],
26+ ssd512 = [
27+ [2.0 , 1 / 2.0 ],
28+ [2.0 , 1 / 2.0 , 3.0 , 1 / 3.0 ],
29+ [2.0 , 1 / 2.0 , 3.0 , 1 / 3.0 ],
30+ [2.0 , 1 / 2.0 , 3.0 , 1 / 3.0 ],
31+ [2.0 , 1 / 2.0 , 3.0 , 1 / 3.0 ],
32+ [2.0 , 1 / 2.0 ],
33+ [2.0 , 1 / 2.0 ],
34+ ],
3035 )
3136 scale_presets = dict (
32- ssd300 = [(30. , 60. ),
33- (60. , 111. ),
34- (111. , 162. ),
35- (162. , 213. ),
36- (213. , 264. ),
37- (264. , 315. )],
38- ssd512 = [(20.48 , 51.2 ),
39- (51.2 , 133.12 ),
40- (133.12 , 215.04 ),
41- (215.04 , 296.96 ),
42- (296.96 , 378.88 ),
43- (378.88 , 460.8 ),
44- (460.8 , 542.72 )]
45- )
46- default_shapes = dict (
47- ssd300 = (300 , 300 , 3 )
37+ ssd300 = [
38+ (30.0 , 60.0 ),
39+ (60.0 , 111.0 ),
40+ (111.0 , 162.0 ),
41+ (162.0 , 213.0 ),
42+ (213.0 , 264.0 ),
43+ (264.0 , 315.0 ),
44+ ],
45+ ssd512 = [
46+ (20.48 , 51.2 ),
47+ (51.2 , 133.12 ),
48+ (133.12 , 215.04 ),
49+ (215.04 , 296.96 ),
50+ (296.96 , 378.88 ),
51+ (378.88 , 460.8 ),
52+ (460.8 , 542.72 ),
53+ ],
4854 )
55+ default_shapes = dict (ssd300 = (300 , 300 , 3 ))
4956
50- def __init__ (self , n_classes = 1 , class_names = ["bg" ], input_shape = None ,
51- aspect_ratios = None , scales = None , variances = None ,
52- overlap_threshold = 0.5 , nms_threshold = 0.45 ,
53- max_output_size = 400 ,
54- model_type = "ssd300" , base_net = "vgg16" ):
57+ def __init__ (
58+ self ,
59+ n_classes = 1 ,
60+ class_names = ["bg" ],
61+ input_shape = None ,
62+ aspect_ratios = None ,
63+ scales = None ,
64+ variances = None ,
65+ overlap_threshold = 0.5 ,
66+ nms_threshold = 0.45 ,
67+ max_output_size = 400 ,
68+ model_type = "ssd300" ,
69+ base_net = "vgg16" ,
70+ ):
5571 """
5672 """
5773 self .n_classes = n_classes
5874 self .class_names = class_names
5975 if "bg" != class_names [0 ]:
60- print ("Warning: Fist label should be bg."
61- " It'll be added automatically." )
76+ print ("Warning: Fist label should be bg." " It'll be added automatically." )
6277 self .class_names = ["bg" ] + class_names
6378 self .n_classes += 1
6479 if input_shape :
@@ -101,86 +116,93 @@ def build(self, init_weight="keras_imagenet"):
101116 """
102117 # create network
103118 if self .model_type == "ssd300" and self .base_net == "vgg16" :
104- self .model , priors = SSD300_vgg16 (self .input_shape ,
105- self .n_classes ,
106- self .aspect_ratios ,
107- self .scales )
119+ self .model , priors = SSD300_vgg16 (
120+ self .input_shape , self .n_classes , self .aspect_ratios , self .scales
121+ )
108122 elif self .model_type == "ssd300" and self .base_net == "resnet50" :
109- self .model , priors = SSD300_resnet50 (self .input_shape ,
110- self .n_classes ,
111- self .aspect_ratios ,
112- self .scales )
123+ self .model , priors = SSD300_resnet50 (
124+ self .input_shape , self .n_classes , self .aspect_ratios , self .scales
125+ )
113126 elif self .model_type == "ssd300" and self .base_net == "xception" :
114- self .model , priors = SSD300_xception (self .input_shape ,
115- self .n_classes ,
116- self .aspect_ratios ,
117- self .scales )
127+ self .model , priors = SSD300_xception (
128+ self .input_shape , self .n_classes , self .aspect_ratios , self .scales
129+ )
118130 elif self .model_type == "ssd512" and self .base_net == "vgg16" :
119- self .model , priors = SSD512_vgg16 (self .input_shape ,
120- self .n_classes ,
121- self .aspect_ratios ,
122- self .scales )
131+ self .model , priors = SSD512_vgg16 (
132+ self .input_shape , self .n_classes , self .aspect_ratios , self .scales
133+ )
123134 elif self .model_type == "ssd512" and self .base_net == "resnet50" :
124- self .model , priors = SSD512_resnet50 (self .input_shape ,
125- self .n_classes ,
126- self .aspect_ratios ,
127- self .scales )
135+ self .model , priors = SSD512_resnet50 (
136+ self .input_shape , self .n_classes , self .aspect_ratios , self .scales
137+ )
128138 else :
129139 raise NameError (
130140 "{},{} is not defined. types are {}, basenets are {}." .format (
131- self .model_type , self .base_net ,
132- self .available_type , self .available_net
141+ self .model_type ,
142+ self .base_net ,
143+ self .available_type ,
144+ self .available_net ,
133145 )
134146 )
135147
136148 if init_weight is None :
137149 print ("Network has not initialized with any pretrained models." )
138150 elif init_weight == "keras_imagenet" :
139- print ("Initializing network with keras application model"
140- " pretrained imagenet." )
151+ print (
152+ "Initializing network with keras application model"
153+ " pretrained imagenet."
154+ )
141155 if self .base_net == "vgg16" :
142156 import keras .applications .vgg16 as keras_vgg16
157+
143158 weights_path = keras_vgg16 .get_file (
144- ' vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5' ,
159+ " vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5" ,
145160 keras_vgg16 .WEIGHTS_PATH_NO_TOP ,
146- cache_subdir = "models"
161+ cache_subdir = "models" ,
147162 )
148163 elif self .base_net == "resnet50" :
149164 import keras .applications .resnet50 as keras_resnet50
165+
150166 weights_path = keras_resnet50 .get_file (
151- ' resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' ,
167+ " resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5" ,
152168 keras_resnet50 .WEIGHTS_PATH_NO_TOP ,
153- cache_subdir = "models"
169+ cache_subdir = "models" ,
154170 )
155171 elif self .base_net == "xception" :
156172 import keras .applications .xception as keras_xception
173+
157174 weights_path = keras_xception .get_file (
158- ' xception_weights_tf_dim_ordering_tf_kernels_notop.h5' ,
175+ " xception_weights_tf_dim_ordering_tf_kernels_notop.h5" ,
159176 keras_xception .TF_WEIGHTS_PATH_NO_TOP ,
160- cache_subdir = "models"
177+ cache_subdir = "models" ,
161178 )
162179 else :
163- raise NameError (
164- "{} is not defined." .format (
165- self .base_net
166- )
167- )
180+ raise NameError ("{} is not defined." .format (self .base_net ))
168181 self .model .load_weights (weights_path , by_name = True )
169182 else :
170183 print ("Initializing network from file {}." .format (init_weight ))
171184 self .model .load_weights (init_weight , by_name = True )
172185
173186 # make boundary box class
174- self .bboxes = BoundaryBox (n_classes = self .n_classes ,
175- default_boxes = priors ,
176- variances = self .variances ,
177- overlap_threshold = self .overlap_threshold ,
178- nms_threshold = self .nms_threshold ,
179- max_output_size = self .max_output_size )
187+ self .bboxes = BoundaryBox (
188+ n_classes = self .n_classes ,
189+ default_boxes = priors ,
190+ variances = self .variances ,
191+ overlap_threshold = self .overlap_threshold ,
192+ nms_threshold = self .nms_threshold ,
193+ max_output_size = self .max_output_size ,
194+ )
180195
181- def train_by_generator (self , gen , epoch = 30 , neg_pos_ratio = 3.0 ,
182- learning_rate = 1e-3 , freeze = None , checkpoints = None ,
183- optimizer = None ):
196+ def train_by_generator (
197+ self ,
198+ gen ,
199+ epoch = 30 ,
200+ neg_pos_ratio = 3.0 ,
201+ learning_rate = 1e-3 ,
202+ freeze = None ,
203+ checkpoints = None ,
204+ optimizer = None ,
205+ ):
184206 """
185207 """
186208 # set freeze layers
@@ -196,14 +218,13 @@ def train_by_generator(self, gen, epoch=30, neg_pos_ratio=3.0,
196218 if checkpoints :
197219 callbacks .append (
198220 keras .callbacks .ModelCheckpoint (
199- checkpoints ,
200- verbose = 1 ,
201- save_weights_only = True
202- ),
221+ checkpoints , verbose = 1 , save_weights_only = True
222+ )
203223 )
204224
205225 def schedule (epoch , decay = 0.9 ):
206- return learning_rate * decay ** (epoch )
226+ return learning_rate * decay ** (epoch )
227+
207228 callbacks .append (keras .callbacks .LearningRateScheduler (schedule ))
208229
209230 if optimizer is None :
@@ -216,20 +237,17 @@ def schedule(epoch, decay=0.9):
216237
217238 self .model .compile (
218239 optimizer = optim ,
219- loss = MultiBoxLoss (
220- self .n_classes ,
221- neg_pos_ratio = neg_pos_ratio
222- ).compute_loss
240+ loss = MultiBoxLoss (self .n_classes , neg_pos_ratio = neg_pos_ratio ).compute_loss ,
223241 )
224242 history = self .model .fit_generator (
225243 gen .generate (self .preprocesser , True ),
226- int (gen .train_batches / gen .batch_size ),
244+ int (gen .train_batches / gen .batch_size ),
227245 epochs = epoch ,
228246 verbose = 1 ,
229247 callbacks = callbacks ,
230248 validation_data = gen .generate (self .preprocesser , False ),
231- validation_steps = int (gen .val_batches / gen .batch_size ),
232- workers = 1
249+ validation_steps = int (gen .val_batches / gen .batch_size ),
250+ workers = 1 ,
233251 )
234252
235253 return history
@@ -245,7 +263,7 @@ def save_parameters(self, filepath="./param.json"):
245263 base_net = self .base_net ,
246264 aspect_ratios = self .aspect_ratios ,
247265 scales = self .scales ,
248- variances = self .variances
266+ variances = self .variances ,
249267 )
250268 print ("Writing parameters into {}." .format (filepath ))
251269 json .dump (params , open (filepath , "w" ), indent = 4 , sort_keys = True )
@@ -264,23 +282,27 @@ def load_parameters(self, filepath):
264282 self .scales = params ["scales" ]
265283 self .variances = params ["variances" ]
266284
267- def detect (self , X , batch_size = 1 , verbose = 0 ,
268- keep_top_k = 200 , confidence_threshold = 0.01 ,
269- do_preprocess = True ):
285+ def detect (
286+ self ,
287+ X ,
288+ batch_size = 1 ,
289+ verbose = 0 ,
290+ keep_top_k = 200 ,
291+ confidence_threshold = 0.01 ,
292+ do_preprocess = True ,
293+ ):
270294 """
271295 """
272296 if do_preprocess :
273297 inputs = self .preprocesser (X .copy ())
274298 else :
275299 inputs = X .copy ()
276300
277- predictions = self .model .predict (inputs ,
278- batch_size = batch_size ,
279- verbose = verbose )
301+ predictions = self .model .predict (inputs , batch_size = batch_size , verbose = verbose )
280302 detections = self .bboxes .detection_out (
281303 predictions ,
282304 keep_top_k = keep_top_k ,
283- confidence_threshold = confidence_threshold
305+ confidence_threshold = confidence_threshold ,
284306 )
285307
286308 return detections
0 commit comments