Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions docs/source/user_guide/input_output.md
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,47 @@ If so, make sure to save them to a netCDF file that satisfies the
[GUI compatibility requirements](target-gui-compatible-netcdf).
:::

(target-zarr)=
## Saving and loading with Zarr

[Zarr](https://zarr.readthedocs.io/) is an open format for storing
chunked, compressed N-dimensional arrays. It is particularly well-suited
for large datasets and cloud or remote storage.
Like netCDF, Zarr is natively supported by xarray.

To save any xarray dataset `ds` to a Zarr store:
```python
ds.to_zarr("path/to/dataset.zarr", consolidated=True)
```

To load it back:
```python
import xarray as xr

ds = xr.open_zarr("path/to/dataset.zarr", consolidated=True)
```

Similarly, an {class}`xarray.DataArray` object (e.g. the `position` variable
of a `movement` dataset) can be saved to a Zarr store using the
{meth}`to_zarr()<xarray.DataArray.to_zarr()>` method.
Because Zarr stores correspond to Dataset objects, the DataArray is internally
converted to a Dataset before saving. To load a specific DataArray back,
load the store as a Dataset and select the variable by name:

```python
da.to_zarr("path/to/dataarray.zarr", consolidated=True)

import xarray as xr

da_loaded = xr.open_zarr("path/to/dataarray.zarr", consolidated=True)["position"]
```

:::{note}
For more details on Zarr I/O options (chunking, compression, cloud
storage backends), refer to the
[xarray documentation on Zarr](https://docs.xarray.dev/en/stable/user-guide/io.html#zarr).
:::

(target-sample-data)=
## Sample data

Expand Down
86 changes: 86 additions & 0 deletions tests/test_integration/test_zarr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Test saving movement datasets to Zarr stores."""

import pandas as pd
import pytest
import xarray as xr

from movement.filtering import filter_by_confidence, rolling_filter
from movement.kinematics import compute_forward_vector, compute_speed
from movement.transforms import scale


@pytest.fixture
def processed_dataset(valid_poses_dataset):
"""Process a valid poses dataset by applying filters and transforms."""
ds = valid_poses_dataset.copy()
ds["position_filtered"] = filter_by_confidence(
ds["position"], ds["confidence"], threshold=0.5
)
ds["position_smoothed"] = rolling_filter(
ds["position"], window=3, min_periods=2, statistic="median"
)
ds["position_scaled"] = scale(
ds["position_smoothed"], factor=1 / 10, space_unit="cm"
)
return ds


@pytest.fixture
def dataset_with_derived_variables(valid_poses_dataset):
"""Create a dataset with some derived variables."""
ds = valid_poses_dataset.copy()
ds["speed"] = compute_speed(ds["position"])
ds["forward_vector"] = compute_forward_vector(
ds["position"], "left", "right"
)
return ds


@pytest.fixture
def dataset_with_datetime_index(valid_poses_dataset):
"""Create a dataset with a pd.DateTimeIndex as the time coordinate."""
ds = valid_poses_dataset.copy()
timestamps = pd.date_range(
start=pd.Timestamp.now(),
periods=ds.sizes["time"],
freq=pd.Timedelta(seconds=1),
)
ds.assign_coords(time=timestamps)
return ds


@pytest.mark.parametrize(
"dataset",
[
"valid_poses_dataset",
"valid_poses_dataset_with_nan",
"valid_bboxes_dataset", # time unit is in frames
"valid_bboxes_dataset_in_seconds",
"valid_bboxes_dataset_with_nan",
"processed_dataset",
"dataset_with_derived_variables",
"dataset_with_datetime_index",
],
)
def test_ds_save_and_load_zarr(dataset, tmp_path, request):
"""Test that saving a movement dataset to a Zarr store and then
loading it back returns the same Dataset.
"""
ds = request.getfixturevalue(dataset)
zarr_store = tmp_path / "test_dataset.zarr"
ds.to_zarr(zarr_store, consolidated=True)
loaded_ds = xr.open_zarr(zarr_store, consolidated=True)
loaded_ds.load()
xr.testing.assert_allclose(loaded_ds, ds)
assert loaded_ds.attrs == ds.attrs


def test_da_save_and_load_zarr(valid_poses_dataset, tmp_path):
"""Test saving a DataArray to a Zarr store and loading it back."""
da = valid_poses_dataset["position"]
zarr_store = tmp_path / "test_dataarray.zarr"
da.to_zarr(zarr_store, consolidated=True)
loaded_da = xr.open_zarr(zarr_store, consolidated=True)["position"]
loaded_da.load()
xr.testing.assert_allclose(loaded_da, da)
assert loaded_da.attrs == da.attrs
Loading