Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
6b35851
Refactor dataset validators
lochhh Nov 14, 2025
437a4c8
Remove duplication in subclasses
lochhh Nov 24, 2025
41ccacd
Simplify existing validator tests
lochhh Nov 24, 2025
5c14c87
Reorganise parent class
lochhh Nov 24, 2025
b5db555
Infer axes from DIM_NAMES
lochhh Nov 24, 2025
2ea7517
Add unit tests for parent class
lochhh Nov 24, 2025
98af9eb
Move LightningPose validation to position instead of individual_names
lochhh Nov 25, 2025
8f0958d
Remove ABC
lochhh Nov 25, 2025
074f060
Validate shape_array against position_array
lochhh Nov 25, 2025
7527c44
Add unit tests for frame_array
lochhh Nov 25, 2025
04ac17f
Group tests into classes
lochhh Nov 25, 2025
1204cb5
Replace test_datasets_validators.py
lochhh Nov 25, 2025
804f270
Address sonarqube warnings
lochhh Nov 25, 2025
9163b74
Add `validate` as _BaseDataset class method
lochhh Nov 27, 2025
06c57bb
Move LP check to load_poses.py and remove subclass hook
lochhh Nov 28, 2025
700f878
Rename dataset validator classes
lochhh Dec 1, 2025
c296b0b
Suggest DLC in from_lp_file error message
lochhh Jan 6, 2026
5b36178
Update validate method docstring
lochhh Jan 6, 2026
44ddcd3
Rename _BaseDatasetValidator to _BaseDatasetInputs
lochhh Jan 6, 2026
e686637
Rename bboxes validator class and add to_dataset method
lochhh Jan 6, 2026
6584608
Rename poses validator class and add to_dataset method
lochhh Jan 6, 2026
c543a74
Add abstract method to_dataset in _BaseDatasetInputs
lochhh Jan 6, 2026
25559ec
Update validate classmethod docstring
lochhh Jan 6, 2026
4792fc3
Update ValidInputs docstrings
lochhh Jan 6, 2026
a508c49
Update _validate_keypoint_names docstrings
lochhh Jan 6, 2026
84328b2
Simplify time_coords computation
lochhh Jan 6, 2026
99c7161
Add tests for to_dataset methods
lochhh Jan 6, 2026
282e76e
Use updated input validators in loader_widgets
lochhh Jan 6, 2026
8ae6721
Rename test variable in test_optional_fields_defaults
lochhh Jan 6, 2026
9d46a8d
Clarify warning message when converting str to list of str
lochhh Jan 6, 2026
50af275
Resolve arg-type and assignment errors
lochhh Jan 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ raise logger.exception(ValueError("message")) # with traceback
We aim to adhere to the [When to use logging guide](inv:python#logging-basic-tutorial) to ensure consistency in our logging practices.
In general:
* Use {func}`print` for simple, non-critical messages that do not need to be logged.
* Use {func}`warnings.warn` for user input issues that are non-critical and can be addressed within `movement`, e.g. deprecated function calls that are redirected, invalid `fps` number in {class}`ValidPosesDataset<movement.validators.datasets.ValidPosesDataset>` that is implicitly set to `None`; or when processing data containing excessive NaNs, which the user can potentially address using appropriate methods, e.g. {func}`interpolate_over_time()<movement.filtering.interpolate_over_time>`
* Use {meth}`logger.warning()<loguru._logger.Logger.warning>` for non-critical issues where default values are assigned to optional parameters, e.g. `individual_names`, `keypoint_names` in {class}`ValidPosesDataset<movement.validators.datasets.ValidPosesDataset>`.
* Use {func}`warnings.warn` for user input issues that are non-critical and can be addressed within `movement`, e.g. deprecated function calls that are redirected, invalid `fps` number in {class}`ValidPosesInputs<movement.validators.datasets.ValidPosesInputs>` that is implicitly set to `None`; or when processing data containing excessive NaNs, which the user can potentially address using appropriate methods, e.g. {func}`interpolate_over_time()<movement.filtering.interpolate_over_time>`
* Use {meth}`logger.warning()<loguru._logger.Logger.warning>` for non-critical issues where default values are assigned to optional parameters, e.g. `individual_names`, `keypoint_names` in {class}`ValidPosesInputs<movement.validators.datasets.ValidPosesInputs>`.

### Continuous integration
All pushes and pull requests will be built by [GitHub actions](github-docs:actions).
Expand Down
66 changes: 4 additions & 62 deletions movement/io/load_bboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import xarray as xr

from movement.utils.logging import logger
from movement.validators.datasets import ValidBboxesDataset
from movement.validators.datasets import ValidBboxesInputs
from movement.validators.files import (
DEFAULT_FRAME_REGEXP,
ValidFile,
Expand Down Expand Up @@ -136,7 +136,7 @@ def from_numpy(
... )

"""
valid_bboxes_data = ValidBboxesDataset(
valid_bboxes_inputs = ValidBboxesInputs(
position_array=position_array,
shape_array=shape_array,
confidence_array=confidence_array,
Expand All @@ -145,7 +145,7 @@ def from_numpy(
fps=fps,
source_software=source_software,
)
return _ds_from_valid_data(valid_bboxes_data)
return valid_bboxes_inputs.to_dataset()


def from_file(
Expand Down Expand Up @@ -360,7 +360,7 @@ def from_via_tracks_file(
),
fps=fps,
source_software="VIA-tracks",
) # it validates the dataset via ValidBboxesDataset
) # it validates the dataset via ValidBboxesInputs

# Add metadata as attributes
ds.attrs["source_software"] = "VIA-tracks"
Expand Down Expand Up @@ -648,61 +648,3 @@ def _via_attribute_column_to_numpy(
bbox_attr_array = np.array(list_bbox_attr)

return bbox_attr_array.squeeze()


def _ds_from_valid_data(data: ValidBboxesDataset) -> xr.Dataset:
"""Convert a validated bounding boxes dataset to an xarray Dataset.

Parameters
----------
data : movement.validators.datasets.ValidBboxesDataset
The validated bounding boxes dataset object.

Returns
-------
bounding boxes dataset containing the boxes tracks,
boxes shapes, confidence scores and associated metadata.

"""
# Create the time coordinate
time_coords = data.frame_array.squeeze() # type: ignore
time_unit = "frames"

dataset_attrs: dict[str, str | float | None] = {
"source_software": data.source_software,
"ds_type": "bboxes",
}
# if fps is provided:
# time_coords is expressed in seconds, with the time origin
# set as frame 0 == time 0 seconds
# Store fps as a dataset attribute
if data.fps:
# Compute elapsed time from frame 0.
# Ignoring type error because `data.frame_array` is not None after
# ValidBboxesDataset.__attrs_post_init__() # type: ignore
time_coords = np.array(
[frame / data.fps for frame in data.frame_array.squeeze()] # type: ignore
)
time_unit = "seconds"
dataset_attrs["fps"] = data.fps

dataset_attrs["time_unit"] = time_unit
# Convert data to an xarray.Dataset
# with dimensions ('time', 'space', 'individuals')
DIM_NAMES = ValidBboxesDataset.DIM_NAMES
n_space = data.position_array.shape[1]
return xr.Dataset(
data_vars={
"position": xr.DataArray(data.position_array, dims=DIM_NAMES),
"shape": xr.DataArray(data.shape_array, dims=DIM_NAMES),
"confidence": xr.DataArray(
data.confidence_array, dims=DIM_NAMES[:1] + DIM_NAMES[2:]
),
},
coords={
DIM_NAMES[0]: time_coords,
DIM_NAMES[1]: ["x", "y", "z"][:n_space],
DIM_NAMES[2]: data.individual_names,
},
attrs=dataset_attrs,
)
74 changes: 15 additions & 59 deletions movement/io/load_poses.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Load pose tracking data from various frameworks into ``movement``."""

from pathlib import Path
from typing import TYPE_CHECKING, Literal
from typing import Literal

import h5py
import numpy as np
Expand All @@ -12,7 +12,7 @@
from sleap_io.model.labels import Labels

from movement.utils.logging import logger
from movement.validators.datasets import ValidPosesDataset
from movement.validators.datasets import ValidPosesInputs
from movement.validators.files import (
ValidAniposeCSV,
ValidDeepLabCutCSV,
Expand All @@ -21,9 +21,6 @@
ValidNWBFile,
)

if TYPE_CHECKING:
from numpy.typing import NDArray


def from_numpy(
position_array: np.ndarray,
Expand Down Expand Up @@ -85,15 +82,15 @@ def from_numpy(
... )

"""
valid_data = ValidPosesDataset(
valid_poses_inputs = ValidPosesInputs(
position_array=position_array,
confidence_array=confidence_array,
individual_names=individual_names,
keypoint_names=keypoint_names,
fps=fps,
source_software=source_software,
)
return _ds_from_valid_data(valid_data)
return valid_poses_inputs.to_dataset()


def from_file(
Expand Down Expand Up @@ -356,9 +353,19 @@ def from_lp_file(
>>> ds = load_poses.from_lp_file("path/to/file.csv", fps=30)

"""
return _ds_from_lp_or_dlc_file(
ds = _ds_from_lp_or_dlc_file(
file_path=file_path, source_software="LightningPose", fps=fps
)
n_individuals = ds.sizes.get("individuals", 1)
if n_individuals > 1:
raise logger.error(
ValueError(
"LightningPose only supports single-individual datasets, "
f"but the loaded dataset has {n_individuals} individuals. "
"Did you mean to load from a DeepLabCut file instead?"
)
)
return ds


def from_dlc_file(
Expand Down Expand Up @@ -695,57 +702,6 @@ def _df_from_dlc_h5(file_path: Path) -> pd.DataFrame:
return df


def _ds_from_valid_data(data: ValidPosesDataset) -> xr.Dataset:
"""Create a ``movement`` poses dataset from validated pose tracking data.

Parameters
----------
data : movement.io.tracks_validators.ValidPosesDataset
The validated data object.

Returns
-------
xarray.Dataset
``movement`` dataset containing the pose tracks, confidence scores,
and associated metadata.

"""
n_frames = data.position_array.shape[0]
n_space = data.position_array.shape[1]
dataset_attrs: dict[str, str | float | None] = {
"source_software": data.source_software,
"ds_type": "poses",
}
# Create the time coordinate, depending on the value of fps
time_coords: NDArray[np.floating] | NDArray[np.integer]
time_unit: Literal["seconds", "frames"]
if data.fps is not None:
time_coords = np.arange(n_frames, dtype=np.float64) / data.fps
time_unit = "seconds"
dataset_attrs["fps"] = data.fps
else:
time_coords = np.arange(n_frames, dtype=np.int64)
time_unit = "frames"
dataset_attrs["time_unit"] = time_unit
DIM_NAMES = ValidPosesDataset.DIM_NAMES
# Convert data to an xarray.Dataset
return xr.Dataset(
data_vars={
"position": xr.DataArray(data.position_array, dims=DIM_NAMES),
"confidence": xr.DataArray(
data.confidence_array, dims=DIM_NAMES[:1] + DIM_NAMES[2:]
),
},
coords={
DIM_NAMES[0]: time_coords,
DIM_NAMES[1]: ["x", "y", "z"][:n_space],
DIM_NAMES[2]: data.keypoint_names,
DIM_NAMES[3]: data.individual_names,
},
attrs=dataset_attrs,
)


def from_anipose_style_df(
df: pd.DataFrame,
fps: float | None = None,
Expand Down
4 changes: 2 additions & 2 deletions movement/io/save_bboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import xarray as xr

from movement.utils.logging import logger
from movement.validators.datasets import ValidBboxesDataset, _validate_dataset
from movement.validators.datasets import ValidBboxesInputs
from movement.validators.files import _validate_file_path

if TYPE_CHECKING:
Expand Down Expand Up @@ -128,7 +128,7 @@ def to_via_tracks_file(
"""
# Validate file path and dataset
file = _validate_file_path(file_path, expected_suffix=[".csv"])
_validate_dataset(ds, ValidBboxesDataset)
ValidBboxesInputs.validate(ds)

# Check the number of digits required to represent the frame numbers
frame_n_digits = _check_frame_required_digits(
Expand Down
8 changes: 4 additions & 4 deletions movement/io/save_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
_write_processing_module,
)
from movement.utils.logging import logger
from movement.validators.datasets import ValidPosesDataset, _validate_dataset
from movement.validators.datasets import ValidPosesInputs
from movement.validators.files import _validate_file_path


Expand Down Expand Up @@ -126,7 +126,7 @@ def to_dlc_style_df(
to_dlc_file : Save dataset directly to a DeepLabCut-style .h5 or .csv file.

"""
_validate_dataset(ds, ValidPosesDataset)
ValidPosesInputs.validate(ds)
scorer = ["movement"]
bodyparts = ds.coords["keypoints"].data.tolist()
base_coords = ds.coords["space"].data.tolist()
Expand Down Expand Up @@ -265,7 +265,7 @@ def to_lp_file(

"""
file = _validate_file_path(file_path=file_path, expected_suffix=[".csv"])
_validate_dataset(ds, ValidPosesDataset)
ValidPosesInputs.validate(ds)
to_dlc_file(ds, file.path, split_individuals=True)


Expand Down Expand Up @@ -309,7 +309,7 @@ def to_sleap_analysis_file(ds: xr.Dataset, file_path: str | Path) -> None:

"""
file = _validate_file_path(file_path=file_path, expected_suffix=[".h5"])
_validate_dataset(ds, ValidPosesDataset)
ValidPosesInputs.validate(ds)

ds = _remove_unoccupied_tracks(ds)

Expand Down
19 changes: 9 additions & 10 deletions movement/napari/loader_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@
from movement.napari.convert import ds_to_napari_layers
from movement.napari.layer_styles import BoxesStyle, PointsStyle, TracksStyle
from movement.utils.logging import logger
from movement.validators.datasets import (
ValidBboxesDataset,
ValidPosesDataset,
_validate_dataset,
)
from movement.validators.datasets import ValidBboxesInputs, ValidPosesInputs

# Allowed file suffixes for each supported source software
SUPPORTED_POSES_FILES = {
Expand Down Expand Up @@ -267,13 +263,16 @@ def _load_netcdf_file(self) -> xr.Dataset | None:
return None

# Validate dataset depending on its type
validator = {
"poses": ValidPosesDataset,
"bboxes": ValidBboxesDataset,
}[ds_type]
validators: dict[
str, type[ValidPosesInputs] | type[ValidBboxesInputs]
] = {
"poses": ValidPosesInputs,
"bboxes": ValidBboxesInputs,
}
validator = validators[ds_type]

try:
_validate_dataset(ds, validator)
validator.validate(ds)
except (ValueError, TypeError) as e:
show_error(
f"The netCDF file does not appear to be a valid "
Expand Down
Loading