diff --git a/docs/api/lightning_pose.utils.io.split_video_files_by_view.rst b/docs/api/lightning_pose.utils.io.split_video_files_by_view.rst new file mode 100644 index 00000000..94261cbc --- /dev/null +++ b/docs/api/lightning_pose.utils.io.split_video_files_by_view.rst @@ -0,0 +1,6 @@ +split_video_files_by_view +========================= + +.. currentmodule:: lightning_pose.utils.io + +.. autofunction:: split_video_files_by_view diff --git a/lightning_pose/data/cameras.py b/lightning_pose/data/cameras.py index 8a7f0ac8..51bf85fa 100644 --- a/lightning_pose/data/cameras.py +++ b/lightning_pose/data/cameras.py @@ -12,7 +12,6 @@ # to ignore imports for sphix-autoapidoc __all__ = [ "project_camera_pairs_to_3d", - "get_valid_projection_masks", "CameraGroup", ] @@ -81,20 +80,6 @@ def project_camera_pairs_to_3d( return torch.stack(p3d, dim=1) -def get_valid_projection_masks( - points: TensorType["batch", "num_views", "num_keypoints", 2] -) -> TensorType["batch", "cam_pair", "num_keypoints"]: - - num_batch, num_views, num_keypoints, _ = points.shape - - m3d = [] - for j1, j2 in itertools.combinations(range(num_views), 2): - points1 = points[:, j1, :, 0] - points2 = points[:, j2, :, 0] - m3d.append(~torch.isnan(points1 + points2)) - return torch.stack(m3d, dim=1) - - class CameraGroup(CameraGroupAnipose): """Inherit Anipose camera group and add new non-jitted triangulation method for dataloaders.""" diff --git a/lightning_pose/losses/losses.py b/lightning_pose/losses/losses.py index 863b4b2c..f11c1b43 100644 --- a/lightning_pose/losses/losses.py +++ b/lightning_pose/losses/losses.py @@ -774,9 +774,12 @@ def __init__(self, log_weight: float = 0.0, **kwargs) -> None: def remove_nans( self, loss: TensorType["batch", "cam_pairs", "num_keypoints"], - mask: TensorType["batch", "cam_pairs", "num_keypoints"], ) -> TensorType["valid_losses"]: - return torch.masked_select(loss, mask) + mask = ~torch.isnan(loss) + if mask.sum() == 0.0: + return torch.tensor(0.0, device=loss.device, dtype=loss.dtype) + else: + return torch.masked_select(loss, ~torch.isnan(loss)) def compute_loss( self, @@ -790,32 +793,24 @@ def __call__( self, keypoints_targ_3d: TensorType["batch", "num_keypoints", 3], keypoints_pred_3d: TensorType["batch", "cam_pairs", "num_keypoints", 3], - keypoints_mask_3d: TensorType["batch", "cam_pairs", "num_keypoints"], stage: Literal["train", "val", "test"] | None = None, **kwargs, ) -> Tuple[TensorType[()], list[dict]]: # check if 3D keypoints are available - if keypoints_targ_3d is None or keypoints_pred_3d is None or keypoints_mask_3d is None: + if keypoints_targ_3d is None or keypoints_pred_3d is None: raise ValueError( f"3D keypoints not available for {stage} stage. " "Camera params file is required but not found;" "Turn off supervised_pairwise_projections loss to avoid this error." ) - if keypoints_mask_3d.sum() == 0: - scalar_loss = torch.tensor( - 0.0, - device=keypoints_targ_3d.device, - dtype=keypoints_targ_3d.dtype, - ) - else: - elementwise_loss = self.compute_loss( - targets=keypoints_targ_3d, - predictions=keypoints_pred_3d, - ) - clean_loss = self.remove_nans(loss=elementwise_loss, mask=keypoints_mask_3d) - scalar_loss = self.reduce_loss(clean_loss, method="mean") + elementwise_loss = self.compute_loss( + targets=keypoints_targ_3d, + predictions=keypoints_pred_3d, + ) + clean_loss = self.remove_nans(loss=elementwise_loss) + scalar_loss = self.reduce_loss(clean_loss, method="mean") logs = self.log_loss(loss=scalar_loss, stage=stage) diff --git a/lightning_pose/models/heatmap_tracker_multiview.py b/lightning_pose/models/heatmap_tracker_multiview.py index b7e12849..bd35a611 100644 --- a/lightning_pose/models/heatmap_tracker_multiview.py +++ b/lightning_pose/models/heatmap_tracker_multiview.py @@ -8,7 +8,7 @@ from torch import nn from torchtyping import TensorType -from lightning_pose.data.cameras import get_valid_projection_masks, project_camera_pairs_to_3d +from lightning_pose.data.cameras import project_camera_pairs_to_3d from lightning_pose.data.datatypes import ( MultiviewHeatmapLabeledBatchDict, MultiviewUnlabeledBatchDict, @@ -266,19 +266,14 @@ def get_loss_inputs_labeled( dist=batch_dict["distortions"].float(), ) keypoints_targ_3d = batch_dict["keypoints_3d"] - keypoints_mask_3d = get_valid_projection_masks( - target_keypoints.reshape((-1, num_views, num_keypoints, 2)) - ) except Exception as e: print(f"Error in 3D projection: {e}") keypoints_pred_3d = None keypoints_targ_3d = None - keypoints_mask_3d = None else: keypoints_pred_3d = None keypoints_targ_3d = None - keypoints_mask_3d = None return { "heatmaps_targ": batch_dict["heatmaps"], @@ -288,7 +283,6 @@ def get_loss_inputs_labeled( "confidences": confidence, "keypoints_targ_3d": keypoints_targ_3d, # shape (2*batch, num_keypoints, 3) "keypoints_pred_3d": keypoints_pred_3d, # shape (2*batch, cam_pairs, num_keypoints, 3) - "keypoints_mask_3d": keypoints_mask_3d, # shape (2*batch, cam_pairs, num_keypoints) } def predict_step( diff --git a/tests/data/test_cameras.py b/tests/data/test_cameras.py index ebacefd1..9eb94a80 100644 --- a/tests/data/test_cameras.py +++ b/tests/data/test_cameras.py @@ -1,9 +1,6 @@ import torch -from lightning_pose.data.cameras import ( - get_valid_projection_masks, - project_camera_pairs_to_3d, -) +from lightning_pose.data.cameras import project_camera_pairs_to_3d def test_project_camera_pairs_to_3d(): @@ -92,38 +89,3 @@ def test_project_camera_pairs_to_3d(): assert torch.all(torch.isnan(p3d[0, 1, 0, :])) assert torch.allclose(p3d[0, 1, 1, :], target[0, 1, 1, :], rtol=1e-2) assert torch.allclose(p3d[0, 2], target[0, 2], rtol=1e-3) - - -def test_get_valid_projection_masks(): - - n_batch = 2 - n_views = 3 - n_keypoints = 4 - points = torch.randn((n_batch, n_views, n_keypoints, 2)) - - points[0, 0, 0, :] = float('nan') # nan1 - points[0, 0, 1, :] = float('nan') # nan2 - points[1, 2, 3, :] = float('nan') # nan3 - - masks = get_valid_projection_masks(points) - - assert masks.shape == (n_batch, 3, n_keypoints) # 3 = 3 choose 2 - - # effect of nan1 - assert ~masks[0, 0, 0] - masks[0, 0, 0] = True - assert ~masks[0, 1, 0] - masks[0, 1, 0] = True - # effect of nan2 - assert ~masks[0, 0, 1] - masks[0, 0, 1] = True - assert ~masks[0, 1, 1] - masks[0, 1, 1] = True - # effect of nan3 - assert ~masks[1, 1, 3] - masks[1, 1, 3] = True - assert ~masks[1, 2, 3] - masks[1, 2, 3] = True - - # test others - assert torch.all(masks) diff --git a/tests/losses/test_losses.py b/tests/losses/test_losses.py index 6292dc36..b56c82e6 100644 --- a/tests/losses/test_losses.py +++ b/tests/losses/test_losses.py @@ -486,8 +486,7 @@ def test_zero_loss_when_predictions_equal_targets(self, pp_loss): num_cam_pairs = 3 keypoints_targ_3d = torch.ones(size=(num_batch, num_keypoints, 3)) keypoints_pred_3d = torch.ones(size=(num_batch, num_cam_pairs, num_keypoints, 3)) - keypoints_mask_3d = torch.ones(size=(num_batch, num_cam_pairs, num_keypoints)).bool() - loss, logs = pp_loss(keypoints_targ_3d, keypoints_pred_3d, keypoints_mask_3d, stage=stage) + loss, logs = pp_loss(keypoints_targ_3d, keypoints_pred_3d, stage=stage) assert loss.shape == torch.Size([]) assert loss == 0.0 assert logs[0]["name"] == f"{stage}_pairwise_projections_loss" @@ -501,32 +500,53 @@ def test_actual_values(self, pp_loss): num_cam_pairs = 3 keypoints_targ_3d = torch.zeros(size=(num_batch, num_keypoints, 3)) keypoints_pred_3d = torch.ones(size=(num_batch, num_cam_pairs, num_keypoints, 3)) - keypoints_mask_3d = torch.ones(size=(num_batch, num_cam_pairs, num_keypoints)).bool() - loss, _ = pp_loss(keypoints_targ_3d, keypoints_pred_3d, keypoints_mask_3d) + loss, _ = pp_loss(keypoints_targ_3d, keypoints_pred_3d) assert loss.isclose(torch.sqrt(torch.tensor(3))) - def test_mask(self, pp_loss): - num_batch = 2 + def test_targets_all_nans(self, pp_loss): + num_batch = 1 num_keypoints = 4 num_cam_pairs = 3 - keypoints_targ_3d = torch.zeros(size=(num_batch, num_keypoints, 3)) - keypoints_pred_3d = torch.ones(size=(num_batch, num_cam_pairs, num_keypoints, 3)) - keypoints_pred_3d[-1, ...] = 2 - keypoints_mask_3d = torch.ones(size=(num_batch, num_cam_pairs, num_keypoints)) - keypoints_mask_3d[-1, ...] = 0 - loss, _ = pp_loss(keypoints_targ_3d, keypoints_pred_3d, keypoints_mask_3d.bool()) - assert loss.isclose(torch.sqrt(torch.tensor(3))) + keypoints_targ_3d = torch.full((num_batch, num_keypoints, 3), float('nan')) + keypoints_pred_3d = torch.ones((num_batch, num_cam_pairs, num_keypoints, 3)) + loss, _ = pp_loss(keypoints_targ_3d, keypoints_pred_3d) + assert loss.item() == 0.0 - def test_all_nans(self, pp_loss): + def test_predictions_all_nans(self, pp_loss): num_batch = 1 num_keypoints = 4 num_cam_pairs = 3 - keypoints_targ_3d = torch.full((num_batch, num_keypoints, 3), float('nan')) + keypoints_targ_3d = torch.ones((num_batch, num_keypoints, 3)) keypoints_pred_3d = torch.full((num_batch, num_cam_pairs, num_keypoints, 3), float('nan')) - keypoints_mask_3d = torch.zeros(size=(num_batch, num_cam_pairs, num_keypoints)) - loss, _ = pp_loss(keypoints_targ_3d, keypoints_pred_3d, keypoints_mask_3d.bool()) + loss, _ = pp_loss(keypoints_targ_3d, keypoints_pred_3d) assert loss.item() == 0.0 + def test_targets_partial_nans(self, pp_loss): + num_batch = 2 + num_keypoints = 4 + num_cam_pairs = 2 + keypoints_targ_3d = torch.zeros(size=(num_batch, num_keypoints, 3)) + keypoints_targ_3d[0, 0, :] = float('nan') # first keypoint in first batch NaN + keypoints_pred_3d = torch.ones(size=(num_batch, num_cam_pairs, num_keypoints, 3)) + loss, _ = pp_loss(keypoints_targ_3d, keypoints_pred_3d) + # each valid position has loss = sqrt(3) (distance from 0 to 1 in 3D) + expected_loss = torch.sqrt(torch.tensor(3.0)) + assert loss.isclose(expected_loss) + + def test_predictions_partial_nans(self, pp_loss): + num_batch = 3 + num_keypoints = 4 + num_cam_pairs = 3 + keypoints_targ_3d = torch.zeros(size=(num_batch, num_keypoints, 3)) + keypoints_pred_3d = torch.ones(size=(num_batch, num_cam_pairs, num_keypoints, 3)) + keypoints_pred_3d[0, 0, 0, :] = float('nan') + keypoints_pred_3d[1, 1, :, :] = float('nan') + keypoints_pred_3d[2, :, :, :] = float('nan') + loss, _ = pp_loss(keypoints_targ_3d, keypoints_pred_3d) + # each valid position has loss = sqrt(3) (distance from 0 to 1 in 3D) + expected_loss = torch.sqrt(torch.tensor(3.0)) + assert loss.isclose(expected_loss) + def test_get_loss_classes(): loss_classes = get_loss_classes()