Skip to content

Commit d190ce5

Browse files
vigjipre-commit-ci[bot]niksirbi
authored
Loading function for Anipose data (#358)
* first draft of loading function * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * adapted to new dimensions order * adapted to work with new dims arrangement * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * anipose loader test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * validator for anipose file * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * anipose validator finished * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * linting fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update tests/test_unit/test_validators/test_files_validators.py Co-authored-by: Niko Sirmpilatze <[email protected]> * simplified validator test * Update movement/io/load_poses.py Co-authored-by: Niko Sirmpilatze <[email protected]> * Update movement/validators/files.py Co-authored-by: Niko Sirmpilatze <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update movement/validators/files.py Co-authored-by: Niko Sirmpilatze <[email protected]> * Update movement/validators/files.py Co-authored-by: Niko Sirmpilatze <[email protected]> * implementing fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * more consistency fixes * moved anipose loading test to load_poses * fixed validators tests * tests for anipose loading done properly * docstring fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Implementing direct anipose load from from_file * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * ruffed * trying to fix mypy check * Update movement/io/load_poses.py Co-authored-by: Niko Sirmpilatze <[email protected]> * Update movement/io/load_poses.py Co-authored-by: Niko Sirmpilatze <[email protected]> * Update movement/io/load_poses.py Co-authored-by: Niko Sirmpilatze <[email protected]> * Update movement/io/load_poses.py Co-authored-by: Niko Sirmpilatze <[email protected]> * final touches to docstrings * added entry in input_output docs * define anipose link in conf.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Niko Sirmpilatze <[email protected]>
1 parent e7d8e47 commit d190ce5

File tree

7 files changed

+350
-4
lines changed

7 files changed

+350
-4
lines changed

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@
203203
"xarray": "https://docs.xarray.dev/en/stable/{{path}}#{{fragment}}",
204204
"lp": "https://lightning-pose.readthedocs.io/en/stable/{{path}}#{{fragment}}",
205205
"via": "https://www.robots.ox.ac.uk/~vgg/software/via/{{path}}#{{fragment}}",
206+
"anipose": "https://anipose.readthedocs.io/en/latest/",
206207
}
207208

