Skip to content

Commit 0a33e5c

Browse files
committed
Rename dataset validator classes
1 parent 1b37373 commit 0a33e5c

File tree

10 files changed

+50
-48
lines changed

10 files changed

+50
-48
lines changed

CONTRIBUTING.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,8 @@ raise logger.exception(ValueError("message")) # with traceback
190190
We aim to adhere to the [When to use logging guide](inv:python#logging-basic-tutorial) to ensure consistency in our logging practices.
191191
In general:
192192
* Use {func}`print` for simple, non-critical messages that do not need to be logged.
193-
* Use {func}`warnings.warn` for user input issues that are non-critical and can be addressed within `movement`, e.g. deprecated function calls that are redirected, invalid `fps` number in {class}`ValidPosesDataset<movement.validators.datasets.ValidPosesDataset>` that is implicitly set to `None`; or when processing data containing excessive NaNs, which the user can potentially address using appropriate methods, e.g. {func}`interpolate_over_time()<movement.filtering.interpolate_over_time>`
194-
* Use {meth}`logger.warning()<loguru._logger.Logger.warning>` for non-critical issues where default values are assigned to optional parameters, e.g. `individual_names`, `keypoint_names` in {class}`ValidPosesDataset<movement.validators.datasets.ValidPosesDataset>`.
193+
* Use {func}`warnings.warn` for user input issues that are non-critical and can be addressed within `movement`, e.g. deprecated function calls that are redirected, invalid `fps` number in {class}`PosesValidator<movement.validators.datasets.PosesValidator>` that is implicitly set to `None`; or when processing data containing excessive NaNs, which the user can potentially address using appropriate methods, e.g. {func}`interpolate_over_time()<movement.filtering.interpolate_over_time>`
194+
* Use {meth}`logger.warning()<loguru._logger.Logger.warning>` for non-critical issues where default values are assigned to optional parameters, e.g. `individual_names`, `keypoint_names` in {class}`PosesValidator<movement.validators.datasets.PosesValidator>`.
195195
196196
### Continuous integration
197197
All pushes and pull requests will be built by [GitHub actions](github-docs:actions).

movement/io/load_bboxes.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import xarray as xr
1212

1313
from movement.utils.logging import logger
14-
from movement.validators.datasets import ValidBboxesDataset
14+
from movement.validators.datasets import BboxesValidator
1515
from movement.validators.files import (
1616
DEFAULT_FRAME_REGEXP,
1717
ValidFile,
@@ -136,7 +136,7 @@ def from_numpy(
136136
... )
137137
138138
"""
139-
valid_bboxes_data = ValidBboxesDataset(
139+
valid_bboxes_data = BboxesValidator(
140140
position_array=position_array,
141141
shape_array=shape_array,
142142
confidence_array=confidence_array,
@@ -360,7 +360,7 @@ def from_via_tracks_file(
360360
),
361361
fps=fps,
362362
source_software="VIA-tracks",
363-
) # it validates the dataset via ValidBboxesDataset
363+
) # it validates the dataset via BboxesValidator
364364

365365
# Add metadata as attributes
366366
ds.attrs["source_software"] = "VIA-tracks"
@@ -650,13 +650,13 @@ def _via_attribute_column_to_numpy(
650650
return bbox_attr_array.squeeze()
651651

652652

653-
def _ds_from_valid_data(data: ValidBboxesDataset) -> xr.Dataset:
653+
def _ds_from_valid_data(data: BboxesValidator) -> xr.Dataset:
654654
"""Convert a validated bounding boxes dataset to an xarray Dataset.
655655
656656
Parameters
657657
----------
658-
data : movement.validators.datasets.ValidBboxesDataset
659-
The validated bounding boxes dataset object.
658+
data : movement.validators.datasets.BboxesValidator
659+
The validator object containing the validated bounding boxes data.
660660
661661
Returns
662662
-------
@@ -678,8 +678,8 @@ def _ds_from_valid_data(data: ValidBboxesDataset) -> xr.Dataset:
678678
# Store fps as a dataset attribute
679679
if data.fps:
680680
# Compute elapsed time from frame 0.
681-
# Ignoring type error because `data.frame_array` is not None after
682-
# ValidBboxesDataset.__attrs_post_init__() # type: ignore
681+
# Ignore type error as BboxesValidator ensures
682+
# `data.frame_array` is not None
683683
time_coords = np.array(
684684
[frame / data.fps for frame in data.frame_array.squeeze()] # type: ignore
685685
)
@@ -689,7 +689,7 @@ def _ds_from_valid_data(data: ValidBboxesDataset) -> xr.Dataset:
689689
dataset_attrs["time_unit"] = time_unit
690690
# Convert data to an xarray.Dataset
691691
# with dimensions ('time', 'space', 'individuals')
692-
DIM_NAMES = ValidBboxesDataset.DIM_NAMES
692+
DIM_NAMES = BboxesValidator.DIM_NAMES
693693
n_space = data.position_array.shape[1]
694694
return xr.Dataset(
695695
data_vars={

movement/io/load_poses.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sleap_io.model.labels import Labels
1313

1414
from movement.utils.logging import logger
15-
from movement.validators.datasets import ValidPosesDataset
15+
from movement.validators.datasets import PosesValidator
1616
from movement.validators.files import (
1717
ValidAniposeCSV,
1818
ValidDeepLabCutCSV,
@@ -82,7 +82,7 @@ def from_numpy(
8282
... )
8383
8484
"""
85-
valid_data = ValidPosesDataset(
85+
valid_data = PosesValidator(
8686
position_array=position_array,
8787
confidence_array=confidence_array,
8888
individual_names=individual_names,
@@ -713,13 +713,13 @@ def _df_from_dlc_h5(file_path: Path) -> pd.DataFrame:
713713
return df
714714

715715

716-
def _ds_from_valid_data(data: ValidPosesDataset) -> xr.Dataset:
716+
def _ds_from_valid_data(data: PosesValidator) -> xr.Dataset:
717717
"""Create a ``movement`` poses dataset from validated pose tracking data.
718718
719719
Parameters
720720
----------
721-
data : movement.io.tracks_validators.ValidPosesDataset
722-
The validated data object.
721+
data : movement.io.tracks_validators.PosesValidator
722+
The validator object containing the validated pose tracking data.
723723
724724
Returns
725725
-------
@@ -746,7 +746,7 @@ def _ds_from_valid_data(data: ValidPosesDataset) -> xr.Dataset:
746746

747747
dataset_attrs["time_unit"] = time_unit
748748

749-
DIM_NAMES = ValidPosesDataset.DIM_NAMES
749+
DIM_NAMES = PosesValidator.DIM_NAMES
750750
# Convert data to an xarray.Dataset
751751
return xr.Dataset(
752752
data_vars={

movement/io/save_bboxes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import xarray as xr
1111

1212
from movement.utils.logging import logger
13-
from movement.validators.datasets import ValidBboxesDataset
13+
from movement.validators.datasets import BboxesValidator
1414
from movement.validators.files import _validate_file_path
1515

1616

@@ -124,7 +124,7 @@ def to_via_tracks_file(
124124
"""
125125
# Validate file path and dataset
126126
file = _validate_file_path(file_path, expected_suffix=[".csv"])
127-
ValidBboxesDataset.validate(ds)
127+
BboxesValidator.validate(ds)
128128

129129
# Check the number of digits required to represent the frame numbers
130130
frame_n_digits = _check_frame_required_digits(

movement/io/save_poses.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
_write_processing_module,
1616
)
1717
from movement.utils.logging import logger
18-
from movement.validators.datasets import ValidPosesDataset
18+
from movement.validators.datasets import PosesValidator
1919
from movement.validators.files import _validate_file_path
2020

2121

@@ -126,7 +126,7 @@ def to_dlc_style_df(
126126
to_dlc_file : Save dataset directly to a DeepLabCut-style .h5 or .csv file.
127127
128128
"""
129-
ValidPosesDataset.validate(ds)
129+
PosesValidator.validate(ds)
130130
scorer = ["movement"]
131131
bodyparts = ds.coords["keypoints"].data.tolist()
132132
base_coords = ds.coords["space"].data.tolist()
@@ -265,7 +265,7 @@ def to_lp_file(
265265
266266
"""
267267
file = _validate_file_path(file_path=file_path, expected_suffix=[".csv"])
268-
ValidPosesDataset.validate(ds)
268+
PosesValidator.validate(ds)
269269
to_dlc_file(ds, file.path, split_individuals=True)
270270

271271

@@ -309,7 +309,7 @@ def to_sleap_analysis_file(ds: xr.Dataset, file_path: str | Path) -> None:
309309
310310
"""
311311
file = _validate_file_path(file_path=file_path, expected_suffix=[".h5"])
312-
ValidPosesDataset.validate(ds)
312+
PosesValidator.validate(ds)
313313

314314
ds = _remove_unoccupied_tracks(ds)
315315

movement/validators/datasets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def validate(cls, ds: xr.Dataset) -> None:
248248

249249

250250
@define(kw_only=True)
251-
class ValidPosesDataset(_BaseValidDataset):
251+
class PosesValidator(_BaseValidDataset):
252252
"""Class for validating poses data intended for a ``movement`` dataset.
253253
254254
The validator ensures that within the ``movement poses`` dataset:
@@ -336,7 +336,7 @@ def __attrs_post_init__(self):
336336

337337

338338
@define(kw_only=True)
339-
class ValidBboxesDataset(_BaseValidDataset):
339+
class BboxesValidator(_BaseValidDataset):
340340
"""Class for validating bounding boxes data for a ``movement`` dataset.
341341
342342
The validator considers 2D bounding boxes only. It ensures that

tests/fixtures/datasets.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
import pytest
66
import xarray as xr
77

8-
from movement.validators.datasets import ValidBboxesDataset, ValidPosesDataset
8+
from movement.validators.datasets import BboxesValidator, PosesValidator
99

1010

1111
# -------------------- Valid bboxes datasets and arrays --------------------
1212
@pytest.fixture
1313
def valid_bboxes_arrays_all_zeros():
1414
"""Return a dictionary of valid zero arrays (in terms of shape) for a
15-
ValidBboxesDataset.
15+
valid bboxes dataset.
1616
"""
1717
# define the shape of the arrays
1818
n_frames, n_space, n_individuals = (10, 2, 2)
@@ -31,7 +31,7 @@ def valid_bboxes_arrays_all_zeros():
3131
@pytest.fixture
3232
def valid_bboxes_arrays():
3333
"""Return a dictionary of valid arrays for a
34-
ValidBboxesDataset representing a uniform linear motion.
34+
BboxesValidator representing a uniform linear motion.
3535
3636
It represents 2 individuals for 10 frames, in 2D space.
3737
- Individual 0 moves along the x=y line from the origin.
@@ -94,7 +94,7 @@ def valid_bboxes_dataset(valid_bboxes_arrays):
9494
- Individual 0 at frames 2, 3, 4
9595
- Individual 1 at frames 2, 3
9696
"""
97-
dim_names = ValidBboxesDataset.DIM_NAMES
97+
dim_names = BboxesValidator.DIM_NAMES
9898

9999
position_array = valid_bboxes_arrays["position"]
100100
shape_array = valid_bboxes_arrays["shape"]
@@ -152,7 +152,7 @@ def valid_bboxes_dataset_with_nan(valid_bboxes_dataset):
152152
@pytest.fixture
153153
def valid_poses_arrays():
154154
"""Return a dictionary of valid arrays for a
155-
ValidPosesDataset representing a uniform linear motion.
155+
valid poses dataset representing a uniform linear motion.
156156
157157
This fixture is a factory of fixtures.
158158
Depending on the ``array_type`` requested (``multi_individual_array``,
@@ -173,7 +173,7 @@ def valid_poses_arrays():
173173
"""
174174

175175
def _valid_poses_arrays(array_type):
176-
"""Return a dictionary of valid arrays for a ValidPosesDataset."""
176+
"""Return a dictionary of valid arrays for a PosesValidator."""
177177
# Unless specified, default is a ``multi_individual_array`` with
178178
# 10 frames, 3 keypoints, and 2 individuals in 2D space.
179179
n_frames, n_space, n_keypoints, n_individuals = (10, 2, 3, 2)
@@ -250,7 +250,7 @@ def valid_poses_dataset(valid_poses_arrays, request):
250250
Default is a ``multi_individual_array`` (2 individuals, 3 keypoints each).
251251
See the ``valid_poses_arrays`` fixture for details.
252252
"""
253-
dim_names = ValidPosesDataset.DIM_NAMES
253+
dim_names = PosesValidator.DIM_NAMES
254254
# create a multi_individual_array by default unless overridden via param
255255
try:
256256
array_type = request.param

tests/test_unit/test_io/test_load_bboxes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import pytest
1111

1212
from movement.io import load_bboxes
13-
from movement.validators.datasets import ValidBboxesDataset
13+
from movement.validators.datasets import BboxesValidator
1414

1515

1616
@pytest.fixture()
@@ -269,7 +269,7 @@ def test_from_file(
269269

270270
expected_values_bboxes = {
271271
"vars_dims": {"position": 3, "shape": 3, "confidence": 2},
272-
"dim_names": ValidBboxesDataset.DIM_NAMES,
272+
"dim_names": BboxesValidator.DIM_NAMES,
273273
}
274274

275275

tests/test_unit/test_io/test_load_poses.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
from pytest import DATA_PATHS
99

1010
from movement.io import load_poses
11-
from movement.validators.datasets import ValidPosesDataset
11+
from movement.validators.datasets import PosesValidator
1212

1313
expected_values_poses = {
1414
"vars_dims": {"position": 4, "confidence": 3},
15-
"dim_names": ValidPosesDataset.DIM_NAMES,
15+
"dim_names": PosesValidator.DIM_NAMES,
1616
}
1717

1818

tests/test_unit/test_validators/test_datasets_validators.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import xarray as xr
77

88
from movement.validators.datasets import (
9-
ValidBboxesDataset,
10-
ValidPosesDataset,
9+
BboxesValidator,
10+
PosesValidator,
1111
_BaseValidDataset,
1212
_convert_fps_to_none_if_invalid,
1313
_convert_to_list_of_str,
@@ -279,8 +279,8 @@ def test_validate(self, dataset_fixture, expected_exception, request):
279279
)
280280

281281

282-
class TestValidPosesDataset:
283-
"""Test the ValidPosesDataset class."""
282+
class TestPosesValidator:
283+
"""Test the PosesValidator class."""
284284

285285
@pytest.mark.parametrize(
286286
"position_array, keypoint_names, expected_context",
@@ -312,17 +312,17 @@ class TestValidPosesDataset:
312312
def test_keypoint_names(
313313
self, position_array, keypoint_names, expected_context
314314
):
315-
"""Test keypoint_names validation in ValidPosesDataset."""
315+
"""Test keypoint_names validation in PosesValidator."""
316316
with expected_context as expected_keypoint_names:
317-
ds = ValidPosesDataset(
317+
data = PosesValidator(
318318
position_array=position_array,
319319
keypoint_names=keypoint_names,
320320
)
321-
assert ds.keypoint_names == expected_keypoint_names
321+
assert data.keypoint_names == expected_keypoint_names
322322

323323

324-
class TestValidBboxesDataset:
325-
"""Test the ValidBboxesDataset class."""
324+
class TestBboxesValidator:
325+
"""Test the BboxesValidator class."""
326326

327327
@pytest.mark.parametrize(
328328
"shape_array, expected_context",
@@ -364,7 +364,7 @@ def test_shape_array(self, shape_array, expected_context):
364364
"""Test shape_array validation."""
365365
position_array = np.zeros((5, 2, 3)) # time, space, individuals
366366
with expected_context:
367-
ValidBboxesDataset(
367+
BboxesValidator(
368368
position_array=position_array,
369369
shape_array=shape_array,
370370
)
@@ -417,9 +417,11 @@ def test_frame_array(self, frame_array, expected_context):
417417
position_array = np.zeros((5, 2, 3)) # time, space, individuals
418418
shape_array = np.zeros((5, 2, 3))
419419
with expected_context as expected_frame_array:
420-
ds = ValidBboxesDataset(
420+
data = BboxesValidator(
421421
position_array=position_array,
422422
shape_array=shape_array,
423423
frame_array=frame_array,
424424
)
425-
np.testing.assert_array_equal(ds.frame_array, expected_frame_array)
425+
np.testing.assert_array_equal(
426+
data.frame_array, expected_frame_array
427+
)

0 commit comments

Comments
 (0)