Skip to content

Commit 174817d

Browse files
vigjiniksirbi
andauthored
Add the option for multiple views loading - take2 (#346)
* moving to new branch * final test fix * Update movement/io/load_poses.py Co-authored-by: Niko Sirmpilatze <[email protected]> * doc change * fixed doc --------- Co-authored-by: Niko Sirmpilatze <[email protected]>
1 parent a3956c4 commit 174817d

File tree

3 files changed

+66
-1
lines changed

3 files changed

+66
-1
lines changed

docs/source/getting_started/movement_dataset.md

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ To discuss the specifics of both types of `movement` datasets, it is useful to c
1616
To learn more about `xarray` data structures in general, see the relevant
1717
[documentation](xarray:user-guide/data-structures.html).
1818

19-
2019
## Dataset structure
2120

2221
```{figure} ../_static/dataset_structure.png
@@ -135,6 +134,20 @@ In both cases, appropriate **coordinates** are assigned to each **dimension**.
135134
- `space` is labelled with either `x`, `y` (2D) or `x`, `y`, `z` (3D). Note that bounding boxes datasets are restricted to 2D space.
136135
- `time` is labelled in seconds if `fps` is provided, otherwise the **coordinates** are expressed in frames (ascending 0-indexed integers).
137136

137+
:::{dropdown} Additional dimensions
138+
:color: info
139+
:icon: info
140+
The above **dimensions** and **coordinates** are created
141+
by default when loading a `movement` dataset from a single
142+
file containing pose or bounding boxes tracks.
143+
144+
In some cases, you may encounter or create datasets with extra
145+
**dimensions**. For example, the
146+
{func}`movement.io.load_poses.from_multiview_files()` function
147+
creates an additional `views` **dimension**,
148+
with the **coordinates** being the names given to each camera view.
149+
:::
150+
138151
### Data variables
139152
The data variables in a `movement` dataset are the arrays that hold the actual data, as {class}`xarray.DataArray` objects.
140153

movement/io/load_poses.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,41 @@ def from_dlc_file(
351351
)
352352

353353

354+
def from_multiview_files(
355+
file_path_dict: dict[str, Path | str],
356+
source_software: Literal["DeepLabCut", "SLEAP", "LightningPose"],
357+
fps: float | None = None,
358+
) -> xr.Dataset:
359+
"""Load and merge pose tracking data from multiple views (cameras).
360+
361+
Parameters
362+
----------
363+
file_path_dict : dict[str, Union[Path, str]]
364+
A dict whose keys are the view names and values are the paths to load.
365+
source_software : {'LightningPose', 'SLEAP', 'DeepLabCut'}
366+
The source software of the file.
367+
fps : float, optional
368+
The number of frames per second in the video. If None (default),
369+
the `time` coordinates will be in frame numbers.
370+
371+
Returns
372+
-------
373+
xarray.Dataset
374+
``movement`` dataset containing the pose tracks, confidence scores,
375+
and associated metadata, with an additional ``views`` dimension.
376+
377+
"""
378+
views_list = list(file_path_dict.keys())
379+
new_coord_views = xr.DataArray(views_list, dims="view")
380+
381+
dataset_list = [
382+
from_file(f, source_software=source_software, fps=fps)
383+
for f in file_path_dict.values()
384+
]
385+
386+
return xr.concat(dataset_list, dim=new_coord_views)
387+
388+
354389
def _ds_from_lp_or_dlc_file(
355390
file_path: Path | str,
356391
source_software: Literal["LightningPose", "DeepLabCut"],

tests/test_unit/test_load_poses.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,3 +300,20 @@ def test_from_numpy_valid(
300300
source_software=source_software,
301301
)
302302
self.assert_dataset(ds, expected_source_software=source_software)
303+
304+
def test_from_multiview_files(self):
305+
"""Test that the from_file() function delegates to the correct
306+
loader function according to the source_software.
307+
"""
308+
view_names = ["view_0", "view_1"]
309+
file_path_dict = {
310+
view: DATA_PATHS.get("DLC_single-wasp.predictions.h5")
311+
for view in view_names
312+
}
313+
multi_view_ds = load_poses.from_multiview_files(
314+
file_path_dict, source_software="DeepLabCut"
315+
)
316+
317+
assert isinstance(multi_view_ds, xr.Dataset)
318+
assert "view" in multi_view_ds.dims
319+
assert multi_view_ds.view.values.tolist() == view_names

0 commit comments

Comments
 (0)