|
7 | 7 | import xarray as xr |
8 | 8 | from scipy.spatial.distance import cdist |
9 | 9 |
|
10 | | -from movement.utils.logging import log_error |
| 10 | +from movement.utils.logging import log_error, log_warning |
| 11 | +from movement.utils.reports import report_nan_values |
11 | 12 | from movement.utils.vector import compute_norm |
12 | 13 | from movement.validators.arrays import validate_dims_coords |
13 | 14 |
|
@@ -173,6 +174,30 @@ def compute_time_derivative(data: xr.DataArray, order: int) -> xr.DataArray: |
173 | 174 | return result |
174 | 175 |
|
175 | 176 |
|
| 177 | +def compute_speed(data: xr.DataArray) -> xr.DataArray: |
| 178 | + """Compute instantaneous speed at each time point. |
| 179 | +
|
| 180 | + Speed is a scalar quantity computed as the Euclidean norm (magnitude) |
| 181 | + of the velocity vector at each time point. |
| 182 | +
|
| 183 | +
|
| 184 | + Parameters |
| 185 | + ---------- |
| 186 | + data : xarray.DataArray |
| 187 | + The input data containing position information, with ``time`` |
| 188 | + and ``space`` (in Cartesian coordinates) as required dimensions. |
| 189 | +
|
| 190 | + Returns |
| 191 | + ------- |
| 192 | + xarray.DataArray |
| 193 | + An xarray DataArray containing the computed speed, |
| 194 | + with dimensions matching those of the input data, |
| 195 | + except ``space`` is removed. |
| 196 | +
|
| 197 | + """ |
| 198 | + return compute_norm(compute_velocity(data)) |
| 199 | + |
| 200 | + |
176 | 201 | def compute_forward_vector( |
177 | 202 | data: xr.DataArray, |
178 | 203 | left_keypoint: str, |
@@ -675,3 +700,164 @@ def _validate_type_data_array(data: xr.DataArray) -> None: |
675 | 700 | TypeError, |
676 | 701 | f"Input data must be an xarray.DataArray, but got {type(data)}.", |
677 | 702 | ) |
| 703 | + |
| 704 | + |
| 705 | +def compute_path_length( |
| 706 | + data: xr.DataArray, |
| 707 | + start: float | None = None, |
| 708 | + stop: float | None = None, |
| 709 | + nan_policy: Literal["ffill", "scale"] = "ffill", |
| 710 | + nan_warn_threshold: float = 0.2, |
| 711 | +) -> xr.DataArray: |
| 712 | + """Compute the length of a path travelled between two time points. |
| 713 | +
|
| 714 | + The path length is defined as the sum of the norms (magnitudes) of the |
| 715 | + displacement vectors between two time points ``start`` and ``stop``, |
| 716 | + which should be provided in the time units of the data array. |
| 717 | + If not specified, the minimum and maximum time coordinates of the data |
| 718 | + array are used as start and stop times, respectively. |
| 719 | +
|
| 720 | + Parameters |
| 721 | + ---------- |
| 722 | + data : xarray.DataArray |
| 723 | + The input data containing position information, with ``time`` |
| 724 | + and ``space`` (in Cartesian coordinates) as required dimensions. |
| 725 | + start : float, optional |
| 726 | + The start time of the path. If None (default), |
| 727 | + the minimum time coordinate in the data is used. |
| 728 | + stop : float, optional |
| 729 | + The end time of the path. If None (default), |
| 730 | + the maximum time coordinate in the data is used. |
| 731 | + nan_policy : Literal["ffill", "scale"], optional |
| 732 | + Policy to handle NaN (missing) values. Can be one of the ``"ffill"`` |
| 733 | + or ``"scale"``. Defaults to ``"ffill"`` (forward fill). |
| 734 | + See Notes for more details on the two policies. |
| 735 | + nan_warn_threshold : float, optional |
| 736 | + If more than this proportion of values are missing in any point track, |
| 737 | + a warning will be emitted. Defaults to 0.2 (20%). |
| 738 | +
|
| 739 | + Returns |
| 740 | + ------- |
| 741 | + xarray.DataArray |
| 742 | + An xarray DataArray containing the computed path length, |
| 743 | + with dimensions matching those of the input data, |
| 744 | + except ``time`` and ``space`` are removed. |
| 745 | +
|
| 746 | + Notes |
| 747 | + ----- |
| 748 | + Choosing ``nan_policy="ffill"`` will use :meth:`xarray.DataArray.ffill` |
| 749 | + to forward-fill missing segments (NaN values) across time. |
| 750 | + This equates to assuming that a track remains stationary for |
| 751 | + the duration of the missing segment and then instantaneously moves to |
| 752 | + the next valid position, following a straight line. This approach tends |
| 753 | + to underestimate the path length, and the error increases with the number |
| 754 | + of missing values. |
| 755 | +
|
| 756 | + Choosing ``nan_policy="scale"`` will adjust the path length based on the |
| 757 | + the proportion of valid segments per point track. For example, if only |
| 758 | + 80% of segments are present, the path length will be computed based on |
| 759 | + these and the result will be divided by 0.8. This approach assumes |
| 760 | + that motion dynamics are similar across observed and missing time |
| 761 | + segments, which may not accurately reflect actual conditions. |
| 762 | +
|
| 763 | + """ |
| 764 | + validate_dims_coords(data, {"time": [], "space": []}) |
| 765 | + data = data.sel(time=slice(start, stop)) |
| 766 | + # Check that the data is not empty or too short |
| 767 | + n_time = data.sizes["time"] |
| 768 | + if n_time < 2: |
| 769 | + raise log_error( |
| 770 | + ValueError, |
| 771 | + f"At least 2 time points are required to compute path length, " |
| 772 | + f"but {n_time} were found. Double-check the start and stop times.", |
| 773 | + ) |
| 774 | + |
| 775 | + _warn_about_nan_proportion(data, nan_warn_threshold) |
| 776 | + |
| 777 | + if nan_policy == "ffill": |
| 778 | + return compute_norm( |
| 779 | + compute_displacement(data.ffill(dim="time")).isel( |
| 780 | + time=slice(1, None) |
| 781 | + ) # skip first displacement (always 0) |
| 782 | + ).sum(dim="time", min_count=1) # return NaN if no valid segment |
| 783 | + |
| 784 | + elif nan_policy == "scale": |
| 785 | + return _compute_scaled_path_length(data) |
| 786 | + else: |
| 787 | + raise log_error( |
| 788 | + ValueError, |
| 789 | + f"Invalid value for nan_policy: {nan_policy}. " |
| 790 | + "Must be one of 'ffill' or 'scale'.", |
| 791 | + ) |
| 792 | + |
| 793 | + |
| 794 | +def _warn_about_nan_proportion( |
| 795 | + data: xr.DataArray, nan_warn_threshold: float |
| 796 | +) -> None: |
| 797 | + """Print a warning if the proportion of NaN values exceeds a threshold. |
| 798 | +
|
| 799 | + The NaN proportion is evaluated per point track, and a given point is |
| 800 | + considered NaN if any of its ``space`` coordinates are NaN. The warning |
| 801 | + specifically lists the point tracks that exceed the threshold. |
| 802 | +
|
| 803 | + Parameters |
| 804 | + ---------- |
| 805 | + data : xarray.DataArray |
| 806 | + The input data array. |
| 807 | + nan_warn_threshold : float |
| 808 | + The threshold for the proportion of NaN values. Must be a number |
| 809 | + between 0 and 1. |
| 810 | +
|
| 811 | + """ |
| 812 | + nan_warn_threshold = float(nan_warn_threshold) |
| 813 | + if not 0 <= nan_warn_threshold <= 1: |
| 814 | + raise log_error( |
| 815 | + ValueError, |
| 816 | + "nan_warn_threshold must be between 0 and 1.", |
| 817 | + ) |
| 818 | + |
| 819 | + n_nans = data.isnull().any(dim="space").sum(dim="time") |
| 820 | + data_to_warn_about = data.where( |
| 821 | + n_nans > data.sizes["time"] * nan_warn_threshold, drop=True |
| 822 | + ) |
| 823 | + if len(data_to_warn_about) > 0: |
| 824 | + log_warning( |
| 825 | + "The result may be unreliable for point tracks with many " |
| 826 | + "missing values. The following tracks have more than " |
| 827 | + f"{nan_warn_threshold * 100:.3} % NaN values:", |
| 828 | + ) |
| 829 | + print(report_nan_values(data_to_warn_about)) |
| 830 | + |
| 831 | + |
| 832 | +def _compute_scaled_path_length( |
| 833 | + data: xr.DataArray, |
| 834 | +) -> xr.DataArray: |
| 835 | + """Compute scaled path length based on proportion of valid segments. |
| 836 | +
|
| 837 | + Path length is first computed based on valid segments (non-NaN values |
| 838 | + on both ends of the segment) and then scaled based on the proportion of |
| 839 | + valid segments per point track - i.e. the result is divided by the |
| 840 | + proportion of valid segments. |
| 841 | +
|
| 842 | + Parameters |
| 843 | + ---------- |
| 844 | + data : xarray.DataArray |
| 845 | + The input data containing position information, with ``time`` |
| 846 | + and ``space`` (in Cartesian coordinates) as required dimensions. |
| 847 | +
|
| 848 | + Returns |
| 849 | + ------- |
| 850 | + xarray.DataArray |
| 851 | + An xarray DataArray containing the computed path length, |
| 852 | + with dimensions matching those of the input data, |
| 853 | + except ``time`` and ``space`` are removed. |
| 854 | +
|
| 855 | + """ |
| 856 | + # Skip first displacement segment (always 0) to not mess up the scaling |
| 857 | + displacement = compute_displacement(data).isel(time=slice(1, None)) |
| 858 | + # count number of valid displacement segments per point track |
| 859 | + valid_segments = (~displacement.isnull()).all(dim="space").sum(dim="time") |
| 860 | + # compute proportion of valid segments per point track |
| 861 | + valid_proportion = valid_segments / (data.sizes["time"] - 1) |
| 862 | + # return scaled path length |
| 863 | + return compute_norm(displacement).sum(dim="time") / valid_proportion |
0 commit comments