Skip to content

Commit 08ebada

Browse files
raspstephanWeatherBenchX authors
authored andcommitted
Add a subsampling interpolation method
PiperOrigin-RevId: 874562875
1 parent 44e0e36 commit 08ebada

File tree

2 files changed

+120
-0
lines changed

2 files changed

+120
-0
lines changed

weatherbenchX/interpolations.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,3 +433,39 @@ def interpolate_data_array(
433433
),
434434
)
435435
return out
436+
437+
438+
class Subsample(Interpolation):
439+
"""Subsample a DataArray along specified dimensions.
440+
441+
This is useful for reducing the resolution of a dataset without interpolation,
442+
e.g. for faster evaluation at lower resolution.
443+
"""
444+
445+
def __init__(
446+
self,
447+
dims: Sequence[str],
448+
stride: int,
449+
):
450+
"""Init.
451+
452+
Args:
453+
dims: Dimensions along which to subsample.
454+
stride: Stride for subsampling. Must be a positive integer.
455+
"""
456+
if stride < 1:
457+
raise ValueError(f'stride must be >= 1, got {stride}')
458+
self._dims = dims
459+
self._stride = stride
460+
461+
def interpolate_data_array(
462+
self,
463+
da: xr.DataArray,
464+
reference: Optional[xr.DataArray] = None,
465+
) -> xr.DataArray:
466+
isel_kwargs = {
467+
dim: slice(None, None, self._stride)
468+
for dim in self._dims
469+
if dim in da.dims
470+
}
471+
return da.isel(**isel_kwargs)

weatherbenchX/interpolations_test.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,90 @@ def test_crop_to_box_invalid_lat(self):
220220
interpolations.CropToBox(lat_min=10, lat_max=-10, lon_min=0, lon_max=10)
221221

222222

223+
class SubsampleTest(absltest.TestCase):
224+
225+
def test_subsample_basic(self):
226+
lats = np.arange(0, 100, 1.0)
227+
lons = np.arange(0, 200, 1.0)
228+
da = xr.DataArray(
229+
name='t2m',
230+
data=np.random.rand(len(lats), len(lons)),
231+
coords={'latitude': lats, 'longitude': lons},
232+
dims=['latitude', 'longitude'],
233+
)
234+
subsampler = interpolations.Subsample(
235+
dims=['latitude', 'longitude'], stride=10
236+
)
237+
result = subsampler.interpolate_data_array(da)
238+
self.assertEqual(result.sizes['latitude'], 10)
239+
self.assertEqual(result.sizes['longitude'], 20)
240+
np.testing.assert_equal(result.latitude.values, lats[::10])
241+
np.testing.assert_equal(result.longitude.values, lons[::10])
242+
243+
def test_subsample_stride_1_is_noop(self):
244+
lats = np.arange(0, 10, 1.0)
245+
lons = np.arange(0, 20, 1.0)
246+
da = xr.DataArray(
247+
name='t2m',
248+
data=np.random.rand(len(lats), len(lons)),
249+
coords={'latitude': lats, 'longitude': lons},
250+
dims=['latitude', 'longitude'],
251+
)
252+
subsampler = interpolations.Subsample(
253+
dims=['latitude', 'longitude'], stride=1
254+
)
255+
result = subsampler.interpolate_data_array(da)
256+
xr.testing.assert_equal(result, da)
257+
258+
def test_subsample_missing_dim_is_skipped(self):
259+
lats = np.arange(0, 10, 1.0)
260+
da = xr.DataArray(
261+
name='t2m',
262+
data=np.random.rand(len(lats)),
263+
coords={'latitude': lats},
264+
dims=['latitude'],
265+
)
266+
subsampler = interpolations.Subsample(
267+
dims=['latitude', 'longitude'], stride=2
268+
)
269+
result = subsampler.interpolate_data_array(da)
270+
self.assertEqual(result.sizes['latitude'], 5)
271+
self.assertNotIn('longitude', result.dims)
272+
273+
def test_subsample_single_dim(self):
274+
lats = np.arange(0, 12, 1.0)
275+
lons = np.arange(0, 20, 1.0)
276+
da = xr.DataArray(
277+
name='t2m',
278+
data=np.random.rand(len(lats), len(lons)),
279+
coords={'latitude': lats, 'longitude': lons},
280+
dims=['latitude', 'longitude'],
281+
)
282+
subsampler = interpolations.Subsample(dims=['latitude'], stride=3)
283+
result = subsampler.interpolate_data_array(da)
284+
self.assertEqual(result.sizes['latitude'], 4)
285+
self.assertEqual(result.sizes['longitude'], 20)
286+
287+
def test_subsample_invalid_stride(self):
288+
with self.assertRaisesRegex(ValueError, 'stride must be >= 1'):
289+
interpolations.Subsample(dims=['latitude'], stride=0)
290+
291+
def test_subsample_via_interpolate(self):
292+
lats = np.arange(0, 10, 1.0)
293+
lons = np.arange(0, 20, 1.0)
294+
ds = {
295+
't2m': xr.DataArray(
296+
data=np.random.rand(len(lats), len(lons)),
297+
coords={'latitude': lats, 'longitude': lons},
298+
dims=['latitude', 'longitude'],
299+
),
300+
}
301+
subsampler = interpolations.Subsample(
302+
dims=['latitude', 'longitude'], stride=2
303+
)
304+
result = subsampler.interpolate(ds)
305+
self.assertEqual(result['t2m'].sizes['latitude'], 5)
306+
self.assertEqual(result['t2m'].sizes['longitude'], 10)
223307

224308

225309
if __name__ == '__main__':

0 commit comments

Comments
 (0)