diff --git a/docs/source/conf.py b/docs/source/conf.py index 3597e8433..5a4a5e792 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -253,6 +253,7 @@ "pynwb": ("https://pynwb.readthedocs.io/en/stable/", None), "matplotlib": ("https://matplotlib.org/stable/", None), "numpy": ("https://numpy.org/doc/stable/", None), + "shapely": ("https://shapely.readthedocs.io/en/stable/", None), } # What to show on the 404 page diff --git a/movement/io/load_bboxes.py b/movement/io/load_bboxes.py index 144003400..efa1c9070 100644 --- a/movement/io/load_bboxes.py +++ b/movement/io/load_bboxes.py @@ -8,9 +8,10 @@ import numpy as np import pandas as pd +import pooch import xarray as xr -from movement.utils.logging import logger +from movement.utils.logging import hide_pooch_hash_logs, logger from movement.validators.datasets import ValidBboxesInputs from movement.validators.files import ( DEFAULT_FRAME_REGEXP, @@ -217,8 +218,26 @@ def from_file( >>> source_software="VIA-tracks", >>> fps=30, >>> ) + >>> # Load from a URL + >>> ds = load_bboxes.from_file( + >>> "https://github.com/neuroinformatics-unit/movement/raw/main/tests/data/bboxes/VIA_multiple-crabs_5-frames_labels.csv", + >>> source_software="VIA-tracks", + >>> fps=30, + >>> ) """ + # Download file if it is a URL + if str(file_path).startswith(("http://", "https://")): + with hide_pooch_hash_logs(): + file_path = pooch.retrieve( + url=file_path, + known_hash=None, + path=Path( + "~", ".movement", "data", "public_datasets" + ).expanduser(), + progressbar=True, + ) + if source_software == "VIA-tracks": return from_via_tracks_file( file_path, diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index 646cb6873..f736e4031 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -6,12 +6,13 @@ import h5py import numpy as np import pandas as pd +import pooch import pynwb import xarray as xr from sleap_io.io.slp import read_labels from sleap_io.model.labels import Labels -from movement.utils.logging import logger +from movement.utils.logging import hide_pooch_hash_logs, logger from movement.validators.datasets import ValidPosesInputs from movement.validators.files import ( ValidAniposeCSV, @@ -149,8 +150,26 @@ def from_file( >>> ds = load_poses.from_file( ... "path/to/file.h5", source_software="DeepLabCut", fps=30 ... ) + >>> # Load from a URL + >>> ds = load_poses.from_file( + ... "https://github.com/neuroinformatics-unit/movement/raw/main/tests/data/DLC/single-mouse_EPM.predictions.h5", + ... source_software="DeepLabCut", + ... fps=30 + ... ) """ + # Download file if it is a URL + if str(file_path).startswith(("http://", "https://")): + with hide_pooch_hash_logs(): + file_path = pooch.retrieve( + url=file_path, + known_hash=None, + path=Path( + "~", ".movement", "data", "public_datasets" + ).expanduser(), + progressbar=True, + ) + if source_software == "DeepLabCut": return from_dlc_file(file_path, fps) elif source_software == "SLEAP": diff --git a/movement/roi/base.py b/movement/roi/base.py index e6d273b5c..0312a43ee 100644 --- a/movement/roi/base.py +++ b/movement/roi/base.py @@ -2,8 +2,10 @@ from __future__ import annotations +import json from abc import ABC, abstractmethod from collections.abc import Callable, Hashable, Sequence +from pathlib import Path from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar, cast import matplotlib.pyplot as plt @@ -167,7 +169,7 @@ def _boundary_angle_computation( Spatial position data, that is passed to ``how_to_compute_vector_to_region`` and used to compute the "vector to the region". - reference_vector : xarray.DataArray | np.ndarray + reference_vector : xarray.DataArray | numpy.ndarray Constant or time-varying vector to take signed angle with the "vector to the region". how_to_compute_vector_to_region : Callable @@ -335,7 +337,7 @@ def compute_nearest_point_to( Returns ------- - np.ndarray + numpy.ndarray Coordinates of the point on ``self`` that is closest to ``position``. @@ -380,7 +382,7 @@ def compute_approach_vector( Returns ------- - np.ndarray + numpy.ndarray Approach vector from the point to the region. See Also @@ -438,7 +440,7 @@ def compute_allocentric_angle_to_nearest_point( in_degrees : bool If ``True``, angles are returned in degrees. Otherwise angles are returned in radians. Default ``False``. - reference_vector : np.ndarray or xarray.DataArray or None + reference_vector : ArrayLike | xarray.DataArray The reference vector to be used. Dimensions must be compatible with the argument of the same name that is passed to :func:`compute_signed_angle_2d`. Default ``(1., 0.)``. @@ -560,3 +562,81 @@ def plot( if fig is None or ax is None: fig, ax = plt.subplots(1, 1) return self._plot(fig, ax, **matplotlib_kwargs) + + def to_file(self, path: str | Path) -> None: + """Save the region of interest to a file. + + Parameters + ---------- + path : str | Path + Path to save the ROI file. The file will be saved in JSON format. + + See Also + -------- + from_file : Load a region of interest from a file. + + Examples + -------- + >>> from movement.roi import PolygonOfInterest + >>> roi = PolygonOfInterest([(0, 0), (1, 0), (1, 1)], name="triangle") + >>> roi.to_file("my_roi.json") # doctest: +SKIP + + """ + data = { + "name": self._name, + "geometry_wkt": self.region.wkt, + "dimensions": self.dimensions, + "roi_type": self.__class__.__name__, + } + Path(path).write_text(json.dumps(data, indent=2)) + + @classmethod + def from_file(cls, path: str | Path) -> BaseRegionOfInterest: + """Load a region of interest from a file. + + Parameters + ---------- + path : str | Path + Path to the ROI file to load. Must be a JSON file saved by + :meth:`to_file`. + + Returns + ------- + BaseRegionOfInterest + The loaded region of interest object. The specific subclass + (LineOfInterest or PolygonOfInterest) is determined from the file. + + Raises + ------ + FileNotFoundError + If the specified file does not exist. + + See Also + -------- + to_file : Save a region of interest to a file. + + Examples + -------- + >>> from movement.roi import PolygonOfInterest + >>> roi = PolygonOfInterest.from_file("my_roi.json") # doctest: +SKIP + + """ + file_path = Path(path) + if not file_path.exists(): + raise FileNotFoundError(f"ROI file not found: {path}") + + data = json.loads(file_path.read_text()) + geometry = shapely.from_wkt(data["geometry_wkt"]) + + # Import here to avoid circular imports + from movement.roi import LineOfInterest, PolygonOfInterest + + roi_type = data.get("roi_type", "") + if roi_type == "LineOfInterest" or data["dimensions"] == 1: + return LineOfInterest._from_geometry( + geometry, name=data.get("name") + ) + else: + return PolygonOfInterest._from_geometry( + geometry, name=data.get("name") + ) diff --git a/movement/roi/line.py b/movement/roi/line.py index 8d6ce1108..a14811814 100644 --- a/movement/roi/line.py +++ b/movement/roi/line.py @@ -84,6 +84,31 @@ def __init__( line = shapely.normalize(line) super().__init__(line, name=name) + @classmethod + def _from_geometry( + cls, + geometry: "shapely.LineString | shapely.LinearRing", + name: str | None = None, + ) -> "LineOfInterest": + """Construct a LineOfInterest from a shapely geometry. + + Parameters + ---------- + geometry : shapely.LineString | shapely.LinearRing + The shapely geometry to construct from. + name : str, optional + Name for the LineOfInterest. + + Returns + ------- + LineOfInterest + A new LineOfInterest instance. + + """ + points = geometry.coords + loop = isinstance(geometry, shapely.LinearRing) + return cls(points=points, loop=loop, name=name) + def _plot( self, fig: Figure | SubFigure, ax: Axes, **matplotlib_kwargs ) -> tuple[Figure | SubFigure, Axes]: diff --git a/movement/roi/polygon.py b/movement/roi/polygon.py index 911266324..4a190ea20 100644 --- a/movement/roi/polygon.py +++ b/movement/roi/polygon.py @@ -85,6 +85,35 @@ def __init__( ) super().__init__(geometry=polygon, name=name) + @classmethod + def _from_geometry( + cls, + geometry: shapely.Polygon, + name: str | None = None, + ) -> PolygonOfInterest: + """Construct a PolygonOfInterest from a shapely geometry. + + Parameters + ---------- + geometry : shapely.Polygon + The shapely geometry to construct from. + name : str, optional + Name for the PolygonOfInterest. + + Returns + ------- + PolygonOfInterest + A new PolygonOfInterest instance. + + """ + exterior = geometry.exterior.coords + holes = ( + [interior.coords for interior in geometry.interiors] + if geometry.interiors + else None + ) + return cls(exterior_boundary=exterior, holes=holes, name=name) + @property def _default_plot_args(self) -> dict[str, Any]: return { diff --git a/movement/sample_data.py b/movement/sample_data.py index fcc5abe1f..b4ec8328a 100644 --- a/movement/sample_data.py +++ b/movement/sample_data.py @@ -6,9 +6,7 @@ are used. """ -import logging import shutil -from contextlib import contextmanager from pathlib import Path import pooch @@ -17,7 +15,7 @@ from requests.exceptions import RequestException from movement.io import load_bboxes, load_poses -from movement.utils.logging import logger +from movement.utils.logging import hide_pooch_hash_logs, logger # URL to the remote data repository on GIN # noinspection PyInterpreter @@ -34,32 +32,6 @@ METADATA_FILE = "metadata.yaml" -@contextmanager -def hide_pooch_hash_logs(): - """Hide SHA256 hash printouts from ``pooch.retrieve``. - - This context manager temporarily suppresses SHA256 hash messages - when downloading files with Pooch. - """ - logger = pooch.get_logger() - - class HashFilter(logging.Filter): - def filter(self, record): - msg = record.getMessage() - # Suppress only hash display lines - return not ( - "SHA256 hash of downloaded file" in msg - or "Use this value as the 'known_hash'" in msg - ) - - flt = HashFilter() - logger.addFilter(flt) - try: - yield - finally: - logger.removeFilter(flt) - - def _download_metadata_file(file_name: str, data_dir: Path = DATA_DIR) -> Path: """Download the metadata yaml file. diff --git a/movement/utils/logging.py b/movement/utils/logging.py index 0d3e8a9a8..87bbe42ef 100644 --- a/movement/utils/logging.py +++ b/movement/utils/logging.py @@ -2,12 +2,15 @@ import inspect import json +import logging import sys import warnings +from contextlib import contextmanager from datetime import datetime from functools import wraps from pathlib import Path +import pooch from loguru import logger as loguru_logger DEFAULT_LOG_DIRECTORY = Path.home() / ".movement" @@ -157,3 +160,29 @@ def wrapper(*args, **kwargs): return result return wrapper + + +@contextmanager +def hide_pooch_hash_logs(): + """Hide SHA256 hash printouts from ``pooch.retrieve``. + + This context manager temporarily suppresses SHA256 hash messages + when downloading files with Pooch. + """ + pooch_logger = pooch.get_logger() + + class HashFilter(logging.Filter): + def filter(self, record): + msg = record.getMessage() + # Suppress only hash display lines + return not ( + "SHA256 hash of downloaded file" in msg + or "Use this value as the 'known_hash'" in msg + ) + + flt = HashFilter() + pooch_logger.addFilter(flt) + try: + yield + finally: + pooch_logger.removeFilter(flt) diff --git a/tests/test_unit/test_roi/test_save_load.py b/tests/test_unit/test_roi/test_save_load.py new file mode 100644 index 000000000..1eb83f27e --- /dev/null +++ b/tests/test_unit/test_roi/test_save_load.py @@ -0,0 +1,86 @@ +"""Tests for saving and loading regions of interest to/from files.""" + +import json + +import pytest + +from movement.roi import LineOfInterest, PolygonOfInterest + + +class TestROISaveLoad: + """Tests for ROI save/load functionality.""" + + def test_save_and_load_polygon_roi(self, tmp_path, triangle): + """Test round-trip save and load for PolygonOfInterest.""" + file_path = tmp_path / "triangle.json" + + # Save + triangle.to_file(file_path) + + # Verify file exists and has correct content + assert file_path.exists() + data = json.loads(file_path.read_text()) + assert data["roi_type"] == "PolygonOfInterest" + assert data["dimensions"] == 2 + assert data["name"] == "triangle" + assert "geometry_wkt" in data + + # Load + loaded = PolygonOfInterest.from_file(file_path) + + # Verify loaded ROI matches original + assert loaded.name == triangle.name + assert loaded.dimensions == triangle.dimensions + assert loaded.region.equals(triangle.region) + + def test_save_and_load_line_roi(self, tmp_path, segment_of_y_equals_x): + """Test round-trip save and load for LineOfInterest.""" + file_path = tmp_path / "line.json" + + # Save + segment_of_y_equals_x.to_file(file_path) + + # Verify file exists + assert file_path.exists() + data = json.loads(file_path.read_text()) + assert data["roi_type"] == "LineOfInterest" + assert data["dimensions"] == 1 + + # Load + loaded = LineOfInterest.from_file(file_path) + + # Verify loaded ROI matches original + assert loaded.dimensions == segment_of_y_equals_x.dimensions + assert loaded.region.equals(segment_of_y_equals_x.region) + + def test_save_and_load_polygon_with_hole( + self, tmp_path, unit_square_with_hole + ): + """Test round-trip for polygon with interior holes.""" + file_path = tmp_path / "square_with_hole.json" + + # Save + unit_square_with_hole.to_file(file_path) + + # Load + loaded = PolygonOfInterest.from_file(file_path) + + # Verify holes are preserved + assert loaded.region.equals(unit_square_with_hole.region) + assert len(loaded.holes) == len(unit_square_with_hole.holes) + + def test_load_nonexistent_file_raises(self, tmp_path): + """Test that loading a non-existent file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="ROI file not found"): + PolygonOfInterest.from_file(tmp_path / "nonexistent.json") + + def test_save_with_none_name(self, tmp_path, triangle_pts): + """Test saving an ROI with no name.""" + roi = PolygonOfInterest(triangle_pts) # No name provided + file_path = tmp_path / "unnamed.json" + + roi.to_file(file_path) + loaded = PolygonOfInterest.from_file(file_path) + + # Name should be None in both + assert loaded._name is None