Skip to content

Commit 92cde26

Browse files
committed
Extend napari dimension slider tests to bboxes data
1 parent 8d49d6a commit 92cde26

File tree

2 files changed

+212
-1
lines changed

2 files changed

+212
-1
lines changed

tests/fixtures/napari.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pandas as pd
55
import pytest
66

7-
from movement.io import save_poses
7+
from movement.io import save_bboxes, save_poses
88

99

1010
@pytest.fixture
@@ -129,6 +129,111 @@ def valid_poses_path_and_ds_nan_end(
129129
return (out_path, ds)
130130

131131

132+
@pytest.fixture
133+
def valid_bboxes_path_and_ds(valid_bboxes_dataset, tmp_path):
134+
"""Return a (path, dataset) pair representing a bboxes dataset
135+
with data for 10 frames.
136+
137+
The fixture is derived from the ``valid_bboxes_dataset`` fixture.
138+
"""
139+
out_path = tmp_path / "ds_bboxes.csv"
140+
save_bboxes.to_via_tracks_file(valid_bboxes_dataset, out_path)
141+
return (out_path, valid_bboxes_dataset)
142+
143+
144+
@pytest.fixture
145+
def valid_bboxes_path_and_ds_short(valid_bboxes_dataset, tmp_path):
146+
"""Return a (path, dataset) pair representing a bboxes dataset
147+
with data for 5 frames.
148+
149+
The fixture is derived from the ``valid_bboxes_dataset`` fixture.
150+
"""
151+
valid_bboxes_dataset = valid_bboxes_dataset.sel(time=slice(0, 5))
152+
out_path = tmp_path / "ds_bboxes_short.csv"
153+
save_bboxes.to_via_tracks_file(valid_bboxes_dataset, out_path)
154+
return (out_path, valid_bboxes_dataset)
155+
156+
157+
@pytest.fixture
158+
def valid_bboxes_path_and_ds_with_localised_nans(
159+
valid_bboxes_dataset, tmp_path
160+
):
161+
"""Return a factory of (path, dataset) pairs representing
162+
valid bboxes datasets with NaN values at specific locations.
163+
"""
164+
ds = valid_bboxes_dataset.copy(deep=True)
165+
166+
def _valid_bboxes_path_and_ds_with_localised_nans(
167+
nan_location, filename="ds_bboxes_with_nans.csv"
168+
):
169+
"""Return a valid bboxes dataset and corresponding file with NaN
170+
values at specific locations.
171+
172+
The ``nan_location`` parameter is a dictionary with keys
173+
``"time"`` and ``"individuals"`` specifying which coordinates
174+
to set to NaN.
175+
"""
176+
if nan_location["time"] == "start":
177+
time_point = 0
178+
elif nan_location["time"] == "middle":
179+
time_point = ds.coords["time"][
180+
ds.coords["time"].shape[0] // 2
181+
]
182+
elif nan_location["time"] == "end":
183+
time_point = ds.coords["time"][-1]
184+
185+
ds.position.loc[
186+
{
187+
"individuals": nan_location["individuals"],
188+
"time": time_point,
189+
}
190+
] = np.nan
191+
192+
out_path = tmp_path / filename
193+
save_bboxes.to_via_tracks_file(ds, out_path)
194+
return (out_path, ds)
195+
196+
return _valid_bboxes_path_and_ds_with_localised_nans
197+
198+
199+
@pytest.fixture
200+
def valid_bboxes_path_and_ds_nan_start(
201+
valid_bboxes_path_and_ds_with_localised_nans,
202+
):
203+
"""Return a (path, dataset) pair for a bboxes dataset with NaN
204+
values for one individual at the first frame.
205+
206+
Only one individual is set to NaN because VIA-tracks CSV format
207+
drops frames where all individuals have NaN positions.
208+
"""
209+
return valid_bboxes_path_and_ds_with_localised_nans(
210+
{
211+
"time": "start",
212+
"individuals": ["id_0"],
213+
},
214+
filename="ds_bboxes_with_nan_start.csv",
215+
)
216+
217+
218+
@pytest.fixture
219+
def valid_bboxes_path_and_ds_nan_end(
220+
valid_bboxes_path_and_ds_with_localised_nans,
221+
):
222+
"""Return a (path, dataset) pair for a bboxes dataset with NaN
223+
values for one individual at the last frame.
224+
225+
Only one individual is set to NaN because VIA-tracks CSV format
226+
drops frames where all individuals have NaN positions.
227+
"""
228+
return valid_bboxes_path_and_ds_with_localised_nans(
229+
{
230+
"time": "end",
231+
"individuals": ["id_0"],
232+
},
233+
filename="ds_bboxes_with_nan_end.csv",
234+
)
235+
236+
132237
@pytest.fixture
133238
def sample_layer_data(rng):
134239
"""Return a dictionary of sample data for each napari layer type."""

tests/test_unit/test_napari_plugin/test_data_loader_widget.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,65 @@ def test_dimension_slider_with_nans(
567567
)
568568

569569

570+
@pytest.mark.parametrize(
571+
"nan_time_location",
572+
["start", "middle", "end"],
573+
)
574+
def test_dimension_slider_with_nans_bboxes(
575+
valid_bboxes_path_and_ds_with_localised_nans,
576+
nan_time_location,
577+
make_napari_viewer_proxy,
578+
):
579+
"""Test that the dimension slider is set to the total number of frames
580+
when bboxes data layers with NaNs are loaded.
581+
582+
Only one individual is set to NaN at a given time, because the
583+
VIA-tracks CSV format drops frames where all individuals have NaN
584+
positions, making the round-trip lossy for all-NaN frames.
585+
"""
586+
nan_location = {
587+
"time": nan_time_location,
588+
"individuals": ["id_0"],
589+
}
590+
file_path, ds = valid_bboxes_path_and_ds_with_localised_nans(
591+
nan_location
592+
)
593+
594+
# Define the expected frame index with the NaN value
595+
if nan_location["time"] == "start":
596+
expected_frame = ds.coords["time"][0]
597+
elif nan_location["time"] == "middle":
598+
expected_frame = ds.coords["time"][ds.coords["time"].shape[0] // 2]
599+
elif nan_location["time"] == "end":
600+
expected_frame = ds.coords["time"][-1]
601+
602+
# Load the data loader widget
603+
viewer = make_napari_viewer_proxy()
604+
data_loader_widget = DataLoader(viewer)
605+
606+
# Read sample data with a NaN at the specified location
607+
data_loader_widget.file_path_edit.setText(file_path.as_posix())
608+
data_loader_widget.source_software_combo.setCurrentText("VIA-tracks")
609+
610+
# Check the data contains nans where expected
611+
assert (
612+
ds.position.sel(
613+
individuals=nan_location["individuals"],
614+
time=expected_frame,
615+
)
616+
.isnull()
617+
.all()
618+
)
619+
620+
# Call the _on_load_clicked method
621+
data_loader_widget._on_load_clicked()
622+
623+
# Check the frame slider is set to the full range of frames
624+
assert viewer.dims.range[0] == RangeTuple(
625+
start=0.0, stop=ds.position.shape[0] - 1, step=1.0
626+
)
627+
628+
570629
@pytest.mark.parametrize(
571630
"list_input_data_files",
572631
[
@@ -614,6 +673,53 @@ def test_dimension_slider_multiple_files(
614673
assert max_frames == ds_long.sizes["time"]
615674

616675

676+
@pytest.mark.parametrize(
677+
"list_input_data_files",
678+
[
679+
["valid_bboxes_path_and_ds", "valid_bboxes_path_and_ds_short"],
680+
["valid_bboxes_path_and_ds_short", "valid_bboxes_path_and_ds"],
681+
],
682+
ids=["long_first", "short_first"],
683+
)
684+
def test_dimension_slider_multiple_files_bboxes(
685+
list_input_data_files, make_napari_viewer_proxy, request
686+
):
687+
"""Test that the dimension slider is set to the maximum number of frames
688+
when multiple bboxes data layers are loaded.
689+
"""
690+
# Get the datasets to load (paths and ds)
691+
list_paths, list_datasets = [
692+
[
693+
request.getfixturevalue(file_name)[j]
694+
for file_name in list_input_data_files
695+
]
696+
for j in range(len(list_input_data_files))
697+
]
698+
699+
# Get the maximum number of frames from all datasets
700+
max_frames = max(ds.sizes["time"] for ds in list_datasets)
701+
702+
# Load the data loader widget
703+
viewer = make_napari_viewer_proxy()
704+
data_loader_widget = DataLoader(viewer)
705+
706+
# Load each dataset in order
707+
for file_path in list_paths:
708+
data_loader_widget.file_path_edit.setText(file_path.as_posix())
709+
data_loader_widget.source_software_combo.setCurrentText("VIA-tracks")
710+
data_loader_widget._on_load_clicked()
711+
712+
# Check the frame slider is as expected
713+
assert viewer.dims.range[0] == RangeTuple(
714+
start=0.0, stop=max_frames - 1, step=1.0
715+
)
716+
717+
# Check the maximum number of frames is the number of frames
718+
# in the longest dataset
719+
_, ds_long = request.getfixturevalue("valid_bboxes_path_and_ds")
720+
assert max_frames == ds_long.sizes["time"]
721+
722+
617723
@pytest.mark.parametrize(
618724
"list_input_data_files",
619725
[

0 commit comments

Comments
 (0)