Skip to content

Commit 1dce647

Browse files
fix NaN gradients in 3d triangulation loss (#358)
1 parent 0b4da26 commit 1dce647

File tree

3 files changed

+64
-8
lines changed

3 files changed

+64
-8
lines changed

lightning_pose/losses/losses.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -776,17 +776,55 @@ def remove_nans(
776776
loss: TensorType["batch", "cam_pairs", "num_keypoints"],
777777
) -> TensorType["valid_losses"]:
778778
mask = ~torch.isnan(loss)
779-
if mask.sum() == 0.0:
780-
return torch.tensor(0.0, device=loss.device, dtype=loss.dtype)
779+
valid_losses = torch.masked_select(loss, mask)
780+
if valid_losses.numel() == 0:
781+
# No valid losses, return zero that preserves gradients
782+
# Use torch.where to avoid nan*0.0 issues
783+
dummy_loss = torch.where(mask, loss, torch.zeros_like(loss))
784+
return dummy_loss.sum() # This will be 0.0 and preserve gradients
781785
else:
782-
return torch.masked_select(loss, ~torch.isnan(loss))
786+
return valid_losses
783787

784788
def compute_loss(
785789
self,
786790
targets: TensorType["batch", "num_keypoints", 3],
787791
predictions: TensorType["batch", "cam_pairs", "num_keypoints", 3],
788792
) -> TensorType["batch", "cam_pairs", "num_keypoints"]:
789-
loss = torch.linalg.norm(targets.unsqueeze(1) - predictions, ord=2, dim=-1)
793+
794+
# Check for NaN targets AND predictions
795+
nan_targets = torch.isnan(targets).any(dim=-1) # [batch, num_keypoints]
796+
nan_predictions = torch.isnan(predictions).any(dim=-1) # [batch, cam_pairs, num_keypoints]
797+
798+
# Expand target NaN mask to match prediction dimensions
799+
nan_targets_expanded = nan_targets.unsqueeze(1) # [batch, 1, num_keypoints]
800+
801+
# Combined NaN mask
802+
combined_nan_mask = \
803+
nan_targets_expanded | nan_predictions # [batch, cam_pairs, num_keypoints]
804+
805+
# Create clean targets and predictions - replace NaNs with zeros and detach
806+
clean_targets = torch.where(
807+
nan_targets.unsqueeze(-1), # [batch, num_keypoints, 1]
808+
torch.zeros_like(targets).detach(),
809+
targets,
810+
)
811+
812+
clean_predictions = torch.where(
813+
combined_nan_mask.unsqueeze(-1), # [batch, cam_pairs, num_keypoints, 1]
814+
torch.zeros_like(predictions).detach(),
815+
predictions,
816+
)
817+
818+
# Compute loss with clean tensors
819+
loss = torch.linalg.norm(clean_targets.unsqueeze(1) - clean_predictions, ord=2, dim=-1)
820+
821+
# Set loss to NaN where either targets or predictions were originally NaN
822+
loss = torch.where(
823+
combined_nan_mask,
824+
torch.tensor(float('nan'), device=loss.device, dtype=loss.dtype),
825+
loss,
826+
)
827+
790828
return loss
791829

792830
def __call__(

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
44

55
[project]
66
name = "lightning-pose"
7-
version = "2.0.3"
7+
version = "2.0.4"
88
description = "Semi-supervised pose estimation using pytorch lightning"
99
license = "MIT"
1010
readme = "README.md"

tests/losses/test_losses.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -508,30 +508,45 @@ def test_targets_all_nans(self, pp_loss):
508508
num_keypoints = 4
509509
num_cam_pairs = 3
510510
keypoints_targ_3d = torch.full((num_batch, num_keypoints, 3), float('nan'))
511-
keypoints_pred_3d = torch.ones((num_batch, num_cam_pairs, num_keypoints, 3))
511+
keypoints_pred_3d = torch.ones(
512+
(num_batch, num_cam_pairs, num_keypoints, 3),
513+
requires_grad=True,
514+
)
512515
loss, _ = pp_loss(keypoints_targ_3d, keypoints_pred_3d)
513516
assert loss.item() == 0.0
517+
loss.backward()
518+
assert not torch.isnan(keypoints_pred_3d.grad).any(), "gradients contain NaN values"
514519

515520
def test_predictions_all_nans(self, pp_loss):
516521
num_batch = 1
517522
num_keypoints = 4
518523
num_cam_pairs = 3
519524
keypoints_targ_3d = torch.ones((num_batch, num_keypoints, 3))
520-
keypoints_pred_3d = torch.full((num_batch, num_cam_pairs, num_keypoints, 3), float('nan'))
525+
keypoints_pred_3d = torch.full(
526+
(num_batch, num_cam_pairs, num_keypoints, 3), float('nan'),
527+
requires_grad=True,
528+
)
521529
loss, _ = pp_loss(keypoints_targ_3d, keypoints_pred_3d)
522530
assert loss.item() == 0.0
531+
loss.backward()
532+
assert not torch.isnan(keypoints_pred_3d.grad).any(), "gradients contain NaN values"
523533

524534
def test_targets_partial_nans(self, pp_loss):
525535
num_batch = 2
526536
num_keypoints = 4
527537
num_cam_pairs = 2
528538
keypoints_targ_3d = torch.zeros(size=(num_batch, num_keypoints, 3))
529539
keypoints_targ_3d[0, 0, :] = float('nan') # first keypoint in first batch NaN
530-
keypoints_pred_3d = torch.ones(size=(num_batch, num_cam_pairs, num_keypoints, 3))
540+
keypoints_pred_3d = torch.ones(
541+
size=(num_batch, num_cam_pairs, num_keypoints, 3),
542+
requires_grad=True,
543+
)
531544
loss, _ = pp_loss(keypoints_targ_3d, keypoints_pred_3d)
532545
# each valid position has loss = sqrt(3) (distance from 0 to 1 in 3D)
533546
expected_loss = torch.sqrt(torch.tensor(3.0))
534547
assert loss.isclose(expected_loss)
548+
loss.backward()
549+
assert not torch.isnan(keypoints_pred_3d.grad).any(), "gradients contain NaN values"
535550

536551
def test_predictions_partial_nans(self, pp_loss):
537552
num_batch = 3
@@ -542,10 +557,13 @@ def test_predictions_partial_nans(self, pp_loss):
542557
keypoints_pred_3d[0, 0, 0, :] = float('nan')
543558
keypoints_pred_3d[1, 1, :, :] = float('nan')
544559
keypoints_pred_3d[2, :, :, :] = float('nan')
560+
keypoints_pred_3d.requires_grad_(True) # need to do this after inplace operations
545561
loss, _ = pp_loss(keypoints_targ_3d, keypoints_pred_3d)
546562
# each valid position has loss = sqrt(3) (distance from 0 to 1 in 3D)
547563
expected_loss = torch.sqrt(torch.tensor(3.0))
548564
assert loss.isclose(expected_loss)
565+
loss.backward()
566+
assert not torch.isnan(keypoints_pred_3d.grad).any(), "gradients contain NaN values"
549567

550568

551569
def test_get_loss_classes():

0 commit comments

Comments
 (0)