@@ -181,18 +181,17 @@ def focal_loss(self, x, y):
181181 alpha = 0.25
182182 gamma = 2
183183
184- t = one_hot_embedding (y .data .cpu (), 1 + self .num_classes ) # [N,21]
185- t = t [:,1 :] # exclude background
184+ t = one_hot_embedding (y .data .cpu (), 1 + self .num_classes ) # [N,21]
185+ t = t [:, 1 :] # exclude background
186186 if torch .cuda .is_available ():
187187 t = Variable (t ).cuda () # [N,20]
188188 else :
189189 t = Variable (t ) # [N,20]
190190
191-
192191 p = x .sigmoid ()
193- pt = p * t + (1 - p ) * ( 1 - t ) # pt = p if t > 0 else 1-p
194- w = alpha * t + (1 - alpha )* ( 1 - t ) # w = alpha if t > 0 else 1-alpha
195- w = w * (1 - pt ).pow (gamma )
192+ pt = p * t + (1 - p ) * ( 1 - t ) # pt = p if t > 0 else 1-p
193+ w = alpha * t + (1 - alpha ) * ( 1 - t ) # w = alpha if t > 0 else 1-alpha
194+ w = w * (1 - pt ).pow (gamma )
196195 return F .binary_cross_entropy_with_logits (x , t , w , size_average = False )
197196
198197 def focal_loss_alt (self , x , y ):
@@ -207,18 +206,18 @@ def focal_loss_alt(self, x, y):
207206 """
208207 alpha = 0.25
209208
210- t = one_hot_embedding (y .data .cpu (), 1 + self .num_classes )
211- t = t [:,1 :]
209+ t = one_hot_embedding (y .data .cpu (), 1 + self .num_classes )
210+ t = t [:, 1 :]
212211 if torch .cuda .is_available ():
213212 t = Variable (t ).cuda () # [N,20]
214213 else :
215214 t = Variable (t ) # [N,20]
216215
217- xt = x * ( 2 * t - 1 ) # xt = x if t > 0 else -x
218- pt = (2 * xt + 1 ).sigmoid ()
216+ xt = x * ( 2 * t - 1 ) # xt = x if t > 0 else -x
217+ pt = (2 * xt + 1 ).sigmoid ()
219218
220- w = alpha * t + (1 - alpha )* ( 1 - t )
221- loss = - w * pt .log () / 2
219+ w = alpha * t + (1 - alpha ) * ( 1 - t )
220+ loss = - w * pt .log () / 2
222221 return loss .sum ()
223222
224223 def forward (self , output , target ):
@@ -239,15 +238,15 @@ def forward(self, output, target):
239238 pos = cls_targets > 0 # [N,#anchors]
240239 num_pos = pos .data .long ().sum ()
241240
242- mask = pos .unsqueeze (2 ).expand_as (loc_preds ) # [N,#anchors,4]
243- masked_loc_preds = loc_preds [mask ].view (- 1 ,4 ) # [#pos,4]
244- masked_loc_targets = loc_targets [mask ].view (- 1 ,4 ) # [#pos,4]
241+ mask = pos .unsqueeze (2 ).expand_as (loc_preds ) # [N,#anchors,4]
242+ masked_loc_preds = loc_preds [mask ].view (- 1 , 4 ) # [#pos,4]
243+ masked_loc_targets = loc_targets [mask ].view (- 1 , 4 ) # [#pos,4]
245244
246245 pos_neg = cls_targets > - 1 # exclude ignored anchors
247246 num_pos_neg = pos_neg .data .long ().sum ()
248247
249248 mask = pos_neg .unsqueeze (2 ).expand_as (cls_preds )
250- masked_cls_preds = cls_preds [mask ].view (- 1 ,self .num_classes )
249+ masked_cls_preds = cls_preds [mask ].view (- 1 , self .num_classes )
251250
252251 loc_loss = F .smooth_l1_loss (masked_loc_preds , masked_loc_targets , size_average = False )
253252 cls_loss = self .focal_loss (masked_cls_preds , cls_targets [pos_neg ])
@@ -264,7 +263,7 @@ def forward(self, output, target):
264263
265264class BaseDataHandler ():
266265 def __init__ (self , aspect_ratios , scale_ratios , num_anchors ):
267- self .anchor_areas = [32 * 32. , 64 * 64. , 128 * 128. , 256 * 256. , 512 * 512. ] # p3 -> p7
266+ self .anchor_areas = [32 * 32. , 64 * 64. , 128 * 128. , 256 * 256. , 512 * 512. ] # p3 -> p7
268267 self .aspect_ratios = aspect_ratios
269268 self .scale_ratios = scale_ratios
270269 self .num_anchors = num_anchors
@@ -279,11 +278,11 @@ def _get_anchor_hw(self):
279278 anchor_hw = []
280279 for s in self .anchor_areas :
281280 for ar in self .aspect_ratios : # w/h = ar
282- h = sqrt (s / ar )
281+ h = sqrt (s / ar )
283282 w = ar * h
284283 for sr in self .scale_ratios : # scale
285- anchor_h = h * sr
286- anchor_w = w * sr
284+ anchor_h = h * sr
285+ anchor_w = w * sr
287286 anchor_hw .append ([anchor_h , anchor_w ])
288287 num_fms = len (self .anchor_areas )
289288 return torch .Tensor (anchor_hw ).view (num_fms , - 1 , 2 )
@@ -299,18 +298,18 @@ def _get_anchor_boxes(self, input_size):
299298 where #anchors = fmw * fmh * #anchors_per_cell
300299 """
301300 num_fms = len (self .anchor_areas )
302- fm_sizes = [(input_size / pow (2. ,i + 3 )).ceil () for i in range (num_fms )] # p3 -> p7 feature map sizes
301+ fm_sizes = [(input_size / pow (2. , i + 3 )).ceil () for i in range (num_fms )] # p3 -> p7 feature map sizes
303302
304303 boxes = []
305304 for i in range (num_fms ):
306305 fm_size = fm_sizes [i ]
307306 grid_size = input_size / fm_size
308307 fm_h , fm_w = int (fm_size [0 ]), int (fm_size [1 ])
309- xy = meshgrid (fm_h ,fm_w ) + 0.5 # [fm_h*fm_w, 2]
310- xy = (xy * grid_size ).view (fm_w ,fm_h ,1 , 2 ).expand (fm_w ,fm_h ,self .num_anchors ,2 )
311- hw = self .anchor_wh [i ].view (1 ,1 , self .num_anchors ,2 ).expand (fm_w ,fm_h ,self .num_anchors ,2 )
312- box = torch .cat ([xy ,hw ], 3 ) # [x,y,w,h]
313- boxes .append (box .view (- 1 ,4 ))
308+ xy = meshgrid (fm_h , fm_w ) + 0.5 # [fm_h*fm_w, 2]
309+ xy = (xy * grid_size ).view (fm_w , fm_h , 1 , 2 ).expand (fm_w , fm_h , self .num_anchors , 2 )
310+ hw = self .anchor_wh [i ].view (1 , 1 , self .num_anchors , 2 ).expand (fm_w , fm_h , self .num_anchors , 2 )
311+ box = torch .cat ([xy , hw ], 3 ) # [x,y,w,h]
312+ boxes .append (box .view (- 1 , 4 ))
314313 return torch .cat (boxes , 0 )
315314
316315
@@ -335,7 +334,7 @@ def encode(self, boxes, labels, input_size):
335334 cls_targets: (tensor) encoded class labels, sized [#anchors,].
336335 """
337336 input_size = torch .Tensor ([input_size , input_size ]) if isinstance (input_size , int ) \
338- else torch .Tensor (input_size )
337+ else torch .Tensor (input_size )
339338 anchor_boxes = self ._get_anchor_boxes (input_size )
340339
341340 if len (boxes ) > 0 :
@@ -345,13 +344,13 @@ def encode(self, boxes, labels, input_size):
345344 max_ious , max_ids = ious .max (1 )
346345 boxes = boxes [max_ids ]
347346
348- loc_xy = (boxes [:,:2 ]- anchor_boxes [:,:2 ]) / anchor_boxes [:,2 :]
349- loc_hw = torch .log (boxes [:,2 :]/ anchor_boxes [:,2 :])
350- loc_targets = torch .cat ([loc_xy ,loc_hw ], 1 )
347+ loc_xy = (boxes [:, :2 ] - anchor_boxes [:, :2 ]) / anchor_boxes [:, 2 :]
348+ loc_hw = torch .log (boxes [:, 2 :] / anchor_boxes [:, 2 :])
349+ loc_targets = torch .cat ([loc_xy , loc_hw ], 1 )
351350 cls_targets = labels [max_ids ]
352351
353- cls_targets [max_ious < 0.5 ] = 0
354- ignore = (max_ious > 0.4 ) & (max_ious < 0.5 ) # ignore ious between [0.4,0.5]
352+ cls_targets [max_ious < 0.5 ] = 0
353+ ignore = (max_ious > 0.4 ) & (max_ious < 0.5 ) # ignore ious between [0.4,0.5]
355354 cls_targets [ignore ] = - 1 # for now just mark ignored to -1
356355 else :
357356 loc_targets = torch .zeros (len (anchor_boxes ), 4 )
@@ -386,25 +385,25 @@ def decode(self, loc_preds, cls_preds, input_size):
386385 CLS_THRESH = 0.5
387386 NMS_THRESH = 0.5
388387
389- input_size = torch .Tensor ([input_size ,input_size ]) if isinstance (input_size , int ) \
390- else torch .Tensor (input_size )
388+ input_size = torch .Tensor ([input_size , input_size ]) if isinstance (input_size , int ) \
389+ else torch .Tensor (input_size )
391390 anchor_boxes = self ._get_anchor_boxes (input_size )
392391
393- loc_xy = loc_preds [:,:2 ]
394- loc_hw = loc_preds [:,2 :]
392+ loc_xy = loc_preds [:, :2 ]
393+ loc_hw = loc_preds [:, 2 :]
395394
396- xy = loc_xy * anchor_boxes [:,2 :] + anchor_boxes [:,:2 ]
397- wh = loc_hw .exp () * anchor_boxes [:,2 :]
398- boxes = torch .cat ([xy - wh / 2 , xy + wh / 2 ], 1 ) # [#anchors,4]
395+ xy = loc_xy * anchor_boxes [:, 2 :] + anchor_boxes [:, :2 ]
396+ wh = loc_hw .exp () * anchor_boxes [:, 2 :]
397+ boxes = torch .cat ([xy - wh / 2 , xy + wh / 2 ], 1 ) # [#anchors,4]
399398
400- score , labels = cls_preds .sigmoid ().max (1 ) # [#anchors,]
399+ score , labels = cls_preds .sigmoid ().max (1 ) # [#anchors,]
401400 labels += 1
402401 ids = score > CLS_THRESH
403- ids = ids .nonzero ().squeeze () # [#obj,]
402+ ids = ids .nonzero ().squeeze () # [#obj,]
404403 if len (ids ) == 0 :
405- return torch .Tensor ([]), torch .Tensor ([])
404+ return torch .Tensor ([]), torch .Tensor ([]), torch . Tensor ([])
406405 keep = box_nms (boxes [ids ], score [ids ], threshold = NMS_THRESH )
407- return boxes [ids ][keep ], labels [ids ][keep ]
406+ return boxes [ids ][keep ], labels [ids ][keep ], score [ ids ][ keep ]
408407
409408
410409def one_hot_embedding (labels , num_classes ):
@@ -421,7 +420,7 @@ def one_hot_embedding(labels, num_classes):
421420 https://github.com/kuangliu/pytorch-retinanet
422421 """
423422 y = torch .eye (num_classes ) # [D,D]
424- return y [labels ] # [N,D]
423+ return y [labels ] # [N,D]
425424
426425
427426def meshgrid (x , y , row_major = True ):
@@ -457,11 +456,11 @@ def meshgrid(x, y, row_major=True):
457456 Reference:
458457 https://github.com/kuangliu/pytorch-retinanet
459458 """
460- a = torch .arange (0 ,x )
461- b = torch .arange (0 ,y )
462- xx = a .repeat (y ).view (- 1 ,1 )
463- yy = b .view (- 1 ,1 ).repeat (1 ,x ).view (- 1 ,1 )
464- return torch .cat ([xx ,yy ],1 ) if row_major else torch .cat ([yy ,xx ],1 )
459+ a = torch .arange (0 , x )
460+ b = torch .arange (0 , y )
461+ xx = a .repeat (y ).view (- 1 , 1 )
462+ yy = b .view (- 1 , 1 ).repeat (1 , x ).view (- 1 , 1 )
463+ return torch .cat ([xx , yy ], 1 ) if row_major else torch .cat ([yy , xx ], 1 )
465464
466465
467466def change_box_order (boxes , order ):
@@ -477,12 +476,12 @@ def change_box_order(boxes, order):
477476 Reference:
478477 https://github.com/kuangliu/pytorch-retinanet
479478 """
480- assert order in ['xyxy2xywh' ,'xywh2xyxy' ]
481- a = boxes [:,:2 ]
482- b = boxes [:,2 :]
479+ assert order in ['xyxy2xywh' , 'xywh2xyxy' ]
480+ a = boxes [:, :2 ]
481+ b = boxes [:, 2 :]
483482 if order == 'xyxy2xywh' :
484- return torch .cat ([(a + b ) / 2 , b - a + 1 ], 1 )
485- return torch .cat ([a - b / 2 , a + b / 2 ], 1 )
483+ return torch .cat ([(a + b ) / 2 , b - a + 1 ], 1 )
484+ return torch .cat ([a - b / 2 , a + b / 2 ], 1 )
486485
487486
488487def box_iou (box1 , box2 , order = 'xyxy' ):
@@ -509,15 +508,15 @@ def box_iou(box1, box2, order='xyxy'):
509508 N = box1 .size (0 )
510509 M = box2 .size (0 )
511510
512- lt = torch .max (box1 [:,None ,:2 ], box2 [:,:2 ]) # [N,M,2]
513- rb = torch .min (box1 [:,None ,2 :], box2 [:,2 :]) # [N,M,2]
511+ lt = torch .max (box1 [:, None , :2 ], box2 [:, :2 ]) # [N,M,2]
512+ rb = torch .min (box1 [:, None , 2 :], box2 [:, 2 :]) # [N,M,2]
514513
515- wh = (rb - lt + 1 ).clamp (min = 0 ) # [N,M,2]
516- inter = wh [:,:, 0 ] * wh [:,:, 1 ] # [N,M]
514+ wh = (rb - lt + 1 ).clamp (min = 0 ) # [N,M,2]
515+ inter = wh [:, :, 0 ] * wh [:, :, 1 ] # [N,M]
517516
518- area1 = (box1 [:,2 ] - box1 [:,0 ] + 1 ) * (box1 [:,3 ] - box1 [:,1 ] + 1 ) # [N,]
519- area2 = (box2 [:,2 ] - box2 [:,0 ] + 1 ) * (box2 [:,3 ] - box2 [:,1 ] + 1 ) # [M,]
520- iou = inter / (area1 [:,None ] + area2 - inter )
517+ area1 = (box1 [:, 2 ] - box1 [:, 0 ] + 1 ) * (box1 [:, 3 ] - box1 [:, 1 ] + 1 ) # [N,]
518+ area2 = (box2 [:, 2 ] - box2 [:, 0 ] + 1 ) * (box2 [:, 3 ] - box2 [:, 1 ] + 1 ) # [M,]
519+ iou = inter / (area1 [:, None ] + area2 - inter )
521520 return iou
522521
523522
@@ -537,12 +536,12 @@ def box_nms(bboxes, scores, threshold=0.5, mode='union'):
537536 Reference:
538537 https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/nms/py_cpu_nms.py
539538 """
540- x1 = bboxes [:,0 ]
541- y1 = bboxes [:,1 ]
542- x2 = bboxes [:,2 ]
543- y2 = bboxes [:,3 ]
539+ x1 = bboxes [:, 0 ]
540+ y1 = bboxes [:, 1 ]
541+ x2 = bboxes [:, 2 ]
542+ y2 = bboxes [:, 3 ]
544543
545- areas = (x2 - x1 + 1 ) * (y2 - y1 + 1 )
544+ areas = (x2 - x1 + 1 ) * (y2 - y1 + 1 )
546545 _ , order = scores .sort (0 , descending = True )
547546
548547 keep = []
@@ -558,9 +557,9 @@ def box_nms(bboxes, scores, threshold=0.5, mode='union'):
558557 xx2 = x2 [order [1 :]].clamp (max = x2 [i ])
559558 yy2 = y2 [order [1 :]].clamp (max = y2 [i ])
560559
561- w = (xx2 - xx1 + 1 ).clamp (min = 0 )
562- h = (yy2 - yy1 + 1 ).clamp (min = 0 )
563- inter = w * h
560+ w = (xx2 - xx1 + 1 ).clamp (min = 0 )
561+ h = (yy2 - yy1 + 1 ).clamp (min = 0 )
562+ inter = w * h
564563
565564 if mode == 'union' :
566565 ovr = inter / (areas [i ] + areas [order [1 :]] - inter )
@@ -569,8 +568,8 @@ def box_nms(bboxes, scores, threshold=0.5, mode='union'):
569568 else :
570569 raise TypeError ('Unknown nms mode: %s.' % mode )
571570
572- ids = (ovr <= threshold ).nonzero ().squeeze ()
571+ ids = (ovr <= threshold ).nonzero ().squeeze ()
573572 if ids .numel () == 0 :
574573 break
575- order = order [ids + 1 ]
574+ order = order [ids + 1 ]
576575 return torch .LongTensor (keep )
0 commit comments