Skip to content

Commit 536de0e

Browse files
lochhhniksirbipre-commit-ci[bot]
authored
Compute pairwise distances (#278)
* Draft inter-individual distances * Return vector norm in `compute_interindividual_distances` * Add `compute_interkeypoint_distances` * Refactor pairwise distances tests * Use `scipy.spatial.distance.cdist` * Add examples to docstrings * Rename variables * Update test function args + fix indentation * Handle scalar and 1d dims * Handle missing `core_dim` * Refactor `cdist` and tests * Fix docstrings * Reorder functions + cleanup docs * Reduce pairwise distances functions * Mention examples of available distance metrics * Update docstrings * Require `pairs` in `compute_pairwise_distances` * Raise error if there are no pairs to compute distances for * Rename `core_dim` to `labels_dim` * Spell out expected pairs in test * Merge old `kinematics` file changes into new * Rename `core_dim` to `labels_dim` in tests * Validate dims in `compute_pairwise_distances` * Apply suggestions from code review Co-authored-by: Niko Sirmpilatze <[email protected]> * Apply suggestions from code review Co-authored-by: Niko Sirmpilatze <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Niko Sirmpilatze <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b10896f commit 536de0e

File tree

3 files changed

+540
-3
lines changed

3 files changed

+540
-3
lines changed

movement/kinematics.py

Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
"""Compute kinematic variables like velocity and acceleration."""
22

3+
import itertools
34
from typing import Literal
45

56
import numpy as np
67
import xarray as xr
8+
from scipy.spatial.distance import cdist
79

810
from movement.utils.logging import log_error
911
from movement.utils.vector import compute_norm
@@ -324,6 +326,336 @@ def compute_head_direction_vector(
324326
)
325327

326328

329+
def _cdist(
330+
a: xr.DataArray,
331+
b: xr.DataArray,
332+
dim: Literal["individuals", "keypoints"],
333+
metric: str | None = "euclidean",
334+
**kwargs,
335+
) -> xr.DataArray:
336+
"""Compute distances between two position arrays across a given dimension.
337+
338+
This function is a wrapper around :func:`scipy.spatial.distance.cdist`
339+
and computes the pairwise distances between the two input position arrays
340+
across the dimension specified by ``dim``.
341+
The dimension can be either ``individuals`` or ``keypoints``.
342+
The distances are computed using the specified ``metric``.
343+
344+
Parameters
345+
----------
346+
a : xarray.DataArray
347+
The first input data containing position information of a
348+
single individual or keypoint, with ``time``, ``space``
349+
(in Cartesian coordinates), and ``individuals`` or ``keypoints``
350+
(as specified by ``dim``) as required dimensions.
351+
b : xarray.DataArray
352+
The second input data containing position information of a
353+
single individual or keypoint, with ``time``, ``space``
354+
(in Cartesian coordinates), and ``individuals`` or ``keypoints``
355+
(as specified by ``dim``) as required dimensions.
356+
dim : str
357+
The dimension to compute the distances for. Must be either
358+
``'individuals'`` or ``'keypoints'``.
359+
metric : str, optional
360+
The distance metric to use. Must be one of the options supported
361+
by :func:`scipy.spatial.distance.cdist`, e.g. ``'cityblock'``,
362+
``'euclidean'``, etc.
363+
Defaults to ``'euclidean'``.
364+
**kwargs : dict
365+
Additional keyword arguments to pass to
366+
:func:`scipy.spatial.distance.cdist`.
367+
368+
Returns
369+
-------
370+
xarray.DataArray
371+
An xarray DataArray containing the computed distances between
372+
each pair of inputs.
373+
374+
Examples
375+
--------
376+
Compute the Euclidean distance (default) between ``ind1`` and
377+
``ind2`` (i.e. interindividual distance for all keypoints)
378+
using the ``position`` data variable in the Dataset ``ds``:
379+
380+
>>> pos1 = ds.position.sel(individuals="ind1")
381+
>>> pos2 = ds.position.sel(individuals="ind2")
382+
>>> ind_dists = _cdist(pos1, pos2, dim="individuals")
383+
384+
Compute the Euclidean distance (default) between ``key1`` and
385+
``key2`` (i.e. interkeypoint distance for all individuals)
386+
using the ``position`` data variable in the Dataset ``ds``:
387+
388+
>>> pos1 = ds.position.sel(keypoints="key1")
389+
>>> pos2 = ds.position.sel(keypoints="key2")
390+
>>> key_dists = _cdist(pos1, pos2, dim="keypoints")
391+
392+
See Also
393+
--------
394+
scipy.spatial.distance.cdist : The underlying function used.
395+
compute_pairwise_distances : Compute pairwise distances between
396+
``individuals`` or ``keypoints``
397+
398+
"""
399+
# The dimension from which ``dim`` labels are obtained
400+
labels_dim = "individuals" if dim == "keypoints" else "keypoints"
401+
elem1 = getattr(a, dim).item()
402+
elem2 = getattr(b, dim).item()
403+
a = _validate_labels_dimension(a, labels_dim)
404+
b = _validate_labels_dimension(b, labels_dim)
405+
result = xr.apply_ufunc(
406+
cdist,
407+
a,
408+
b,
409+
kwargs={"metric": metric, **kwargs},
410+
input_core_dims=[[labels_dim, "space"], [labels_dim, "space"]],
411+
output_core_dims=[[elem1, elem2]],
412+
vectorize=True,
413+
)
414+
result = result.assign_coords(
415+
{
416+
elem1: getattr(a, labels_dim).values,
417+
elem2: getattr(a, labels_dim).values,
418+
}
419+
)
420+
# Drop any squeezed coordinates
421+
return result.squeeze(drop=True)
422+
423+
424+
def compute_pairwise_distances(
425+
data: xr.DataArray,
426+
dim: Literal["individuals", "keypoints"],
427+
pairs: dict[str, str | list[str]] | Literal["all"],
428+
metric: str | None = "euclidean",
429+
**kwargs,
430+
) -> xr.DataArray | dict[str, xr.DataArray]:
431+
"""Compute pairwise distances between ``individuals`` or ``keypoints``.
432+
433+
This function computes the distances between
434+
pairs of ``individuals`` (i.e. interindividual distances) or
435+
pairs of ``keypoints`` (i.e. interkeypoint distances),
436+
as determined by ``dim``.
437+
The distances are computed for the given ``pairs``
438+
using the specified ``metric``.
439+
440+
Parameters
441+
----------
442+
data : xarray.DataArray
443+
The input data containing position information, with ``time``,
444+
``space`` (in Cartesian coordinates), and
445+
``individuals`` or ``keypoints`` (as specified by ``dim``)
446+
as required dimensions.
447+
dim : Literal["individuals", "keypoints"]
448+
The dimension to compute the distances for. Must be either
449+
``'individuals'`` or ``'keypoints'``.
450+
pairs : dict[str, str | list[str]] or 'all'
451+
Specifies the pairs of elements (either individuals or keypoints)
452+
for which to compute distances, depending on the value of ``dim``.
453+
454+
- If ``dim='individuals'``, ``pairs`` should be a dictionary where
455+
each key is an individual name, and each value is also an individual
456+
name or a list of such names to compute distances with.
457+
- If ``dim='keypoints'``, ``pairs`` should be a dictionary where each
458+
key is a keypoint name, and each value is also keypoint name or a
459+
list of such names to compute distances with.
460+
- Alternatively, use the special keyword ``'all'`` to compute distances
461+
for all possible pairs of individuals or keypoints
462+
(depending on ``dim``).
463+
metric : str, optional
464+
The distance metric to use. Must be one of the options supported
465+
by :func:`scipy.spatial.distance.cdist`, e.g. ``'cityblock'``,
466+
``'euclidean'``, etc.
467+
Defaults to ``'euclidean'``.
468+
**kwargs : dict
469+
Additional keyword arguments to pass to
470+
:func:`scipy.spatial.distance.cdist`.
471+
472+
Returns
473+
-------
474+
xarray.DataArray or dict[str, xarray.DataArray]
475+
The computed pairwise distances. If a single pair is specified in
476+
``pairs``, returns an :class:`xarray.DataArray`. If multiple pairs
477+
are specified, returns a dictionary where each key is a string
478+
representing the pair (e.g., ``'dist_ind1_ind2'`` or
479+
``'dist_key1_key2'``) and each value is an :class:`xarray.DataArray`
480+
containing the computed distances for that pair.
481+
482+
Raises
483+
------
484+
ValueError
485+
If ``dim`` is not one of ``'individuals'`` or ``'keypoints'``;
486+
if ``pairs`` is not a dictionary or ``'all'``; or
487+
if there are no pairs in ``data`` to compute distances for.
488+
489+
Examples
490+
--------
491+
Compute the Euclidean distance (default) between ``ind1`` and ``ind2``
492+
(i.e. interindividual distance), for all possible pairs of keypoints.
493+
494+
>>> position = xr.DataArray(
495+
... np.arange(36).reshape(2, 3, 3, 2),
496+
... coords={
497+
... "time": np.arange(2),
498+
... "individuals": ["ind1", "ind2", "ind3"],
499+
... "keypoints": ["key1", "key2", "key3"],
500+
... "space": ["x", "y"],
501+
... },
502+
... dims=["time", "individuals", "keypoints", "space"],
503+
... )
504+
>>> dist_ind1_ind2 = compute_pairwise_distances(
505+
... position, "individuals", {"ind1": "ind2"}
506+
... )
507+
>>> dist_ind1_ind2
508+
<xarray.DataArray (time: 2, ind1: 3, ind2: 3)> Size: 144B
509+
8.485 11.31 14.14 5.657 8.485 11.31 ... 5.657 8.485 11.31 2.828 5.657 8.485
510+
Coordinates:
511+
* time (time) int64 16B 0 1
512+
* ind1 (ind1) <U4 48B 'key1' 'key2' 'key3'
513+
* ind2 (ind2) <U4 48B 'key1' 'key2' 'key3'
514+
515+
The resulting ``dist_ind1_ind2`` is a DataArray containing the computed
516+
distances between ``ind1`` and ``ind2`` for all keypoints
517+
at each time point.
518+
519+
To obtain the distances between ``key1`` of ``ind1`` and
520+
``key2`` of ``ind2``:
521+
522+
>>> dist_ind1_ind2.sel(ind1="key1", ind2="key2")
523+
524+
Compute the Euclidean distance (default) between ``key1`` and ``key2``
525+
(i.e. interkeypoint distance), for all possible pairs of individuals.
526+
527+
>>> dist_key1_key2 = compute_pairwise_distances(
528+
... position, "keypoints", {"key1": "key2"}
529+
... )
530+
>>> dist_key1_key2
531+
<xarray.DataArray (time: 2, key1: 3, key2: 3)> Size: 144B
532+
2.828 11.31 19.8 5.657 2.828 11.31 14.14 ... 2.828 11.31 14.14 5.657 2.828
533+
Coordinates:
534+
* time (time) int64 16B 0 1
535+
* key1 (key1) <U4 48B 'ind1' 'ind2' 'ind3'
536+
* key2 (key2) <U4 48B 'ind1' 'ind2' 'ind3'
537+
538+
The resulting ``dist_key1_key2`` is a DataArray containing the computed
539+
distances between ``key1`` and ``key2`` for all individuals
540+
at each time point.
541+
542+
To obtain the distances between ``key1`` and ``key2`` within ``ind1``:
543+
544+
>>> dist_key1_key2.sel(key1="ind1", key2="ind1")
545+
546+
To obtain the distances between ``key1`` of ``ind1`` and
547+
``key2`` of ``ind2``:
548+
549+
>>> dist_key1_key2.sel(key1="ind1", key2="ind2")
550+
551+
Compute the city block or Manhattan distance for multiple pairs of
552+
keypoints using ``position``:
553+
554+
>>> key_dists = compute_pairwise_distances(
555+
... position,
556+
... "keypoints",
557+
... {"key1": "key2", "key3": ["key1", "key2"]},
558+
... metric="cityblock",
559+
... )
560+
>>> key_dists.keys()
561+
dict_keys(['dist_key1_key2', 'dist_key3_key1', 'dist_key3_key2'])
562+
563+
As multiple pairs of keypoints are specified,
564+
the resulting ``key_dists`` is a dictionary containing the DataArrays
565+
of computed distances for each pair of keypoints.
566+
567+
Compute the city block or Manhattan distance for all possible pairs of
568+
individuals using ``position``:
569+
570+
>>> ind_dists = compute_pairwise_distances(
571+
... position,
572+
... "individuals",
573+
... "all",
574+
... metric="cityblock",
575+
... )
576+
>>> ind_dists.keys()
577+
dict_keys(['dist_ind1_ind2', 'dist_ind1_ind3', 'dist_ind2_ind3'])
578+
579+
See Also
580+
--------
581+
scipy.spatial.distance.cdist : The underlying function used.
582+
583+
"""
584+
if dim not in ["individuals", "keypoints"]:
585+
raise log_error(
586+
ValueError,
587+
"'dim' must be either 'individuals' or 'keypoints', "
588+
f"but got {dim}.",
589+
)
590+
if isinstance(pairs, str) and pairs != "all":
591+
raise log_error(
592+
ValueError,
593+
f"'pairs' must be a dictionary or 'all', but got {pairs}.",
594+
)
595+
validate_dims_coords(data, {"time": [], "space": ["x", "y"], dim: []})
596+
# Find all possible pair combinations if 'all' is specified
597+
if pairs == "all":
598+
paired_elements = list(
599+
itertools.combinations(getattr(data, dim).values, 2)
600+
)
601+
else:
602+
paired_elements = [
603+
(elem1, elem2)
604+
for elem1, elem2_list in pairs.items()
605+
for elem2 in
606+
(
607+
# Ensure elem2_list is a list
608+
[elem2_list] if isinstance(elem2_list, str) else elem2_list
609+
)
610+
]
611+
if not paired_elements:
612+
raise log_error(
613+
ValueError, "Could not find any pairs to compute distances for."
614+
)
615+
pairwise_distances = {
616+
f"dist_{elem1}_{elem2}": _cdist(
617+
data.sel({dim: elem1}),
618+
data.sel({dim: elem2}),
619+
dim=dim,
620+
metric=metric,
621+
**kwargs,
622+
)
623+
for elem1, elem2 in paired_elements
624+
}
625+
# Return DataArray if result only has one key
626+
if len(pairwise_distances) == 1:
627+
return next(iter(pairwise_distances.values()))
628+
return pairwise_distances
629+
630+
631+
def _validate_labels_dimension(data: xr.DataArray, dim: str) -> xr.DataArray:
632+
"""Validate the input data contains the ``dim`` for labelling dimensions.
633+
634+
This function ensures the input data contains the ``dim``
635+
used as labels (coordinates) when applying
636+
:func:`scipy.spatial.distance.cdist` to
637+
the input data, by adding a temporary dimension if necessary.
638+
639+
Parameters
640+
----------
641+
data : xarray.DataArray
642+
The input data to validate.
643+
dim : str
644+
The dimension to validate.
645+
646+
Returns
647+
-------
648+
xarray.DataArray
649+
The input data with the labels dimension validated.
650+
651+
"""
652+
if data.coords.get(dim) is None:
653+
data = data.assign_coords({dim: "temp_dim"})
654+
if data.coords[dim].ndim == 0:
655+
data = data.expand_dims(dim).transpose("time", "space", dim)
656+
return data
657+
658+
327659
def _validate_type_data_array(data: xr.DataArray) -> None:
328660
"""Validate the input data is an xarray DataArray.
329661

tests/conftest.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,7 @@ def missing_two_dims_bboxes_dataset(valid_bboxes_dataset):
567567
return valid_bboxes_dataset.rename({"time": "tame", "space": "spice"})
568568

569569

570+
# --------------------------- Kinematics fixtures ---------------------------
570571
@pytest.fixture(params=["displacement", "velocity", "acceleration"])
571572
def kinematic_property(request):
572573
"""Return a kinematic property."""
@@ -820,6 +821,7 @@ def track_ids_not_unique_per_frame(
820821
return file_path
821822

822823

824+
# ----------------- Helpers fixture -----------------
823825
class Helpers:
824826
"""Generic helper methods for ``movement`` test modules."""
825827

@@ -834,9 +836,6 @@ def count_consecutive_nans(da):
834836
return (da.isnull().astype(int).diff("time") == 1).sum().item()
835837

836838

837-
# ----------------- Helper fixture -----------------
838-
839-
840839
@pytest.fixture
841840
def helpers():
842841
"""Return an instance of the ``Helpers`` class."""

0 commit comments

Comments
 (0)