Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@
"pynwb": ("https://pynwb.readthedocs.io/en/stable/", None),
"matplotlib": ("https://matplotlib.org/stable/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"shapely": ("https://shapely.readthedocs.io/en/stable/", None),
}

# What to show on the 404 page
Expand Down
88 changes: 84 additions & 4 deletions movement/roi/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from __future__ import annotations

import json
from abc import ABC, abstractmethod
from collections.abc import Callable, Hashable, Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar, cast

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -167,7 +169,7 @@ def _boundary_angle_computation(
Spatial position data, that is passed to
``how_to_compute_vector_to_region`` and used to compute the
"vector to the region".
reference_vector : xarray.DataArray | np.ndarray
reference_vector : xarray.DataArray | numpy.ndarray
Constant or time-varying vector to take signed angle with the
"vector to the region".
how_to_compute_vector_to_region : Callable
Expand Down Expand Up @@ -335,7 +337,7 @@ def compute_nearest_point_to(

Returns
-------
np.ndarray
numpy.ndarray
Coordinates of the point on ``self`` that is closest to
``position``.

Expand Down Expand Up @@ -380,7 +382,7 @@ def compute_approach_vector(

Returns
-------
np.ndarray
numpy.ndarray
Approach vector from the point to the region.

See Also
Expand Down Expand Up @@ -438,7 +440,7 @@ def compute_allocentric_angle_to_nearest_point(
in_degrees : bool
If ``True``, angles are returned in degrees. Otherwise angles are
returned in radians. Default ``False``.
reference_vector : np.ndarray or xarray.DataArray or None
reference_vector : ArrayLike | xarray.DataArray
The reference vector to be used. Dimensions must be compatible with
the argument of the same name that is passed to
:func:`compute_signed_angle_2d`. Default ``(1., 0.)``.
Expand Down Expand Up @@ -560,3 +562,81 @@ def plot(
if fig is None or ax is None:
fig, ax = plt.subplots(1, 1)
return self._plot(fig, ax, **matplotlib_kwargs)

def to_file(self, path: str | Path) -> None:
"""Save the region of interest to a file.

Parameters
----------
path : str | Path
Path to save the ROI file. The file will be saved in JSON format.

See Also
--------
from_file : Load a region of interest from a file.

Examples
--------
>>> from movement.roi import PolygonOfInterest
>>> roi = PolygonOfInterest([(0, 0), (1, 0), (1, 1)], name="triangle")
>>> roi.to_file("my_roi.json") # doctest: +SKIP

"""
data = {
"name": self._name,
"geometry_wkt": self.region.wkt,
"dimensions": self.dimensions,
"roi_type": self.__class__.__name__,
}
Path(path).write_text(json.dumps(data, indent=2))

@classmethod
def from_file(cls, path: str | Path) -> BaseRegionOfInterest:
"""Load a region of interest from a file.

Parameters
----------
path : str | Path
Path to the ROI file to load. Must be a JSON file saved by
:meth:`to_file`.

Returns
-------
BaseRegionOfInterest
The loaded region of interest object. The specific subclass
(LineOfInterest or PolygonOfInterest) is determined from the file.

Raises
------
FileNotFoundError
If the specified file does not exist.

See Also
--------
to_file : Save a region of interest to a file.

Examples
--------
>>> from movement.roi import PolygonOfInterest
>>> roi = PolygonOfInterest.from_file("my_roi.json") # doctest: +SKIP

"""
file_path = Path(path)
if not file_path.exists():
raise FileNotFoundError(f"ROI file not found: {path}")

data = json.loads(file_path.read_text())
geometry = shapely.from_wkt(data["geometry_wkt"])

# Import here to avoid circular imports
from movement.roi import LineOfInterest, PolygonOfInterest

roi_type = data.get("roi_type", "")
if roi_type == "LineOfInterest" or data["dimensions"] == 1:
return LineOfInterest._from_geometry(
geometry, name=data.get("name")
)
else:
return PolygonOfInterest._from_geometry(
geometry, name=data.get("name")
)
25 changes: 25 additions & 0 deletions movement/roi/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,31 @@ def __init__(
line = shapely.normalize(line)
super().__init__(line, name=name)

@classmethod
def _from_geometry(
cls,
geometry: "shapely.LineString | shapely.LinearRing",
name: str | None = None,
) -> "LineOfInterest":
"""Construct a LineOfInterest from a shapely geometry.

Parameters
----------
geometry : shapely.LineString | shapely.LinearRing
The shapely geometry to construct from.
name : str, optional
Name for the LineOfInterest.

Returns
-------
LineOfInterest
A new LineOfInterest instance.

"""
points = geometry.coords
loop = isinstance(geometry, shapely.LinearRing)
return cls(points=points, loop=loop, name=name)

def _plot(
self, fig: Figure | SubFigure, ax: Axes, **matplotlib_kwargs
) -> tuple[Figure | SubFigure, Axes]:
Expand Down
29 changes: 29 additions & 0 deletions movement/roi/polygon.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,35 @@ def __init__(
)
super().__init__(geometry=polygon, name=name)

@classmethod
def _from_geometry(
cls,
geometry: shapely.Polygon,
name: str | None = None,
) -> PolygonOfInterest:
"""Construct a PolygonOfInterest from a shapely geometry.

Parameters
----------
geometry : shapely.Polygon
The shapely geometry to construct from.
name : str, optional
Name for the PolygonOfInterest.

Returns
-------
PolygonOfInterest
A new PolygonOfInterest instance.

"""
exterior = geometry.exterior.coords
holes = (
[interior.coords for interior in geometry.interiors]
if geometry.interiors
else None
)
return cls(exterior_boundary=exterior, holes=holes, name=name)

@property
def _default_plot_args(self) -> dict[str, Any]:
return {
Expand Down
86 changes: 86 additions & 0 deletions tests/test_unit/test_roi/test_save_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Tests for saving and loading regions of interest to/from files."""

import json

import pytest

from movement.roi import LineOfInterest, PolygonOfInterest


class TestROISaveLoad:
"""Tests for ROI save/load functionality."""

def test_save_and_load_polygon_roi(self, tmp_path, triangle):
"""Test round-trip save and load for PolygonOfInterest."""
file_path = tmp_path / "triangle.json"

# Save
triangle.to_file(file_path)

# Verify file exists and has correct content
assert file_path.exists()
data = json.loads(file_path.read_text())
assert data["roi_type"] == "PolygonOfInterest"
assert data["dimensions"] == 2
assert data["name"] == "triangle"
assert "geometry_wkt" in data

# Load
loaded = PolygonOfInterest.from_file(file_path)

# Verify loaded ROI matches original
assert loaded.name == triangle.name
assert loaded.dimensions == triangle.dimensions
assert loaded.region.equals(triangle.region)

def test_save_and_load_line_roi(self, tmp_path, segment_of_y_equals_x):
"""Test round-trip save and load for LineOfInterest."""
file_path = tmp_path / "line.json"

# Save
segment_of_y_equals_x.to_file(file_path)

# Verify file exists
assert file_path.exists()
data = json.loads(file_path.read_text())
assert data["roi_type"] == "LineOfInterest"
assert data["dimensions"] == 1

# Load
loaded = LineOfInterest.from_file(file_path)

# Verify loaded ROI matches original
assert loaded.dimensions == segment_of_y_equals_x.dimensions
assert loaded.region.equals(segment_of_y_equals_x.region)

def test_save_and_load_polygon_with_hole(
self, tmp_path, unit_square_with_hole
):
"""Test round-trip for polygon with interior holes."""
file_path = tmp_path / "square_with_hole.json"

# Save
unit_square_with_hole.to_file(file_path)

# Load
loaded = PolygonOfInterest.from_file(file_path)

# Verify holes are preserved
assert loaded.region.equals(unit_square_with_hole.region)
assert len(loaded.holes) == len(unit_square_with_hole.holes)

def test_load_nonexistent_file_raises(self, tmp_path):
"""Test that loading a non-existent file raises FileNotFoundError."""
with pytest.raises(FileNotFoundError, match="ROI file not found"):
PolygonOfInterest.from_file(tmp_path / "nonexistent.json")

def test_save_with_none_name(self, tmp_path, triangle_pts):
"""Test saving an ROI with no name."""
roi = PolygonOfInterest(triangle_pts) # No name provided
file_path = tmp_path / "unnamed.json"

roi.to_file(file_path)
loaded = PolygonOfInterest.from_file(file_path)

# Name should be None in both
assert loaded._name is None
Loading