|
1 | 1 | """Compute kinematic variables like velocity and acceleration.""" |
2 | 2 |
|
| 3 | +import itertools |
3 | 4 | from typing import Literal |
4 | 5 |
|
5 | 6 | import numpy as np |
6 | 7 | import xarray as xr |
| 8 | +from scipy.spatial.distance import cdist |
7 | 9 |
|
8 | 10 | from movement.utils.logging import log_error |
9 | 11 | from movement.utils.vector import compute_norm |
@@ -324,6 +326,336 @@ def compute_head_direction_vector( |
324 | 326 | ) |
325 | 327 |
|
326 | 328 |
|
| 329 | +def _cdist( |
| 330 | + a: xr.DataArray, |
| 331 | + b: xr.DataArray, |
| 332 | + dim: Literal["individuals", "keypoints"], |
| 333 | + metric: str | None = "euclidean", |
| 334 | + **kwargs, |
| 335 | +) -> xr.DataArray: |
| 336 | + """Compute distances between two position arrays across a given dimension. |
| 337 | +
|
| 338 | + This function is a wrapper around :func:`scipy.spatial.distance.cdist` |
| 339 | + and computes the pairwise distances between the two input position arrays |
| 340 | + across the dimension specified by ``dim``. |
| 341 | + The dimension can be either ``individuals`` or ``keypoints``. |
| 342 | + The distances are computed using the specified ``metric``. |
| 343 | +
|
| 344 | + Parameters |
| 345 | + ---------- |
| 346 | + a : xarray.DataArray |
| 347 | + The first input data containing position information of a |
| 348 | + single individual or keypoint, with ``time``, ``space`` |
| 349 | + (in Cartesian coordinates), and ``individuals`` or ``keypoints`` |
| 350 | + (as specified by ``dim``) as required dimensions. |
| 351 | + b : xarray.DataArray |
| 352 | + The second input data containing position information of a |
| 353 | + single individual or keypoint, with ``time``, ``space`` |
| 354 | + (in Cartesian coordinates), and ``individuals`` or ``keypoints`` |
| 355 | + (as specified by ``dim``) as required dimensions. |
| 356 | + dim : str |
| 357 | + The dimension to compute the distances for. Must be either |
| 358 | + ``'individuals'`` or ``'keypoints'``. |
| 359 | + metric : str, optional |
| 360 | + The distance metric to use. Must be one of the options supported |
| 361 | + by :func:`scipy.spatial.distance.cdist`, e.g. ``'cityblock'``, |
| 362 | + ``'euclidean'``, etc. |
| 363 | + Defaults to ``'euclidean'``. |
| 364 | + **kwargs : dict |
| 365 | + Additional keyword arguments to pass to |
| 366 | + :func:`scipy.spatial.distance.cdist`. |
| 367 | +
|
| 368 | + Returns |
| 369 | + ------- |
| 370 | + xarray.DataArray |
| 371 | + An xarray DataArray containing the computed distances between |
| 372 | + each pair of inputs. |
| 373 | +
|
| 374 | + Examples |
| 375 | + -------- |
| 376 | + Compute the Euclidean distance (default) between ``ind1`` and |
| 377 | + ``ind2`` (i.e. interindividual distance for all keypoints) |
| 378 | + using the ``position`` data variable in the Dataset ``ds``: |
| 379 | +
|
| 380 | + >>> pos1 = ds.position.sel(individuals="ind1") |
| 381 | + >>> pos2 = ds.position.sel(individuals="ind2") |
| 382 | + >>> ind_dists = _cdist(pos1, pos2, dim="individuals") |
| 383 | +
|
| 384 | + Compute the Euclidean distance (default) between ``key1`` and |
| 385 | + ``key2`` (i.e. interkeypoint distance for all individuals) |
| 386 | + using the ``position`` data variable in the Dataset ``ds``: |
| 387 | +
|
| 388 | + >>> pos1 = ds.position.sel(keypoints="key1") |
| 389 | + >>> pos2 = ds.position.sel(keypoints="key2") |
| 390 | + >>> key_dists = _cdist(pos1, pos2, dim="keypoints") |
| 391 | +
|
| 392 | + See Also |
| 393 | + -------- |
| 394 | + scipy.spatial.distance.cdist : The underlying function used. |
| 395 | + compute_pairwise_distances : Compute pairwise distances between |
| 396 | + ``individuals`` or ``keypoints`` |
| 397 | +
|
| 398 | + """ |
| 399 | + # The dimension from which ``dim`` labels are obtained |
| 400 | + labels_dim = "individuals" if dim == "keypoints" else "keypoints" |
| 401 | + elem1 = getattr(a, dim).item() |
| 402 | + elem2 = getattr(b, dim).item() |
| 403 | + a = _validate_labels_dimension(a, labels_dim) |
| 404 | + b = _validate_labels_dimension(b, labels_dim) |
| 405 | + result = xr.apply_ufunc( |
| 406 | + cdist, |
| 407 | + a, |
| 408 | + b, |
| 409 | + kwargs={"metric": metric, **kwargs}, |
| 410 | + input_core_dims=[[labels_dim, "space"], [labels_dim, "space"]], |
| 411 | + output_core_dims=[[elem1, elem2]], |
| 412 | + vectorize=True, |
| 413 | + ) |
| 414 | + result = result.assign_coords( |
| 415 | + { |
| 416 | + elem1: getattr(a, labels_dim).values, |
| 417 | + elem2: getattr(a, labels_dim).values, |
| 418 | + } |
| 419 | + ) |
| 420 | + # Drop any squeezed coordinates |
| 421 | + return result.squeeze(drop=True) |
| 422 | + |
| 423 | + |
| 424 | +def compute_pairwise_distances( |
| 425 | + data: xr.DataArray, |
| 426 | + dim: Literal["individuals", "keypoints"], |
| 427 | + pairs: dict[str, str | list[str]] | Literal["all"], |
| 428 | + metric: str | None = "euclidean", |
| 429 | + **kwargs, |
| 430 | +) -> xr.DataArray | dict[str, xr.DataArray]: |
| 431 | + """Compute pairwise distances between ``individuals`` or ``keypoints``. |
| 432 | +
|
| 433 | + This function computes the distances between |
| 434 | + pairs of ``individuals`` (i.e. interindividual distances) or |
| 435 | + pairs of ``keypoints`` (i.e. interkeypoint distances), |
| 436 | + as determined by ``dim``. |
| 437 | + The distances are computed for the given ``pairs`` |
| 438 | + using the specified ``metric``. |
| 439 | +
|
| 440 | + Parameters |
| 441 | + ---------- |
| 442 | + data : xarray.DataArray |
| 443 | + The input data containing position information, with ``time``, |
| 444 | + ``space`` (in Cartesian coordinates), and |
| 445 | + ``individuals`` or ``keypoints`` (as specified by ``dim``) |
| 446 | + as required dimensions. |
| 447 | + dim : Literal["individuals", "keypoints"] |
| 448 | + The dimension to compute the distances for. Must be either |
| 449 | + ``'individuals'`` or ``'keypoints'``. |
| 450 | + pairs : dict[str, str | list[str]] or 'all' |
| 451 | + Specifies the pairs of elements (either individuals or keypoints) |
| 452 | + for which to compute distances, depending on the value of ``dim``. |
| 453 | +
|
| 454 | + - If ``dim='individuals'``, ``pairs`` should be a dictionary where |
| 455 | + each key is an individual name, and each value is also an individual |
| 456 | + name or a list of such names to compute distances with. |
| 457 | + - If ``dim='keypoints'``, ``pairs`` should be a dictionary where each |
| 458 | + key is a keypoint name, and each value is also keypoint name or a |
| 459 | + list of such names to compute distances with. |
| 460 | + - Alternatively, use the special keyword ``'all'`` to compute distances |
| 461 | + for all possible pairs of individuals or keypoints |
| 462 | + (depending on ``dim``). |
| 463 | + metric : str, optional |
| 464 | + The distance metric to use. Must be one of the options supported |
| 465 | + by :func:`scipy.spatial.distance.cdist`, e.g. ``'cityblock'``, |
| 466 | + ``'euclidean'``, etc. |
| 467 | + Defaults to ``'euclidean'``. |
| 468 | + **kwargs : dict |
| 469 | + Additional keyword arguments to pass to |
| 470 | + :func:`scipy.spatial.distance.cdist`. |
| 471 | +
|
| 472 | + Returns |
| 473 | + ------- |
| 474 | + xarray.DataArray or dict[str, xarray.DataArray] |
| 475 | + The computed pairwise distances. If a single pair is specified in |
| 476 | + ``pairs``, returns an :class:`xarray.DataArray`. If multiple pairs |
| 477 | + are specified, returns a dictionary where each key is a string |
| 478 | + representing the pair (e.g., ``'dist_ind1_ind2'`` or |
| 479 | + ``'dist_key1_key2'``) and each value is an :class:`xarray.DataArray` |
| 480 | + containing the computed distances for that pair. |
| 481 | +
|
| 482 | + Raises |
| 483 | + ------ |
| 484 | + ValueError |
| 485 | + If ``dim`` is not one of ``'individuals'`` or ``'keypoints'``; |
| 486 | + if ``pairs`` is not a dictionary or ``'all'``; or |
| 487 | + if there are no pairs in ``data`` to compute distances for. |
| 488 | +
|
| 489 | + Examples |
| 490 | + -------- |
| 491 | + Compute the Euclidean distance (default) between ``ind1`` and ``ind2`` |
| 492 | + (i.e. interindividual distance), for all possible pairs of keypoints. |
| 493 | +
|
| 494 | + >>> position = xr.DataArray( |
| 495 | + ... np.arange(36).reshape(2, 3, 3, 2), |
| 496 | + ... coords={ |
| 497 | + ... "time": np.arange(2), |
| 498 | + ... "individuals": ["ind1", "ind2", "ind3"], |
| 499 | + ... "keypoints": ["key1", "key2", "key3"], |
| 500 | + ... "space": ["x", "y"], |
| 501 | + ... }, |
| 502 | + ... dims=["time", "individuals", "keypoints", "space"], |
| 503 | + ... ) |
| 504 | + >>> dist_ind1_ind2 = compute_pairwise_distances( |
| 505 | + ... position, "individuals", {"ind1": "ind2"} |
| 506 | + ... ) |
| 507 | + >>> dist_ind1_ind2 |
| 508 | + <xarray.DataArray (time: 2, ind1: 3, ind2: 3)> Size: 144B |
| 509 | + 8.485 11.31 14.14 5.657 8.485 11.31 ... 5.657 8.485 11.31 2.828 5.657 8.485 |
| 510 | + Coordinates: |
| 511 | + * time (time) int64 16B 0 1 |
| 512 | + * ind1 (ind1) <U4 48B 'key1' 'key2' 'key3' |
| 513 | + * ind2 (ind2) <U4 48B 'key1' 'key2' 'key3' |
| 514 | +
|
| 515 | + The resulting ``dist_ind1_ind2`` is a DataArray containing the computed |
| 516 | + distances between ``ind1`` and ``ind2`` for all keypoints |
| 517 | + at each time point. |
| 518 | +
|
| 519 | + To obtain the distances between ``key1`` of ``ind1`` and |
| 520 | + ``key2`` of ``ind2``: |
| 521 | +
|
| 522 | + >>> dist_ind1_ind2.sel(ind1="key1", ind2="key2") |
| 523 | +
|
| 524 | + Compute the Euclidean distance (default) between ``key1`` and ``key2`` |
| 525 | + (i.e. interkeypoint distance), for all possible pairs of individuals. |
| 526 | +
|
| 527 | + >>> dist_key1_key2 = compute_pairwise_distances( |
| 528 | + ... position, "keypoints", {"key1": "key2"} |
| 529 | + ... ) |
| 530 | + >>> dist_key1_key2 |
| 531 | + <xarray.DataArray (time: 2, key1: 3, key2: 3)> Size: 144B |
| 532 | + 2.828 11.31 19.8 5.657 2.828 11.31 14.14 ... 2.828 11.31 14.14 5.657 2.828 |
| 533 | + Coordinates: |
| 534 | + * time (time) int64 16B 0 1 |
| 535 | + * key1 (key1) <U4 48B 'ind1' 'ind2' 'ind3' |
| 536 | + * key2 (key2) <U4 48B 'ind1' 'ind2' 'ind3' |
| 537 | +
|
| 538 | + The resulting ``dist_key1_key2`` is a DataArray containing the computed |
| 539 | + distances between ``key1`` and ``key2`` for all individuals |
| 540 | + at each time point. |
| 541 | +
|
| 542 | + To obtain the distances between ``key1`` and ``key2`` within ``ind1``: |
| 543 | +
|
| 544 | + >>> dist_key1_key2.sel(key1="ind1", key2="ind1") |
| 545 | +
|
| 546 | + To obtain the distances between ``key1`` of ``ind1`` and |
| 547 | + ``key2`` of ``ind2``: |
| 548 | +
|
| 549 | + >>> dist_key1_key2.sel(key1="ind1", key2="ind2") |
| 550 | +
|
| 551 | + Compute the city block or Manhattan distance for multiple pairs of |
| 552 | + keypoints using ``position``: |
| 553 | +
|
| 554 | + >>> key_dists = compute_pairwise_distances( |
| 555 | + ... position, |
| 556 | + ... "keypoints", |
| 557 | + ... {"key1": "key2", "key3": ["key1", "key2"]}, |
| 558 | + ... metric="cityblock", |
| 559 | + ... ) |
| 560 | + >>> key_dists.keys() |
| 561 | + dict_keys(['dist_key1_key2', 'dist_key3_key1', 'dist_key3_key2']) |
| 562 | +
|
| 563 | + As multiple pairs of keypoints are specified, |
| 564 | + the resulting ``key_dists`` is a dictionary containing the DataArrays |
| 565 | + of computed distances for each pair of keypoints. |
| 566 | +
|
| 567 | + Compute the city block or Manhattan distance for all possible pairs of |
| 568 | + individuals using ``position``: |
| 569 | +
|
| 570 | + >>> ind_dists = compute_pairwise_distances( |
| 571 | + ... position, |
| 572 | + ... "individuals", |
| 573 | + ... "all", |
| 574 | + ... metric="cityblock", |
| 575 | + ... ) |
| 576 | + >>> ind_dists.keys() |
| 577 | + dict_keys(['dist_ind1_ind2', 'dist_ind1_ind3', 'dist_ind2_ind3']) |
| 578 | +
|
| 579 | + See Also |
| 580 | + -------- |
| 581 | + scipy.spatial.distance.cdist : The underlying function used. |
| 582 | +
|
| 583 | + """ |
| 584 | + if dim not in ["individuals", "keypoints"]: |
| 585 | + raise log_error( |
| 586 | + ValueError, |
| 587 | + "'dim' must be either 'individuals' or 'keypoints', " |
| 588 | + f"but got {dim}.", |
| 589 | + ) |
| 590 | + if isinstance(pairs, str) and pairs != "all": |
| 591 | + raise log_error( |
| 592 | + ValueError, |
| 593 | + f"'pairs' must be a dictionary or 'all', but got {pairs}.", |
| 594 | + ) |
| 595 | + validate_dims_coords(data, {"time": [], "space": ["x", "y"], dim: []}) |
| 596 | + # Find all possible pair combinations if 'all' is specified |
| 597 | + if pairs == "all": |
| 598 | + paired_elements = list( |
| 599 | + itertools.combinations(getattr(data, dim).values, 2) |
| 600 | + ) |
| 601 | + else: |
| 602 | + paired_elements = [ |
| 603 | + (elem1, elem2) |
| 604 | + for elem1, elem2_list in pairs.items() |
| 605 | + for elem2 in |
| 606 | + ( |
| 607 | + # Ensure elem2_list is a list |
| 608 | + [elem2_list] if isinstance(elem2_list, str) else elem2_list |
| 609 | + ) |
| 610 | + ] |
| 611 | + if not paired_elements: |
| 612 | + raise log_error( |
| 613 | + ValueError, "Could not find any pairs to compute distances for." |
| 614 | + ) |
| 615 | + pairwise_distances = { |
| 616 | + f"dist_{elem1}_{elem2}": _cdist( |
| 617 | + data.sel({dim: elem1}), |
| 618 | + data.sel({dim: elem2}), |
| 619 | + dim=dim, |
| 620 | + metric=metric, |
| 621 | + **kwargs, |
| 622 | + ) |
| 623 | + for elem1, elem2 in paired_elements |
| 624 | + } |
| 625 | + # Return DataArray if result only has one key |
| 626 | + if len(pairwise_distances) == 1: |
| 627 | + return next(iter(pairwise_distances.values())) |
| 628 | + return pairwise_distances |
| 629 | + |
| 630 | + |
| 631 | +def _validate_labels_dimension(data: xr.DataArray, dim: str) -> xr.DataArray: |
| 632 | + """Validate the input data contains the ``dim`` for labelling dimensions. |
| 633 | +
|
| 634 | + This function ensures the input data contains the ``dim`` |
| 635 | + used as labels (coordinates) when applying |
| 636 | + :func:`scipy.spatial.distance.cdist` to |
| 637 | + the input data, by adding a temporary dimension if necessary. |
| 638 | +
|
| 639 | + Parameters |
| 640 | + ---------- |
| 641 | + data : xarray.DataArray |
| 642 | + The input data to validate. |
| 643 | + dim : str |
| 644 | + The dimension to validate. |
| 645 | +
|
| 646 | + Returns |
| 647 | + ------- |
| 648 | + xarray.DataArray |
| 649 | + The input data with the labels dimension validated. |
| 650 | +
|
| 651 | + """ |
| 652 | + if data.coords.get(dim) is None: |
| 653 | + data = data.assign_coords({dim: "temp_dim"}) |
| 654 | + if data.coords[dim].ndim == 0: |
| 655 | + data = data.expand_dims(dim).transpose("time", "space", dim) |
| 656 | + return data |
| 657 | + |
| 658 | + |
327 | 659 | def _validate_type_data_array(data: xr.DataArray) -> None: |
328 | 660 | """Validate the input data is an xarray DataArray. |
329 | 661 |
|
|
0 commit comments