Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
split_video_files_by_view
=========================

.. currentmodule:: lightning_pose.utils.io

.. autofunction:: split_video_files_by_view
15 changes: 0 additions & 15 deletions lightning_pose/data/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# to ignore imports for sphix-autoapidoc
__all__ = [
"project_camera_pairs_to_3d",
"get_valid_projection_masks",
"CameraGroup",
]

Expand Down Expand Up @@ -81,20 +80,6 @@ def project_camera_pairs_to_3d(
return torch.stack(p3d, dim=1)


def get_valid_projection_masks(
points: TensorType["batch", "num_views", "num_keypoints", 2]
) -> TensorType["batch", "cam_pair", "num_keypoints"]:

num_batch, num_views, num_keypoints, _ = points.shape

m3d = []
for j1, j2 in itertools.combinations(range(num_views), 2):
points1 = points[:, j1, :, 0]
points2 = points[:, j2, :, 0]
m3d.append(~torch.isnan(points1 + points2))
return torch.stack(m3d, dim=1)


class CameraGroup(CameraGroupAnipose):
"""Inherit Anipose camera group and add new non-jitted triangulation method for dataloaders."""

Expand Down
29 changes: 12 additions & 17 deletions lightning_pose/losses/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,9 +774,12 @@ def __init__(self, log_weight: float = 0.0, **kwargs) -> None:
def remove_nans(
self,
loss: TensorType["batch", "cam_pairs", "num_keypoints"],
mask: TensorType["batch", "cam_pairs", "num_keypoints"],
) -> TensorType["valid_losses"]:
return torch.masked_select(loss, mask)
mask = ~torch.isnan(loss)
if mask.sum() == 0.0:
return torch.tensor(0.0, device=loss.device, dtype=loss.dtype)
else:
return torch.masked_select(loss, ~torch.isnan(loss))

def compute_loss(
self,
Expand All @@ -790,32 +793,24 @@ def __call__(
self,
keypoints_targ_3d: TensorType["batch", "num_keypoints", 3],
keypoints_pred_3d: TensorType["batch", "cam_pairs", "num_keypoints", 3],
keypoints_mask_3d: TensorType["batch", "cam_pairs", "num_keypoints"],
stage: Literal["train", "val", "test"] | None = None,
**kwargs,
) -> Tuple[TensorType[()], list[dict]]:

# check if 3D keypoints are available
if keypoints_targ_3d is None or keypoints_pred_3d is None or keypoints_mask_3d is None:
if keypoints_targ_3d is None or keypoints_pred_3d is None:
raise ValueError(
f"3D keypoints not available for {stage} stage. "
"Camera params file is required but not found;"
"Turn off supervised_pairwise_projections loss to avoid this error."
)

if keypoints_mask_3d.sum() == 0:
scalar_loss = torch.tensor(
0.0,
device=keypoints_targ_3d.device,
dtype=keypoints_targ_3d.dtype,
)
else:
elementwise_loss = self.compute_loss(
targets=keypoints_targ_3d,
predictions=keypoints_pred_3d,
)
clean_loss = self.remove_nans(loss=elementwise_loss, mask=keypoints_mask_3d)
scalar_loss = self.reduce_loss(clean_loss, method="mean")
elementwise_loss = self.compute_loss(
targets=keypoints_targ_3d,
predictions=keypoints_pred_3d,
)
clean_loss = self.remove_nans(loss=elementwise_loss)
scalar_loss = self.reduce_loss(clean_loss, method="mean")

logs = self.log_loss(loss=scalar_loss, stage=stage)

Expand Down
8 changes: 1 addition & 7 deletions lightning_pose/models/heatmap_tracker_multiview.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch import nn
from torchtyping import TensorType

from lightning_pose.data.cameras import get_valid_projection_masks, project_camera_pairs_to_3d
from lightning_pose.data.cameras import project_camera_pairs_to_3d
from lightning_pose.data.datatypes import (
MultiviewHeatmapLabeledBatchDict,
MultiviewUnlabeledBatchDict,
Expand Down Expand Up @@ -266,19 +266,14 @@ def get_loss_inputs_labeled(
dist=batch_dict["distortions"].float(),
)
keypoints_targ_3d = batch_dict["keypoints_3d"]
keypoints_mask_3d = get_valid_projection_masks(
target_keypoints.reshape((-1, num_views, num_keypoints, 2))
)

except Exception as e:
print(f"Error in 3D projection: {e}")
keypoints_pred_3d = None
keypoints_targ_3d = None
keypoints_mask_3d = None
else:
keypoints_pred_3d = None
keypoints_targ_3d = None
keypoints_mask_3d = None

return {
"heatmaps_targ": batch_dict["heatmaps"],
Expand All @@ -288,7 +283,6 @@ def get_loss_inputs_labeled(
"confidences": confidence,
"keypoints_targ_3d": keypoints_targ_3d, # shape (2*batch, num_keypoints, 3)
"keypoints_pred_3d": keypoints_pred_3d, # shape (2*batch, cam_pairs, num_keypoints, 3)
"keypoints_mask_3d": keypoints_mask_3d, # shape (2*batch, cam_pairs, num_keypoints)
}

def predict_step(
Expand Down
40 changes: 1 addition & 39 deletions tests/data/test_cameras.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import torch

from lightning_pose.data.cameras import (
get_valid_projection_masks,
project_camera_pairs_to_3d,
)
from lightning_pose.data.cameras import project_camera_pairs_to_3d


def test_project_camera_pairs_to_3d():
Expand Down Expand Up @@ -92,38 +89,3 @@ def test_project_camera_pairs_to_3d():
assert torch.all(torch.isnan(p3d[0, 1, 0, :]))
assert torch.allclose(p3d[0, 1, 1, :], target[0, 1, 1, :], rtol=1e-2)
assert torch.allclose(p3d[0, 2], target[0, 2], rtol=1e-3)


def test_get_valid_projection_masks():

n_batch = 2
n_views = 3
n_keypoints = 4
points = torch.randn((n_batch, n_views, n_keypoints, 2))

points[0, 0, 0, :] = float('nan') # nan1
points[0, 0, 1, :] = float('nan') # nan2
points[1, 2, 3, :] = float('nan') # nan3

masks = get_valid_projection_masks(points)

assert masks.shape == (n_batch, 3, n_keypoints) # 3 = 3 choose 2

# effect of nan1
assert ~masks[0, 0, 0]
masks[0, 0, 0] = True
assert ~masks[0, 1, 0]
masks[0, 1, 0] = True
# effect of nan2
assert ~masks[0, 0, 1]
masks[0, 0, 1] = True
assert ~masks[0, 1, 1]
masks[0, 1, 1] = True
# effect of nan3
assert ~masks[1, 1, 3]
masks[1, 1, 3] = True
assert ~masks[1, 2, 3]
masks[1, 2, 3] = True

# test others
assert torch.all(masks)
54 changes: 37 additions & 17 deletions tests/losses/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,7 @@ def test_zero_loss_when_predictions_equal_targets(self, pp_loss):
num_cam_pairs = 3
keypoints_targ_3d = torch.ones(size=(num_batch, num_keypoints, 3))
keypoints_pred_3d = torch.ones(size=(num_batch, num_cam_pairs, num_keypoints, 3))
keypoints_mask_3d = torch.ones(size=(num_batch, num_cam_pairs, num_keypoints)).bool()
loss, logs = pp_loss(keypoints_targ_3d, keypoints_pred_3d, keypoints_mask_3d, stage=stage)
loss, logs = pp_loss(keypoints_targ_3d, keypoints_pred_3d, stage=stage)
assert loss.shape == torch.Size([])
assert loss == 0.0
assert logs[0]["name"] == f"{stage}_pairwise_projections_loss"
Expand All @@ -501,32 +500,53 @@ def test_actual_values(self, pp_loss):
num_cam_pairs = 3
keypoints_targ_3d = torch.zeros(size=(num_batch, num_keypoints, 3))
keypoints_pred_3d = torch.ones(size=(num_batch, num_cam_pairs, num_keypoints, 3))
keypoints_mask_3d = torch.ones(size=(num_batch, num_cam_pairs, num_keypoints)).bool()
loss, _ = pp_loss(keypoints_targ_3d, keypoints_pred_3d, keypoints_mask_3d)
loss, _ = pp_loss(keypoints_targ_3d, keypoints_pred_3d)
assert loss.isclose(torch.sqrt(torch.tensor(3)))

def test_mask(self, pp_loss):
num_batch = 2
def test_targets_all_nans(self, pp_loss):
num_batch = 1
num_keypoints = 4
num_cam_pairs = 3
keypoints_targ_3d = torch.zeros(size=(num_batch, num_keypoints, 3))
keypoints_pred_3d = torch.ones(size=(num_batch, num_cam_pairs, num_keypoints, 3))
keypoints_pred_3d[-1, ...] = 2
keypoints_mask_3d = torch.ones(size=(num_batch, num_cam_pairs, num_keypoints))
keypoints_mask_3d[-1, ...] = 0
loss, _ = pp_loss(keypoints_targ_3d, keypoints_pred_3d, keypoints_mask_3d.bool())
assert loss.isclose(torch.sqrt(torch.tensor(3)))
keypoints_targ_3d = torch.full((num_batch, num_keypoints, 3), float('nan'))
keypoints_pred_3d = torch.ones((num_batch, num_cam_pairs, num_keypoints, 3))
loss, _ = pp_loss(keypoints_targ_3d, keypoints_pred_3d)
assert loss.item() == 0.0

def test_all_nans(self, pp_loss):
def test_predictions_all_nans(self, pp_loss):
num_batch = 1
num_keypoints = 4
num_cam_pairs = 3
keypoints_targ_3d = torch.full((num_batch, num_keypoints, 3), float('nan'))
keypoints_targ_3d = torch.ones((num_batch, num_keypoints, 3))
keypoints_pred_3d = torch.full((num_batch, num_cam_pairs, num_keypoints, 3), float('nan'))
keypoints_mask_3d = torch.zeros(size=(num_batch, num_cam_pairs, num_keypoints))
loss, _ = pp_loss(keypoints_targ_3d, keypoints_pred_3d, keypoints_mask_3d.bool())
loss, _ = pp_loss(keypoints_targ_3d, keypoints_pred_3d)
assert loss.item() == 0.0

def test_targets_partial_nans(self, pp_loss):
num_batch = 2
num_keypoints = 4
num_cam_pairs = 2
keypoints_targ_3d = torch.zeros(size=(num_batch, num_keypoints, 3))
keypoints_targ_3d[0, 0, :] = float('nan') # first keypoint in first batch NaN
keypoints_pred_3d = torch.ones(size=(num_batch, num_cam_pairs, num_keypoints, 3))
loss, _ = pp_loss(keypoints_targ_3d, keypoints_pred_3d)
# each valid position has loss = sqrt(3) (distance from 0 to 1 in 3D)
expected_loss = torch.sqrt(torch.tensor(3.0))
assert loss.isclose(expected_loss)

def test_predictions_partial_nans(self, pp_loss):
num_batch = 3
num_keypoints = 4
num_cam_pairs = 3
keypoints_targ_3d = torch.zeros(size=(num_batch, num_keypoints, 3))
keypoints_pred_3d = torch.ones(size=(num_batch, num_cam_pairs, num_keypoints, 3))
keypoints_pred_3d[0, 0, 0, :] = float('nan')
keypoints_pred_3d[1, 1, :, :] = float('nan')
keypoints_pred_3d[2, :, :, :] = float('nan')
loss, _ = pp_loss(keypoints_targ_3d, keypoints_pred_3d)
# each valid position has loss = sqrt(3) (distance from 0 to 1 in 3D)
expected_loss = torch.sqrt(torch.tensor(3.0))
assert loss.isclose(expected_loss)


def test_get_loss_classes():
loss_classes = get_loss_classes()
Expand Down