@@ -508,30 +508,45 @@ def test_targets_all_nans(self, pp_loss):
508508 num_keypoints = 4
509509 num_cam_pairs = 3
510510 keypoints_targ_3d = torch .full ((num_batch , num_keypoints , 3 ), float ('nan' ))
511- keypoints_pred_3d = torch .ones ((num_batch , num_cam_pairs , num_keypoints , 3 ))
511+ keypoints_pred_3d = torch .ones (
512+ (num_batch , num_cam_pairs , num_keypoints , 3 ),
513+ requires_grad = True ,
514+ )
512515 loss , _ = pp_loss (keypoints_targ_3d , keypoints_pred_3d )
513516 assert loss .item () == 0.0
517+ loss .backward ()
518+ assert not torch .isnan (keypoints_pred_3d .grad ).any (), "gradients contain NaN values"
514519
515520 def test_predictions_all_nans (self , pp_loss ):
516521 num_batch = 1
517522 num_keypoints = 4
518523 num_cam_pairs = 3
519524 keypoints_targ_3d = torch .ones ((num_batch , num_keypoints , 3 ))
520- keypoints_pred_3d = torch .full ((num_batch , num_cam_pairs , num_keypoints , 3 ), float ('nan' ))
525+ keypoints_pred_3d = torch .full (
526+ (num_batch , num_cam_pairs , num_keypoints , 3 ), float ('nan' ),
527+ requires_grad = True ,
528+ )
521529 loss , _ = pp_loss (keypoints_targ_3d , keypoints_pred_3d )
522530 assert loss .item () == 0.0
531+ loss .backward ()
532+ assert not torch .isnan (keypoints_pred_3d .grad ).any (), "gradients contain NaN values"
523533
524534 def test_targets_partial_nans (self , pp_loss ):
525535 num_batch = 2
526536 num_keypoints = 4
527537 num_cam_pairs = 2
528538 keypoints_targ_3d = torch .zeros (size = (num_batch , num_keypoints , 3 ))
529539 keypoints_targ_3d [0 , 0 , :] = float ('nan' ) # first keypoint in first batch NaN
530- keypoints_pred_3d = torch .ones (size = (num_batch , num_cam_pairs , num_keypoints , 3 ))
540+ keypoints_pred_3d = torch .ones (
541+ size = (num_batch , num_cam_pairs , num_keypoints , 3 ),
542+ requires_grad = True ,
543+ )
531544 loss , _ = pp_loss (keypoints_targ_3d , keypoints_pred_3d )
532545 # each valid position has loss = sqrt(3) (distance from 0 to 1 in 3D)
533546 expected_loss = torch .sqrt (torch .tensor (3.0 ))
534547 assert loss .isclose (expected_loss )
548+ loss .backward ()
549+ assert not torch .isnan (keypoints_pred_3d .grad ).any (), "gradients contain NaN values"
535550
536551 def test_predictions_partial_nans (self , pp_loss ):
537552 num_batch = 3
@@ -542,10 +557,13 @@ def test_predictions_partial_nans(self, pp_loss):
542557 keypoints_pred_3d [0 , 0 , 0 , :] = float ('nan' )
543558 keypoints_pred_3d [1 , 1 , :, :] = float ('nan' )
544559 keypoints_pred_3d [2 , :, :, :] = float ('nan' )
560+ keypoints_pred_3d .requires_grad_ (True ) # need to do this after inplace operations
545561 loss , _ = pp_loss (keypoints_targ_3d , keypoints_pred_3d )
546562 # each valid position has loss = sqrt(3) (distance from 0 to 1 in 3D)
547563 expected_loss = torch .sqrt (torch .tensor (3.0 ))
548564 assert loss .isclose (expected_loss )
565+ loss .backward ()
566+ assert not torch .isnan (keypoints_pred_3d .grad ).any (), "gradients contain NaN values"
549567
550568
551569def test_get_loss_classes ():
0 commit comments