Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions weatherbenchX/interpolations.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,3 +433,39 @@ def interpolate_data_array(
),
)
return out


class Subsample(Interpolation):
"""Subsample a DataArray along specified dimensions.

This is useful for reducing the resolution of a dataset without interpolation,
e.g. for faster evaluation at lower resolution.
"""

def __init__(
self,
dims: Sequence[str],
stride: int,
):
"""Init.

Args:
dims: Dimensions along which to subsample.
stride: Stride for subsampling. Must be a positive integer.
"""
if stride < 1:
raise ValueError(f'stride must be >= 1, got {stride}')
self._dims = dims
self._stride = stride

def interpolate_data_array(
self,
da: xr.DataArray,
reference: Optional[xr.DataArray] = None,
) -> xr.DataArray:
isel_kwargs = {
dim: slice(None, None, self._stride)
for dim in self._dims
if dim in da.dims
}
return da.isel(**isel_kwargs)
84 changes: 84 additions & 0 deletions weatherbenchX/interpolations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,90 @@ def test_crop_to_box_invalid_lat(self):
interpolations.CropToBox(lat_min=10, lat_max=-10, lon_min=0, lon_max=10)


class SubsampleTest(absltest.TestCase):

def test_subsample_basic(self):
lats = np.arange(0, 100, 1.0)
lons = np.arange(0, 200, 1.0)
da = xr.DataArray(
name='t2m',
data=np.random.rand(len(lats), len(lons)),
coords={'latitude': lats, 'longitude': lons},
dims=['latitude', 'longitude'],
)
subsampler = interpolations.Subsample(
dims=['latitude', 'longitude'], stride=10
)
result = subsampler.interpolate_data_array(da)
self.assertEqual(result.sizes['latitude'], 10)
self.assertEqual(result.sizes['longitude'], 20)
np.testing.assert_equal(result.latitude.values, lats[::10])
np.testing.assert_equal(result.longitude.values, lons[::10])

def test_subsample_stride_1_is_noop(self):
lats = np.arange(0, 10, 1.0)
lons = np.arange(0, 20, 1.0)
da = xr.DataArray(
name='t2m',
data=np.random.rand(len(lats), len(lons)),
coords={'latitude': lats, 'longitude': lons},
dims=['latitude', 'longitude'],
)
subsampler = interpolations.Subsample(
dims=['latitude', 'longitude'], stride=1
)
result = subsampler.interpolate_data_array(da)
xr.testing.assert_equal(result, da)

def test_subsample_missing_dim_is_skipped(self):
lats = np.arange(0, 10, 1.0)
da = xr.DataArray(
name='t2m',
data=np.random.rand(len(lats)),
coords={'latitude': lats},
dims=['latitude'],
)
subsampler = interpolations.Subsample(
dims=['latitude', 'longitude'], stride=2
)
result = subsampler.interpolate_data_array(da)
self.assertEqual(result.sizes['latitude'], 5)
self.assertNotIn('longitude', result.dims)

def test_subsample_single_dim(self):
lats = np.arange(0, 12, 1.0)
lons = np.arange(0, 20, 1.0)
da = xr.DataArray(
name='t2m',
data=np.random.rand(len(lats), len(lons)),
coords={'latitude': lats, 'longitude': lons},
dims=['latitude', 'longitude'],
)
subsampler = interpolations.Subsample(dims=['latitude'], stride=3)
result = subsampler.interpolate_data_array(da)
self.assertEqual(result.sizes['latitude'], 4)
self.assertEqual(result.sizes['longitude'], 20)

def test_subsample_invalid_stride(self):
with self.assertRaisesRegex(ValueError, 'stride must be >= 1'):
interpolations.Subsample(dims=['latitude'], stride=0)

def test_subsample_via_interpolate(self):
lats = np.arange(0, 10, 1.0)
lons = np.arange(0, 20, 1.0)
ds = {
't2m': xr.DataArray(
data=np.random.rand(len(lats), len(lons)),
coords={'latitude': lats, 'longitude': lons},
dims=['latitude', 'longitude'],
),
}
subsampler = interpolations.Subsample(
dims=['latitude', 'longitude'], stride=2
)
result = subsampler.interpolate(ds)
self.assertEqual(result['t2m'].sizes['latitude'], 5)
self.assertEqual(result['t2m'].sizes['longitude'], 10)


if __name__ == '__main__':
Expand Down
Loading