@@ -316,12 +316,12 @@ def compile(
316316 }
317317 super ().compile (loss = losses , ** kwargs )
318318
319- def compute_loss (self , images , gt_boxes , gt_classes , training ):
319+ def compute_loss (self , images , boxes , classes , training ):
320320 box_pred , cls_pred = self ._forward (images , training = training )
321- if gt_boxes .shape [- 1 ] != 4 :
321+ if boxes .shape [- 1 ] != 4 :
322322 raise ValueError (
323- "gt_boxes should have shape (None, None, 4). Got "
324- f"gt_boxes .shape={ tuple (gt_boxes .shape )} "
323+ "boxes should have shape (None, None, 4). Got "
324+ f"boxes .shape={ tuple (boxes .shape )} "
325325 )
326326
327327 if box_pred .shape [- 1 ] != 4 :
@@ -338,18 +338,18 @@ def compute_loss(self, images, gt_boxes, gt_classes, training):
338338 )
339339
340340 cls_labels = tf .one_hot (
341- tf .cast (gt_classes , dtype = tf .int32 ),
341+ tf .cast (classes , dtype = tf .int32 ),
342342 depth = self .classes ,
343343 dtype = tf .float32 ,
344344 )
345345
346- positive_mask = tf .cast (tf .greater (gt_classes , - 1.0 ), dtype = tf .float32 )
346+ positive_mask = tf .cast (tf .greater (classes , - 1.0 ), dtype = tf .float32 )
347347 normalizer = tf .reduce_sum (positive_mask )
348- cls_weights = tf .cast (tf .math .not_equal (gt_classes , - 2.0 ), dtype = tf .float32 )
348+ cls_weights = tf .cast (tf .math .not_equal (classes , - 2.0 ), dtype = tf .float32 )
349349 cls_weights /= normalizer
350350 box_weights = positive_mask / normalizer
351351 y_true = {
352- "box" : gt_boxes ,
352+ "box" : boxes ,
353353 "cls" : cls_labels ,
354354 }
355355 y_pred = {
@@ -372,16 +372,16 @@ def train_step(self, data):
372372 target = self .label_encoder .bounding_box_format ,
373373 images = x ,
374374 )
375- gt_boxes , gt_classes = self .label_encoder (x , y )
376- gt_boxes = bounding_box .convert_format (
377- gt_boxes ,
375+ boxes , classes = self .label_encoder (x , y )
376+ boxes = bounding_box .convert_format (
377+ boxes ,
378378 source = self .label_encoder .bounding_box_format ,
379379 target = self .bounding_box_format ,
380380 images = x ,
381381 )
382382
383383 with tf .GradientTape () as tape :
384- total_loss = self .compute_loss (x , gt_boxes , gt_classes , training = True )
384+ total_loss = self .compute_loss (x , boxes , classes , training = True )
385385
386386 reg_losses = []
387387 if self .weight_decay :
@@ -405,14 +405,14 @@ def test_step(self, data):
405405 target = self .label_encoder .bounding_box_format ,
406406 images = x ,
407407 )
408- gt_boxes , gt_classes = self .label_encoder (x , y )
409- gt_boxes = bounding_box .convert_format (
410- gt_boxes ,
408+ boxes , classes = self .label_encoder (x , y )
409+ boxes = bounding_box .convert_format (
410+ boxes ,
411411 source = self .label_encoder .bounding_box_format ,
412412 target = self .bounding_box_format ,
413413 images = x ,
414414 )
415- _ = self .compute_loss (x , gt_boxes , gt_classes , training = False )
415+ _ = self .compute_loss (x , boxes , classes , training = False )
416416
417417 return self .compute_metrics (x , {}, {}, sample_weight = {})
418418
0 commit comments