@@ -220,6 +220,90 @@ def test_crop_to_box_invalid_lat(self):
220220 interpolations .CropToBox (lat_min = 10 , lat_max = - 10 , lon_min = 0 , lon_max = 10 )
221221
222222
223+ class SubsampleTest (absltest .TestCase ):
224+
225+ def test_subsample_basic (self ):
226+ lats = np .arange (0 , 100 , 1.0 )
227+ lons = np .arange (0 , 200 , 1.0 )
228+ da = xr .DataArray (
229+ name = 't2m' ,
230+ data = np .random .rand (len (lats ), len (lons )),
231+ coords = {'latitude' : lats , 'longitude' : lons },
232+ dims = ['latitude' , 'longitude' ],
233+ )
234+ subsampler = interpolations .Subsample (
235+ dims = ['latitude' , 'longitude' ], stride = 10
236+ )
237+ result = subsampler .interpolate_data_array (da )
238+ self .assertEqual (result .sizes ['latitude' ], 10 )
239+ self .assertEqual (result .sizes ['longitude' ], 20 )
240+ np .testing .assert_equal (result .latitude .values , lats [::10 ])
241+ np .testing .assert_equal (result .longitude .values , lons [::10 ])
242+
243+ def test_subsample_stride_1_is_noop (self ):
244+ lats = np .arange (0 , 10 , 1.0 )
245+ lons = np .arange (0 , 20 , 1.0 )
246+ da = xr .DataArray (
247+ name = 't2m' ,
248+ data = np .random .rand (len (lats ), len (lons )),
249+ coords = {'latitude' : lats , 'longitude' : lons },
250+ dims = ['latitude' , 'longitude' ],
251+ )
252+ subsampler = interpolations .Subsample (
253+ dims = ['latitude' , 'longitude' ], stride = 1
254+ )
255+ result = subsampler .interpolate_data_array (da )
256+ xr .testing .assert_equal (result , da )
257+
258+ def test_subsample_missing_dim_is_skipped (self ):
259+ lats = np .arange (0 , 10 , 1.0 )
260+ da = xr .DataArray (
261+ name = 't2m' ,
262+ data = np .random .rand (len (lats )),
263+ coords = {'latitude' : lats },
264+ dims = ['latitude' ],
265+ )
266+ subsampler = interpolations .Subsample (
267+ dims = ['latitude' , 'longitude' ], stride = 2
268+ )
269+ result = subsampler .interpolate_data_array (da )
270+ self .assertEqual (result .sizes ['latitude' ], 5 )
271+ self .assertNotIn ('longitude' , result .dims )
272+
273+ def test_subsample_single_dim (self ):
274+ lats = np .arange (0 , 12 , 1.0 )
275+ lons = np .arange (0 , 20 , 1.0 )
276+ da = xr .DataArray (
277+ name = 't2m' ,
278+ data = np .random .rand (len (lats ), len (lons )),
279+ coords = {'latitude' : lats , 'longitude' : lons },
280+ dims = ['latitude' , 'longitude' ],
281+ )
282+ subsampler = interpolations .Subsample (dims = ['latitude' ], stride = 3 )
283+ result = subsampler .interpolate_data_array (da )
284+ self .assertEqual (result .sizes ['latitude' ], 4 )
285+ self .assertEqual (result .sizes ['longitude' ], 20 )
286+
287+ def test_subsample_invalid_stride (self ):
288+ with self .assertRaisesRegex (ValueError , 'stride must be >= 1' ):
289+ interpolations .Subsample (dims = ['latitude' ], stride = 0 )
290+
291+ def test_subsample_via_interpolate (self ):
292+ lats = np .arange (0 , 10 , 1.0 )
293+ lons = np .arange (0 , 20 , 1.0 )
294+ ds = {
295+ 't2m' : xr .DataArray (
296+ data = np .random .rand (len (lats ), len (lons )),
297+ coords = {'latitude' : lats , 'longitude' : lons },
298+ dims = ['latitude' , 'longitude' ],
299+ ),
300+ }
301+ subsampler = interpolations .Subsample (
302+ dims = ['latitude' , 'longitude' ], stride = 2
303+ )
304+ result = subsampler .interpolate (ds )
305+ self .assertEqual (result ['t2m' ].sizes ['latitude' ], 5 )
306+ self .assertEqual (result ['t2m' ].sizes ['longitude' ], 10 )
223307
224308
225309if __name__ == '__main__' :
0 commit comments