208209
intersphinx_mapping = {

docs/source/user_guide/input_output.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ To analyse pose tracks, `movement` supports loading data from various frameworks
1010
- [DeepLabCut](dlc:) (DLC)
1111
- [SLEAP](sleap:) (SLEAP)
1212
- [LightingPose](lp:) (LP)
13+
- [Anipose](anipose:) (Anipose)
1314

1415
To analyse bounding boxes' tracks, `movement` currently supports the [VGG Image Annotator](via:) (VIA) format for [tracks annotation](via:docs/face_track_annotation.html).
1516

@@ -84,6 +85,22 @@ ds = load_poses.from_file(
8485
```
8586
:::
8687

88+
:::{tab-item} Anipose
89+
90+
To load Anipose files in .csv format:
91+
```python
92+
ds = load_poses.from_anipose_file(
93+
"/path/to/file.analysis.csv", fps=30, individual_name="individual_0"
94+
) # We can optionally specify the individual name, by default it is "individual_0"
95+
96+
# or equivalently
97+
ds = load_poses.from_file(
98+
"/path/to/file.analysis.csv", source_software="Anipose", fps=30, individual_name="individual_0"
99+
)
100+
101+
```
102+
:::
103+
87104
:::{tab-item} From NumPy
88105

89106
In the example below, we create random position data for two individuals, ``Alice`` and ``Bob``,

movement/io/load_poses.py

Lines changed: 130 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313

1414
from movement.utils.logging import log_error, log_warning
1515
from movement.validators.datasets import ValidPosesDataset
16-
from movement.validators.files import ValidDeepLabCutCSV, ValidFile, ValidHDF5
16+
from movement.validators.files import (
17+
ValidAniposeCSV,
18+
ValidDeepLabCutCSV,
19+
ValidFile,
20+
ValidHDF5,
21+
)
1722

1823
logger = logging.getLogger(__name__)
1924

@@ -91,8 +96,11 @@ def from_numpy(
9196

9297
def from_file(
9398
file_path: Path | str,
94-
source_software: Literal["DeepLabCut", "SLEAP", "LightningPose"],
99+
source_software: Literal[
100+
"DeepLabCut", "SLEAP", "LightningPose", "Anipose"
101+
],
95102
fps: float | None = None,
103+
**kwargs,
96104
) -> xr.Dataset:
97105
"""Create a ``movement`` poses dataset from any supported file.
98106
@@ -104,11 +112,14 @@ def from_file(
104112
``from_slp_file()`` or ``from_lp_file()`` functions. One of these
105113
these functions will be called internally, based on
106114
the value of ``source_software``.
107-
source_software : "DeepLabCut", "SLEAP" or "LightningPose"
115+
source_software : "DeepLabCut", "SLEAP", "LightningPose", or "Anipose"
108116
The source software of the file.
109117
fps : float, optional
110118
The number of frames per second in the video. If None (default),
111119
the ``time`` coordinates will be in frame numbers.
120+
**kwargs : dict, optional
121+
Additional keyword arguments to pass to the software-specific
122+
loading functions that are listed under "See Also".
112123
113124
Returns
114125
-------
@@ -121,6 +132,7 @@ def from_file(
121132
movement.io.load_poses.from_dlc_file
122133
movement.io.load_poses.from_sleap_file
123134
movement.io.load_poses.from_lp_file
135+
movement.io.load_poses.from_anipose_file
124136
125137
Examples
126138
--------
@@ -136,6 +148,8 @@ def from_file(
136148
return from_sleap_file(file_path, fps)
137149
elif source_software == "LightningPose":
138150
return from_lp_file(file_path, fps)
151+
elif source_software == "Anipose":
152+
return from_anipose_file(file_path, fps, **kwargs)
139153
else:
140154
raise log_error(
141155
ValueError, f"Unsupported source software: {source_software}"
@@ -696,3 +710,116 @@ def _ds_from_valid_data(data: ValidPosesDataset) -> xr.Dataset:
696710
"ds_type": "poses",
697711
},
698712
)
713+
714+
715+
def from_anipose_style_df(
716+
df: pd.DataFrame,
717+
fps: float | None = None,
718+
individual_name: str = "individual_0",
719+
) -> xr.Dataset:
720+
"""Create a ``movement`` poses dataset from an Anipose 3D dataframe.
721+
722+
Parameters
723+
----------
724+
df : pd.DataFrame
725+
Anipose triangulation dataframe
726+
fps : float, optional
727+
The number of frames per second in the video. If None (default),
728+
the ``time`` coordinates will be in frame units.
729+
individual_name : str, optional
730+
Name of the individual, by default "individual_0"
731+
732+
Returns
733+
-------
734+
xarray.Dataset
735+
``movement`` dataset containing the pose tracks, confidence scores,
736+
and associated metadata.
737+
738+
739+
Notes
740+
-----
741+
Reshape dataframe with columns keypoint1_x, keypoint1_y, keypoint1_z,
742+
keypoint1_score,keypoint2_x, keypoint2_y, keypoint2_z,
743+
keypoint2_score...to array of positions with dimensions
744+
time, space, keypoints, individuals, and array of confidence (from scores)
745+
with dimensions time, keypoints, individuals.
746+
747+
"""
748+
keypoint_names = sorted(
749+
list(
750+
set(
751+
[
752+
col.rsplit("_", 1)[0]
753+
for col in df.columns
754+
if any(col.endswith(f"_{s}") for s in ["x", "y", "z"])
755+
]
756+
)
757+
)
758+
)
759+
760+
n_frames = len(df)
761+
n_keypoints = len(keypoint_names)
762+
763+
# Initialize arrays and fill
764+
position_array = np.zeros(
765+
(n_frames, 3, n_keypoints, 1)
766+
) # 1 for single individual
767+
confidence_array = np.zeros((n_frames, n_keypoints, 1))
768+
for i, kp in enumerate(keypoint_names):
769+
for j, coord in enumerate(["x", "y", "z"]):
770+
position_array[:, j, i, 0] = df[f"{kp}_{coord}"]
771+
confidence_array[:, i, 0] = df[f"{kp}_score"]
772+
773+
individual_names = [individual_name]
774+
775+
return from_numpy(
776+
position_array=position_array,
777+
confidence_array=confidence_array,
778+
individual_names=individual_names,
779+
keypoint_names=keypoint_names,
780+
source_software="Anipose",
781+
fps=fps,
782+
)
783+
784+
785+
def from_anipose_file(
786+
file_path: Path | str,
787+
fps: float | None = None,
788+
individual_name: str = "individual_0",
789+
) -> xr.Dataset:
790+
"""Create a ``movement`` poses dataset from an Anipose 3D .csv file.
791+
792+
Parameters
793+
----------
794+
file_path : pathlib.Path
795+
Path to the Anipose triangulation .csv file
796+
fps : float, optional
797+
The number of frames per second in the video. If None (default),
798+
the ``time`` coordinates will be in frame units.
799+
individual_name : str, optional
800+
Name of the individual, by default "individual_0"
801+
802+
Returns
803+
-------
804+
xarray.Dataset
805+
``movement`` dataset containing the pose tracks, confidence scores,
806+
and associated metadata.
807+
808+
Notes
809+
-----
810+
We currently do not load all information, only x, y, z, and score
811+
(confidence) for each keypoint. Future versions will load n of cameras
812+
and error.
813+
814+
"""
815+
file = ValidFile(
816+
file_path,
817+
expected_permission="r",
818+
expected_suffix=[".csv"],
819+
)
820+
anipose_file = ValidAniposeCSV(file.path)
821+
anipose_df = pd.read_csv(anipose_file.path)
822+
823+
return from_anipose_style_df(
824+
anipose_df, fps=fps, individual_name=individual_name
825+
)

movement/validators/files.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,94 @@ def _file_contains_expected_levels(self, attribute, value):
221221
)
222222

223223

224+
@define
225+
class ValidAniposeCSV:
226+
"""Class for validating Anipose-style 3D pose .csv files.
227+
228+
The validator ensures that the file contains the
229+
expected column names in its header (first row).
230+
231+
Attributes
232+
----------
233+
path : pathlib.Path
234+
Path to the .csv file.
235+
236+
Raises
237+
------
238+
ValueError
239+
If the .csv file does not contain the expected Anipose columns.
240+
241+
"""
242+
243+
path: Path = field(validator=validators.instance_of(Path))
244+
245+
@path.validator
246+
def _file_contains_expected_columns(self, attribute, value):
247+
"""Ensure that the .csv file contains the expected columns."""
248+
expected_column_suffixes = [
249+
"_x",
250+
"_y",
251+
"_z",
252+
"_score",
253+
"_error",
254+
"_ncams",
255+
]
256+
expected_non_keypoint_columns = [
257+
"fnum",
258+
"center_0",
259+
"center_1",
260+
"center_2",
261+
"M_00",
262+
"M_01",
263+
"M_02",
264+
"M_10",
265+
"M_11",
266+
"M_12",
267+
"M_20",
268+
"M_21",
269+
"M_22",
270+
]
271+
272+
# Read the first line of the CSV to get the headers
273+
with open(value) as f:
274+
columns = f.readline().strip().split(",")
275+
276+
# Check that all expected headers are present
277+
if not all(col in columns for col in expected_non_keypoint_columns):
278+
raise log_error(
279+
ValueError,
280+
"CSV file is missing some expected columns."
281+
f"Expected: {expected_non_keypoint_columns}.",
282+
)
283+
284+
# For other headers, check they have expected suffixes and base names
285+
other_columns = [
286+
col for col in columns if col not in expected_non_keypoint_columns
287+
]
288+
for column in other_columns:
289+
# Check suffix
290+
if not any(
291+
column.endswith(suffix) for suffix in expected_column_suffixes
292+
):
293+
raise log_error(
294+
ValueError,
295+
f"Column {column} ends with an unexpected suffix.",
296+
)
297+
# Get base name by removing suffix
298+
base = column.rsplit("_", 1)[0]
299+
# Check base name has all expected suffixes
300+
if not all(
301+
f"{base}{suffix}" in columns
302+
for suffix in expected_column_suffixes
303+
):
304+
raise log_error(
305+
ValueError,
306+
f"Keypoint {base} is missing some expected suffixes."
307+
f"Expected: {expected_column_suffixes};"
308+
f"Got: {columns}.",
309+
)
310+
311+
224312
@define
225313
class ValidVIATracksCSV:
226314
"""Class for validating VIA tracks .csv files.

tests/conftest.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,61 @@ def dlc_style_df():
199199
return pd.read_hdf(pytest.DATA_PATHS.get("DLC_single-wasp.predictions.h5"))
200200

201201

202+
@pytest.fixture
203+
def missing_keypoint_columns_anipose_csv_file(tmp_path):
204+
"""Return the file path for a fake single-individual .csv file."""
205+
file_path = tmp_path / "missing_keypoint_columns.csv"
206+
columns = [
207+
"fnum",
208+
"center_0",
209+
"center_1",
210+
"center_2",
211+
"M_00",
212+
"M_01",
213+
"M_02",
214+
"M_10",
215+
"M_11",
216+
"M_12",
217+
"M_20",
218+
"M_21",
219+
"M_22",
220+
]
221+
# Here we are missing kp0_z:
222+
columns.extend(["kp0_x", "kp0_y", "kp0_score", "kp0_error", "kp0_ncams"])
223+
with open(file_path, "w") as f:
224+
f.write(",".join(columns))
225+
f.write("\n")
226+
f.write(",".join(["1"] * len(columns)))
227+
return file_path
228+
229+
230+
@pytest.fixture
231+
def spurious_column_anipose_csv_file(tmp_path):
232+
"""Return the file path for a fake single-individual .csv file."""
233+
file_path = tmp_path / "spurious_column.csv"
234+
columns = [
235+
"fnum",
236+
"center_0",
237+
"center_1",
238+
"center_2",
239+
"M_00",
240+
"M_01",
241+
"M_02",
242+
"M_10",
243+
"M_11",
244+
"M_12",
245+
"M_20",
246+
"M_21",
247+
"M_22",
248+
]
249+
columns.extend(["funny_column"])
250+
with open(file_path, "w") as f:
251+
f.write(",".join(columns))
252+
f.write("\n")
253+
f.write(",".join(["1"] * len(columns)))
254+
return file_path
255+
256+
202257
@pytest.fixture(
203258
params=[
204259
"SLEAP_single-mouse_EPM.analysis.h5",

0 commit comments

Comments
 (0)