-
Notifications
You must be signed in to change notification settings - Fork 95
Feat: Add support for loading datasets from URLs #782
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
278d69b
56c3eab
1affbde
af6d4cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -8,9 +8,10 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||||||||||||||||||||||||
| import pandas as pd | ||||||||||||||||||||||||||||||||||||||||||||||||||
| import pooch | ||||||||||||||||||||||||||||||||||||||||||||||||||
| import xarray as xr | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| from movement.utils.logging import logger | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from movement.utils.logging import hide_pooch_hash_logs, logger | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from movement.validators.datasets import ValidBboxesInputs | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from movement.validators.files import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| DEFAULT_FRAME_REGEXP, | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -217,8 +218,26 @@ def from_file( | |||||||||||||||||||||||||||||||||||||||||||||||||
| >>> source_software="VIA-tracks", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> fps=30, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> # Load from a URL | ||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> ds = load_bboxes.from_file( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> "https://github.com/neuroinformatics-unit/movement/raw/main/tests/data/bboxes/VIA_multiple-crabs_5-frames_labels.csv", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> source_software="VIA-tracks", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> fps=30, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Download file if it is a URL | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if str(file_path).startswith(("http://", "https://")): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| with hide_pooch_hash_logs(): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| file_path = pooch.retrieve( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| url=file_path, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| known_hash=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| path=Path( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "~", ".movement", "data", "public_datasets" | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ).expanduser(), | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+231
to
+237
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| with hide_pooch_hash_logs(): | |
| file_path = pooch.retrieve( | |
| url=file_path, | |
| known_hash=None, | |
| path=Path( | |
| "~", ".movement", "data", "public_datasets" | |
| ).expanduser(), | |
| file_path = Path( | |
| pooch.retrieve( | |
| url=file_path, | |
| known_hash=None, | |
| path=Path( | |
| "~", ".movement", "data", "public_datasets" | |
| ).expanduser(), | |
| progressbar=True, | |
| ) |
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code block is duplicated from load_poses.py (lines 161-170). Consider extracting this URL downloading logic into a shared helper function to avoid code duplication.
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The URL loading feature lacks test coverage. No tests were added to verify that URLs are correctly detected, downloaded, and cached. Consider adding tests that mock the pooch.retrieve call to verify the URL detection logic and that the correct parameters are passed to pooch.
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no error handling for potential download failures. If pooch.retrieve fails (e.g., due to network issues, invalid URL, or server errors), the error will propagate without context. Consider wrapping the pooch.retrieve call in a try-except block to provide more informative error messages to users, such as "Failed to download file from URL: {url}. Please check your internet connection and verify the URL is accessible."
| with hide_pooch_hash_logs(): | |
| file_path = pooch.retrieve( | |
| url=file_path, | |
| known_hash=None, | |
| path=Path( | |
| "~", ".movement", "data", "public_datasets" | |
| ).expanduser(), | |
| progressbar=True, | |
| try: | |
| file_path = pooch.retrieve( | |
| url=file_path, | |
| known_hash=None, | |
| path=Path( | |
| "~", ".movement", "data", "public_datasets" | |
| ).expanduser(), | |
| progressbar=True, | |
| ) | |
| except Exception as exc: | |
| error_msg = ( | |
| f"Failed to download file from URL: {file_path}. " | |
| "Please check your internet connection and verify the URL is accessible." | |
| ) | |
| logger.error(error_msg) | |
| raise RuntimeError(error_msg) from exc |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -6,12 +6,13 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import h5py | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import pandas as pd | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import pooch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import pynwb | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import xarray as xr | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from sleap_io.io.slp import read_labels | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from sleap_io.model.labels import Labels | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from movement.utils.logging import logger | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from movement.utils.logging import hide_pooch_hash_logs, logger | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from movement.validators.datasets import ValidPosesInputs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from movement.validators.files import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ValidAniposeCSV, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -149,8 +150,26 @@ def from_file( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> ds = load_poses.from_file( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ... "path/to/file.h5", source_software="DeepLabCut", fps=30 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ... ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> # Load from a URL | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> ds = load_poses.from_file( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ... "https://github.com/neuroinformatics-unit/movement/raw/main/tests/data/DLC/single-mouse_EPM.predictions.h5", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ... source_software="DeepLabCut", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ... fps=30 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ... ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Download file if it is a URL | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if str(file_path).startswith(("http://", "https://")): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with hide_pooch_hash_logs(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| file_path = pooch.retrieve( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| url=file_path, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| known_hash=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| path=Path( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "~", ".movement", "data", "public_datasets" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+163
to
+168
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with hide_pooch_hash_logs(): | |
| file_path = pooch.retrieve( | |
| url=file_path, | |
| known_hash=None, | |
| path=Path( | |
| "~", ".movement", "data", "public_datasets" | |
| # Cache public example datasets under the main movement data directory, | |
| # in a dedicated subdirectory separate from any internal test data cache. | |
| public_cache_dir = Path("~", ".movement", "data", "public_datasets").expanduser() | |
| # Explicitly create the cache directory with user-only permissions where supported. | |
| public_cache_dir.mkdir(parents=True, exist_ok=True, mode=0o700) | |
| file_path = pooch.retrieve( | |
| url=file_path, | |
| known_hash=None, | |
| path=public_cache_dir, |
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pooch.retrieve function returns a string path, but the file_path variable is typed as Path | str in the function signature. After the URL download block, file_path will be a string (returned by pooch), which may cause type inconsistencies downstream if the code expects a Path object. Consider wrapping the result in Path() to maintain type consistency: file_path = Path(pooch.retrieve(...)).
| with hide_pooch_hash_logs(): | |
| file_path = pooch.retrieve( | |
| url=file_path, | |
| known_hash=None, | |
| path=Path( | |
| "~", ".movement", "data", "public_datasets" | |
| ).expanduser(), | |
| file_path = Path( | |
| pooch.retrieve( | |
| url=file_path, | |
| known_hash=None, | |
| path=Path( | |
| "~", ".movement", "data", "public_datasets" | |
| ).expanduser(), | |
| progressbar=True, | |
| ) |
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code block for URL downloading is duplicated in both load_poses.py and load_bboxes.py. Consider extracting this logic into a helper function to reduce duplication and improve maintainability. The helper function could be placed in a utility module or in a common location accessible to both modules.
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR description states this addresses Issue #340, which is about adding support for downloading publicly available datasets like CalMS21 and Rat7M. However, the implementation is a more general URL loading feature without specific support for these public datasets. While this is a step towards the goal, consider clarifying in the PR description that this is a foundational feature for future public dataset support, rather than the complete solution for Issue #340.
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no error handling for potential download failures. If pooch.retrieve fails (e.g., due to network issues, invalid URL, or server errors), the error will propagate without context. Consider wrapping the pooch.retrieve call in a try-except block to provide more informative error messages to users, such as "Failed to download file from URL: {url}. Please check your internet connection and verify the URL is accessible."
| if str(file_path).startswith(("http://", "https://")): | |
| with hide_pooch_hash_logs(): | |
| file_path = pooch.retrieve( | |
| url=file_path, | |
| known_hash=None, | |
| path=Path( | |
| "~", ".movement", "data", "public_datasets" | |
| ).expanduser(), | |
| progressbar=True, | |
| url = str(file_path) | |
| if url.startswith(("http://", "https://")): | |
| try: | |
| file_path = pooch.retrieve( | |
| url=url, | |
| known_hash=None, | |
| path=Path( | |
| "~", ".movement", "data", "public_datasets" | |
| ).expanduser(), | |
| progressbar=True, | |
| ) | |
| except Exception as exc: | |
| error_msg = ( | |
| f"Failed to download file from URL: {url}. " | |
| "Please check your internet connection and verify the URL is accessible." | |
| ) | |
| logger.error(error_msg) | |
| raise RuntimeError(error_msg) from exc |
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The URL loading feature lacks test coverage. No tests were added to verify that URLs are correctly detected, downloaded, and cached. Consider adding tests that mock the pooch.retrieve call to verify the URL detection logic and that the correct parameters are passed to pooch.
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider using the hide_pooch_hash_logs() context manager (from movement.sample_data) when downloading files to suppress SHA256 hash messages, which may not be useful for end users. The existing sample_data.py module uses this pattern to provide a cleaner user experience. Example: with hide_pooch_hash_logs(): file_path = pooch.retrieve(...)
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,8 +2,10 @@ | |
|
|
||
| from __future__ import annotations | ||
|
|
||
| import json | ||
| from abc import ABC, abstractmethod | ||
| from collections.abc import Callable, Hashable, Sequence | ||
| from pathlib import Path | ||
| from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar, cast | ||
|
|
||
| import matplotlib.pyplot as plt | ||
|
|
@@ -167,7 +169,7 @@ def _boundary_angle_computation( | |
| Spatial position data, that is passed to | ||
| ``how_to_compute_vector_to_region`` and used to compute the | ||
| "vector to the region". | ||
| reference_vector : xarray.DataArray | np.ndarray | ||
| reference_vector : xarray.DataArray | numpy.ndarray | ||
| Constant or time-varying vector to take signed angle with the | ||
| "vector to the region". | ||
| how_to_compute_vector_to_region : Callable | ||
|
|
@@ -335,7 +337,7 @@ def compute_nearest_point_to( | |
|
|
||
| Returns | ||
| ------- | ||
| np.ndarray | ||
| numpy.ndarray | ||
| Coordinates of the point on ``self`` that is closest to | ||
| ``position``. | ||
|
|
||
|
|
@@ -380,7 +382,7 @@ def compute_approach_vector( | |
|
|
||
| Returns | ||
| ------- | ||
| np.ndarray | ||
| numpy.ndarray | ||
| Approach vector from the point to the region. | ||
|
|
||
| See Also | ||
|
|
@@ -438,7 +440,7 @@ def compute_allocentric_angle_to_nearest_point( | |
| in_degrees : bool | ||
| If ``True``, angles are returned in degrees. Otherwise angles are | ||
| returned in radians. Default ``False``. | ||
| reference_vector : np.ndarray or xarray.DataArray or None | ||
| reference_vector : ArrayLike | xarray.DataArray | ||
| The reference vector to be used. Dimensions must be compatible with | ||
| the argument of the same name that is passed to | ||
| :func:`compute_signed_angle_2d`. Default ``(1., 0.)``. | ||
|
|
@@ -560,3 +562,81 @@ def plot( | |
| if fig is None or ax is None: | ||
| fig, ax = plt.subplots(1, 1) | ||
| return self._plot(fig, ax, **matplotlib_kwargs) | ||
|
|
||
| def to_file(self, path: str | Path) -> None: | ||
| """Save the region of interest to a file. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| path : str | Path | ||
| Path to save the ROI file. The file will be saved in JSON format. | ||
|
|
||
| See Also | ||
| -------- | ||
| from_file : Load a region of interest from a file. | ||
|
|
||
| Examples | ||
| -------- | ||
| >>> from movement.roi import PolygonOfInterest | ||
| >>> roi = PolygonOfInterest([(0, 0), (1, 0), (1, 1)], name="triangle") | ||
| >>> roi.to_file("my_roi.json") # doctest: +SKIP | ||
|
|
||
| """ | ||
| data = { | ||
| "name": self._name, | ||
| "geometry_wkt": self.region.wkt, | ||
| "dimensions": self.dimensions, | ||
| "roi_type": self.__class__.__name__, | ||
| } | ||
| Path(path).write_text(json.dumps(data, indent=2)) | ||
|
Comment on lines
+566
to
+591
|
||
|
|
||
| @classmethod | ||
| def from_file(cls, path: str | Path) -> BaseRegionOfInterest: | ||
| """Load a region of interest from a file. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| path : str | Path | ||
| Path to the ROI file to load. Must be a JSON file saved by | ||
| :meth:`to_file`. | ||
|
|
||
| Returns | ||
| ------- | ||
| BaseRegionOfInterest | ||
| The loaded region of interest object. The specific subclass | ||
| (LineOfInterest or PolygonOfInterest) is determined from the file. | ||
|
|
||
| Raises | ||
| ------ | ||
| FileNotFoundError | ||
| If the specified file does not exist. | ||
|
|
||
| See Also | ||
| -------- | ||
| to_file : Save a region of interest to a file. | ||
|
|
||
| Examples | ||
| -------- | ||
| >>> from movement.roi import PolygonOfInterest | ||
| >>> roi = PolygonOfInterest.from_file("my_roi.json") # doctest: +SKIP | ||
|
|
||
| """ | ||
| file_path = Path(path) | ||
| if not file_path.exists(): | ||
| raise FileNotFoundError(f"ROI file not found: {path}") | ||
|
|
||
| data = json.loads(file_path.read_text()) | ||
| geometry = shapely.from_wkt(data["geometry_wkt"]) | ||
|
Comment on lines
+628
to
+629
|
||
|
|
||
| # Import here to avoid circular imports | ||
| from movement.roi import LineOfInterest, PolygonOfInterest | ||
|
|
||
| roi_type = data.get("roi_type", "") | ||
| if roi_type == "LineOfInterest" or data["dimensions"] == 1: | ||
| return LineOfInterest._from_geometry( | ||
| geometry, name=data.get("name") | ||
| ) | ||
| else: | ||
| return PolygonOfInterest._from_geometry( | ||
| geometry, name=data.get("name") | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -85,6 +85,35 @@ def __init__( | |
| ) | ||
| super().__init__(geometry=polygon, name=name) | ||
|
|
||
| @classmethod | ||
| def _from_geometry( | ||
| cls, | ||
| geometry: shapely.Polygon, | ||
| name: str | None = None, | ||
| ) -> PolygonOfInterest: | ||
| """Construct a PolygonOfInterest from a shapely geometry. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| geometry : shapely.Polygon | ||
| The shapely geometry to construct from. | ||
| name : str, optional | ||
| Name for the PolygonOfInterest. | ||
|
|
||
| Returns | ||
| ------- | ||
| PolygonOfInterest | ||
| A new PolygonOfInterest instance. | ||
|
|
||
| """ | ||
| exterior = geometry.exterior.coords | ||
| holes = ( | ||
| [interior.coords for interior in geometry.interiors] | ||
| if geometry.interiors | ||
| else None | ||
| ) | ||
| return cls(exterior_boundary=exterior, holes=holes, name=name) | ||
|
Comment on lines
+88
to
+115
|
||
|
|
||
| @property | ||
| def _default_plot_args(self) -> dict[str, Any]: | ||
| return { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The URL validation only checks if the path starts with "http://" or "https://". Consider adding additional validation to ensure the URL is well-formed and potentially restricting to trusted domains if appropriate. While pooch.retrieve is secure by default, adding URL validation could prevent potential issues with malformed URLs.