Skip to content

Commit a3956c4

Browse files
niksirbilochhh
andauthored
Implement compute_speed and compute_path_length (#280)
* implement compute_speed and compute_path_length functions * added speed to existing kinematics unit test * rewrote compute_path_length with various nan policies * unit test compute_path_length across time ranges * fixed and refactor compute_path_length and its tests * fixed docstring for compute_path_length * Accept suggestion on docstring wording Co-authored-by: Chang Huan Lo <[email protected]> * Remove print statement from test Co-authored-by: Chang Huan Lo <[email protected]> * Ensure nan report is printed Co-authored-by: Chang Huan Lo <[email protected]> * adapt warning message match in test * change 'any' to 'all' * uniform wording across path length docstrings * (mostly) leave time range validation to xarray slice * refactored parameters for test across time ranges * simplified test for path lenght with nans * replace drop policy with ffill * remove B905 ruff rule * make pre-commit happy --------- Co-authored-by: Chang Huan Lo <[email protected]>
1 parent ca4daf2 commit a3956c4

File tree

3 files changed

+416
-9
lines changed

3 files changed

+416
-9
lines changed

movement/kinematics.py

Lines changed: 187 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import xarray as xr
88
from scipy.spatial.distance import cdist
99

10-
from movement.utils.logging import log_error
10+
from movement.utils.logging import log_error, log_warning
11+
from movement.utils.reports import report_nan_values
1112
from movement.utils.vector import compute_norm
1213
from movement.validators.arrays import validate_dims_coords
1314

@@ -173,6 +174,30 @@ def compute_time_derivative(data: xr.DataArray, order: int) -> xr.DataArray:
173174
return result
174175

175176

177+
def compute_speed(data: xr.DataArray) -> xr.DataArray:
178+
"""Compute instantaneous speed at each time point.
179+
180+
Speed is a scalar quantity computed as the Euclidean norm (magnitude)
181+
of the velocity vector at each time point.
182+
183+
184+
Parameters
185+
----------
186+
data : xarray.DataArray
187+
The input data containing position information, with ``time``
188+
and ``space`` (in Cartesian coordinates) as required dimensions.
189+
190+
Returns
191+
-------
192+
xarray.DataArray
193+
An xarray DataArray containing the computed speed,
194+
with dimensions matching those of the input data,
195+
except ``space`` is removed.
196+
197+
"""
198+
return compute_norm(compute_velocity(data))
199+
200+
176201
def compute_forward_vector(
177202
data: xr.DataArray,
178203
left_keypoint: str,
@@ -675,3 +700,164 @@ def _validate_type_data_array(data: xr.DataArray) -> None:
675700
TypeError,
676701
f"Input data must be an xarray.DataArray, but got {type(data)}.",
677702
)
703+
704+
705+
def compute_path_length(
706+
data: xr.DataArray,
707+
start: float | None = None,
708+
stop: float | None = None,
709+
nan_policy: Literal["ffill", "scale"] = "ffill",
710+
nan_warn_threshold: float = 0.2,
711+
) -> xr.DataArray:
712+
"""Compute the length of a path travelled between two time points.
713+
714+
The path length is defined as the sum of the norms (magnitudes) of the
715+
displacement vectors between two time points ``start`` and ``stop``,
716+
which should be provided in the time units of the data array.
717+
If not specified, the minimum and maximum time coordinates of the data
718+
array are used as start and stop times, respectively.
719+
720+
Parameters
721+
----------
722+
data : xarray.DataArray
723+
The input data containing position information, with ``time``
724+
and ``space`` (in Cartesian coordinates) as required dimensions.
725+
start : float, optional
726+
The start time of the path. If None (default),
727+
the minimum time coordinate in the data is used.
728+
stop : float, optional
729+
The end time of the path. If None (default),
730+
the maximum time coordinate in the data is used.
731+
nan_policy : Literal["ffill", "scale"], optional
732+
Policy to handle NaN (missing) values. Can be one of the ``"ffill"``
733+
or ``"scale"``. Defaults to ``"ffill"`` (forward fill).
734+
See Notes for more details on the two policies.
735+
nan_warn_threshold : float, optional
736+
If more than this proportion of values are missing in any point track,
737+
a warning will be emitted. Defaults to 0.2 (20%).
738+
739+
Returns
740+
-------
741+
xarray.DataArray
742+
An xarray DataArray containing the computed path length,
743+
with dimensions matching those of the input data,
744+
except ``time`` and ``space`` are removed.
745+
746+
Notes
747+
-----
748+
Choosing ``nan_policy="ffill"`` will use :meth:`xarray.DataArray.ffill`
749+
to forward-fill missing segments (NaN values) across time.
750+
This equates to assuming that a track remains stationary for
751+
the duration of the missing segment and then instantaneously moves to
752+
the next valid position, following a straight line. This approach tends
753+
to underestimate the path length, and the error increases with the number
754+
of missing values.
755+
756+
Choosing ``nan_policy="scale"`` will adjust the path length based on the
757+
the proportion of valid segments per point track. For example, if only
758+
80% of segments are present, the path length will be computed based on
759+
these and the result will be divided by 0.8. This approach assumes
760+
that motion dynamics are similar across observed and missing time
761+
segments, which may not accurately reflect actual conditions.
762+
763+
"""
764+
validate_dims_coords(data, {"time": [], "space": []})
765+
data = data.sel(time=slice(start, stop))
766+
# Check that the data is not empty or too short
767+
n_time = data.sizes["time"]
768+
if n_time < 2:
769+
raise log_error(
770+
ValueError,
771+
f"At least 2 time points are required to compute path length, "
772+
f"but {n_time} were found. Double-check the start and stop times.",
773+
)
774+
775+
_warn_about_nan_proportion(data, nan_warn_threshold)
776+
777+
if nan_policy == "ffill":
778+
return compute_norm(
779+
compute_displacement(data.ffill(dim="time")).isel(
780+
time=slice(1, None)
781+
) # skip first displacement (always 0)
782+
).sum(dim="time", min_count=1) # return NaN if no valid segment
783+
784+
elif nan_policy == "scale":
785+
return _compute_scaled_path_length(data)
786+
else:
787+
raise log_error(
788+
ValueError,
789+
f"Invalid value for nan_policy: {nan_policy}. "
790+
"Must be one of 'ffill' or 'scale'.",
791+
)
792+
793+
794+
def _warn_about_nan_proportion(
795+
data: xr.DataArray, nan_warn_threshold: float
796+
) -> None:
797+
"""Print a warning if the proportion of NaN values exceeds a threshold.
798+
799+
The NaN proportion is evaluated per point track, and a given point is
800+
considered NaN if any of its ``space`` coordinates are NaN. The warning
801+
specifically lists the point tracks that exceed the threshold.
802+
803+
Parameters
804+
----------
805+
data : xarray.DataArray
806+
The input data array.
807+
nan_warn_threshold : float
808+
The threshold for the proportion of NaN values. Must be a number
809+
between 0 and 1.
810+
811+
"""
812+
nan_warn_threshold = float(nan_warn_threshold)
813+
if not 0 <= nan_warn_threshold <= 1:
814+
raise log_error(
815+
ValueError,
816+
"nan_warn_threshold must be between 0 and 1.",
817+
)
818+
819+
n_nans = data.isnull().any(dim="space").sum(dim="time")
820+
data_to_warn_about = data.where(
821+
n_nans > data.sizes["time"] * nan_warn_threshold, drop=True
822+
)
823+
if len(data_to_warn_about) > 0:
824+
log_warning(
825+
"The result may be unreliable for point tracks with many "
826+
"missing values. The following tracks have more than "
827+
f"{nan_warn_threshold * 100:.3} % NaN values:",
828+
)
829+
print(report_nan_values(data_to_warn_about))
830+
831+
832+
def _compute_scaled_path_length(
833+
data: xr.DataArray,
834+
) -> xr.DataArray:
835+
"""Compute scaled path length based on proportion of valid segments.
836+
837+
Path length is first computed based on valid segments (non-NaN values
838+
on both ends of the segment) and then scaled based on the proportion of
839+
valid segments per point track - i.e. the result is divided by the
840+
proportion of valid segments.
841+
842+
Parameters
843+
----------
844+
data : xarray.DataArray
845+
The input data containing position information, with ``time``
846+
and ``space`` (in Cartesian coordinates) as required dimensions.
847+
848+
Returns
849+
-------
850+
xarray.DataArray
851+
An xarray DataArray containing the computed path length,
852+
with dimensions matching those of the input data,
853+
except ``time`` and ``space`` are removed.
854+
855+
"""
856+
# Skip first displacement segment (always 0) to not mess up the scaling
857+
displacement = compute_displacement(data).isel(time=slice(1, None))
858+
# count number of valid displacement segments per point track
859+
valid_segments = (~displacement.isnull()).all(dim="space").sum(dim="time")
860+
# compute proportion of valid segments per point track
861+
valid_proportion = valid_segments / (data.sizes["time"] - 1)
862+
# return scaled path length
863+
return compute_norm(displacement).sum(dim="time") / valid_proportion

tests/conftest.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,40 @@ def valid_poses_dataset_uniform_linear_motion(
518518
)
519519

520520

521+
@pytest.fixture
522+
def valid_poses_dataset_uniform_linear_motion_with_nans(
523+
valid_poses_dataset_uniform_linear_motion,
524+
):
525+
"""Return a valid poses dataset with NaN values in the position array.
526+
527+
Specifically, we will introducde:
528+
- 1 NaN value in the centroid keypoint of individual id_1 at time=0
529+
- 5 NaN values in the left keypoint of individual id_1 (frames 3-7)
530+
- 10 NaN values in the right keypoint of individual id_1 (all frames)
531+
"""
532+
valid_poses_dataset_uniform_linear_motion.position.loc[
533+
{
534+
"individuals": "id_1",
535+
"keypoints": "centroid",
536+
"time": 0,
537+
}
538+
] = np.nan
539+
valid_poses_dataset_uniform_linear_motion.position.loc[
540+
{
541+
"individuals": "id_1",
542+
"keypoints": "left",
543+
"time": slice(3, 7),
544+
}
545+
] = np.nan
546+
valid_poses_dataset_uniform_linear_motion.position.loc[
547+
{
548+
"individuals": "id_1",
549+
"keypoints": "right",
550+
}
551+
] = np.nan
552+
return valid_poses_dataset_uniform_linear_motion
553+
554+
521555
# -------------------- Invalid datasets fixtures ------------------------------
522556
@pytest.fixture
523557
def not_a_dataset():

0 commit comments

Comments
 (0)