diff --git a/weatherbenchX/interpolations.py b/weatherbenchX/interpolations.py index 8523fdf..93db5f1 100644 --- a/weatherbenchX/interpolations.py +++ b/weatherbenchX/interpolations.py @@ -433,3 +433,39 @@ def interpolate_data_array( ), ) return out + + +class Subsample(Interpolation): + """Subsample a DataArray along specified dimensions. + + This is useful for reducing the resolution of a dataset without interpolation, + e.g. for faster evaluation at lower resolution. + """ + + def __init__( + self, + dims: Sequence[str], + stride: int, + ): + """Init. + + Args: + dims: Dimensions along which to subsample. + stride: Stride for subsampling. Must be a positive integer. + """ + if stride < 1: + raise ValueError(f'stride must be >= 1, got {stride}') + self._dims = dims + self._stride = stride + + def interpolate_data_array( + self, + da: xr.DataArray, + reference: Optional[xr.DataArray] = None, + ) -> xr.DataArray: + isel_kwargs = { + dim: slice(None, None, self._stride) + for dim in self._dims + if dim in da.dims + } + return da.isel(**isel_kwargs) diff --git a/weatherbenchX/interpolations_test.py b/weatherbenchX/interpolations_test.py index b4178fe..6fadecf 100644 --- a/weatherbenchX/interpolations_test.py +++ b/weatherbenchX/interpolations_test.py @@ -220,6 +220,90 @@ def test_crop_to_box_invalid_lat(self): interpolations.CropToBox(lat_min=10, lat_max=-10, lon_min=0, lon_max=10) +class SubsampleTest(absltest.TestCase): + + def test_subsample_basic(self): + lats = np.arange(0, 100, 1.0) + lons = np.arange(0, 200, 1.0) + da = xr.DataArray( + name='t2m', + data=np.random.rand(len(lats), len(lons)), + coords={'latitude': lats, 'longitude': lons}, + dims=['latitude', 'longitude'], + ) + subsampler = interpolations.Subsample( + dims=['latitude', 'longitude'], stride=10 + ) + result = subsampler.interpolate_data_array(da) + self.assertEqual(result.sizes['latitude'], 10) + self.assertEqual(result.sizes['longitude'], 20) + np.testing.assert_equal(result.latitude.values, lats[::10]) + np.testing.assert_equal(result.longitude.values, lons[::10]) + + def test_subsample_stride_1_is_noop(self): + lats = np.arange(0, 10, 1.0) + lons = np.arange(0, 20, 1.0) + da = xr.DataArray( + name='t2m', + data=np.random.rand(len(lats), len(lons)), + coords={'latitude': lats, 'longitude': lons}, + dims=['latitude', 'longitude'], + ) + subsampler = interpolations.Subsample( + dims=['latitude', 'longitude'], stride=1 + ) + result = subsampler.interpolate_data_array(da) + xr.testing.assert_equal(result, da) + + def test_subsample_missing_dim_is_skipped(self): + lats = np.arange(0, 10, 1.0) + da = xr.DataArray( + name='t2m', + data=np.random.rand(len(lats)), + coords={'latitude': lats}, + dims=['latitude'], + ) + subsampler = interpolations.Subsample( + dims=['latitude', 'longitude'], stride=2 + ) + result = subsampler.interpolate_data_array(da) + self.assertEqual(result.sizes['latitude'], 5) + self.assertNotIn('longitude', result.dims) + + def test_subsample_single_dim(self): + lats = np.arange(0, 12, 1.0) + lons = np.arange(0, 20, 1.0) + da = xr.DataArray( + name='t2m', + data=np.random.rand(len(lats), len(lons)), + coords={'latitude': lats, 'longitude': lons}, + dims=['latitude', 'longitude'], + ) + subsampler = interpolations.Subsample(dims=['latitude'], stride=3) + result = subsampler.interpolate_data_array(da) + self.assertEqual(result.sizes['latitude'], 4) + self.assertEqual(result.sizes['longitude'], 20) + + def test_subsample_invalid_stride(self): + with self.assertRaisesRegex(ValueError, 'stride must be >= 1'): + interpolations.Subsample(dims=['latitude'], stride=0) + + def test_subsample_via_interpolate(self): + lats = np.arange(0, 10, 1.0) + lons = np.arange(0, 20, 1.0) + ds = { + 't2m': xr.DataArray( + data=np.random.rand(len(lats), len(lons)), + coords={'latitude': lats, 'longitude': lons}, + dims=['latitude', 'longitude'], + ), + } + subsampler = interpolations.Subsample( + dims=['latitude', 'longitude'], stride=2 + ) + result = subsampler.interpolate(ds) + self.assertEqual(result['t2m'].sizes['latitude'], 5) + self.assertEqual(result['t2m'].sizes['longitude'], 10) if __name__ == '__main__':