|
| 1 | +"""Tests for the ggplot-style pipe syntax and verbs.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import pandas as pd |
| 7 | +import xarray as xr |
| 8 | + |
| 9 | +import cubedynamics as cd |
| 10 | + |
| 11 | + |
| 12 | +def _make_time_series(count: int = 12): |
| 13 | + time = pd.date_range("2000-01-01", periods=count, freq="MS") |
| 14 | + data = xr.DataArray( |
| 15 | + np.arange(count, dtype=float), |
| 16 | + dims=("time",), |
| 17 | + coords={"time": time}, |
| 18 | + ) |
| 19 | + return data |
| 20 | + |
| 21 | + |
| 22 | +def test_pipe_basic_chain(): |
| 23 | + da = _make_time_series() |
| 24 | + |
| 25 | + result = (cd.pipe(da) | cd.anomaly(dim="time") | cd.variance(dim="time")).unwrap() |
| 26 | + |
| 27 | + assert isinstance(result, xr.DataArray) |
| 28 | + assert result.dims == () |
| 29 | + assert float(result) >= 0 |
| 30 | + |
| 31 | + |
| 32 | +def test_month_filter_reduces_time(): |
| 33 | + da = _make_time_series(24) |
| 34 | + |
| 35 | + summer = (cd.pipe(da) | cd.month_filter([6, 7, 8])).unwrap() |
| 36 | + |
| 37 | + assert set(int(m) for m in summer["time"].dt.month.values) == {6, 7, 8} |
| 38 | + |
| 39 | + |
| 40 | +def test_to_netcdf_roundtrip(tmp_path): |
| 41 | + da = _make_time_series() |
| 42 | + path = tmp_path / "out.nc" |
| 43 | + |
| 44 | + result = (cd.pipe(da) | cd.to_netcdf(path)).unwrap() |
| 45 | + |
| 46 | + assert path.exists() |
| 47 | + loaded = xr.load_dataarray(path) |
| 48 | + xr.testing.assert_identical(da, loaded) |
| 49 | + xr.testing.assert_identical(da, result) |
0 commit comments