Skip to content

Commit 2031dd8

Browse files
authored
Merge pull request #833 from davidhassell/regrid-dtos
Allow 'nearest_dtos' 2-d regridding to work with discrete sampling geometry source grids
2 parents a66687f + 0759153 commit 2031dd8

File tree

5 files changed

+137
-21
lines changed

5 files changed

+137
-21
lines changed

Changelog.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ version NEXTVERSION
33

44
**2024-??-??**
55

6+
* Allow ``'nearest_dtos'`` 2-d regridding to work with discrete
7+
sampling geometry source grids
8+
(https://github.com/NCAS-CMS/cf-python/issues/832)
69
* New method: `cf.Field.filled`
710
(https://github.com/NCAS-CMS/cf-python/issues/811)
811
* New method: `cf.Field.is_discrete_axis`

cf/data/dask_regrid.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -507,17 +507,20 @@ def _regrid(
507507
# Note: It is much more efficient to access
508508
# 'weights.indptr', 'weights.indices', and
509509
# 'weights.data' directly, rather than iterating
510-
# over rows of 'weights' and using 'weights.getrow'.
510+
# over rows of 'weights' and using
511+
# 'weights.getrow'. Also, 'np.count_nonzero' is much
512+
# faster than 'np.any' and 'np.all'.
511513
count_nonzero = np.count_nonzero
512514
indptr = weights.indptr.tolist()
513515
indices = weights.indices
514516
data = weights.data
515517
for j, (i0, i1) in enumerate(zip(indptr[:-1], indptr[1:])):
516518
mask = src_mask[indices[i0:i1]]
517-
if not count_nonzero(mask):
519+
n_masked = count_nonzero(mask)
520+
if not n_masked:
518521
continue
519522

520-
if mask.all():
523+
if n_masked == mask.size:
521524
dst_mask[j] = True
522525
continue
523526

@@ -529,8 +532,8 @@ def _regrid(
529532

530533
del indptr
531534

532-
elif method in ("linear", "bilinear", "nearest_dtos"):
533-
# 2) Linear and nearest neighbour methods:
535+
elif method in ("linear", "bilinear"):
536+
# 2) Linear methods:
534537
#
535538
# Mask out any row j that contains at least one positive
536539
# (i.e. greater than or equal to 'min_weight') w_ji that
@@ -546,7 +549,9 @@ def _regrid(
546549
# Note: It is much more efficient to access
547550
# 'weights.indptr', 'weights.indices', and
548551
# 'weights.data' directly, rather than iterating
549-
# over rows of 'weights' and using 'weights.getrow'.
552+
# over rows of 'weights' and using
553+
# 'weights.getrow'. Also, 'np.count_nonzero' is much
554+
# faster than 'np.any' and 'np.all'.
550555
count_nonzero = np.count_nonzero
551556
where = np.where
552557
indptr = weights.indptr.tolist()
@@ -562,12 +567,45 @@ def _regrid(
562567

563568
del indptr, pos_data
564569

570+
elif method == "nearest_dtos":
571+
# 3) Nearest neighbour dtos method:
572+
#
573+
# Set to 0 any weight that corresponds to a masked source
574+
# grid cell.
575+
#
576+
# Mask out any row j for which all source grid cells are
577+
# masked.
578+
dst_size = weights.shape[0]
579+
if dst_mask is None:
580+
dst_mask = np.zeros((dst_size,), dtype=bool)
581+
else:
582+
dst_mask = dst_mask.copy()
583+
584+
# Note: It is much more efficient to access
585+
# 'weights.indptr', 'weights.indices', and
586+
# 'weights.data' directly, rather than iterating
587+
# over rows of 'weights' and using
588+
# 'weights.getrow'. Also, 'np.count_nonzero' is much
589+
# faster than 'np.any' and 'np.all'.
590+
count_nonzero = np.count_nonzero
591+
indptr = weights.indptr.tolist()
592+
indices = weights.indices
593+
for j, (i0, i1) in enumerate(zip(indptr[:-1], indptr[1:])):
594+
mask = src_mask[indices[i0:i1]]
595+
n_masked = count_nonzero(mask)
596+
if n_masked == mask.size:
597+
dst_mask[j] = True
598+
elif n_masked:
599+
weights.data[np.arange(i0, i1)[mask]] = 0
600+
601+
del indptr
602+
565603
elif method in (
566604
"patch",
567605
"conservative_2nd",
568606
"nearest_stod",
569607
):
570-
# 3) Patch recovery and second-order conservative methods:
608+
# 4) Patch recovery and second-order conservative methods:
571609
#
572610
# A reference source data mask has already been
573611
# incorporated into the weights matrix, and 'a' is assumed

cf/docstring/docstring.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,9 @@
161161
mapped to the closest destination point. A
162162
destination point can be mapped to multiple source
163163
points. Some destination points may not be
164-
mapped. Useful for regridding of categorical data.
164+
mapped. Each regridded value is the sum of its
165+
contributing source elements. Useful for binning or
166+
for categorical data.
165167
166168
* `None`: This is the default and can only be used
167169
when *dst* is a `RegridOperator`.""",

cf/regrid/regrid.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -523,11 +523,10 @@ def regrid(
523523
"are a UGRID mesh"
524524
)
525525

526-
if src_grid.is_locstream or dst_grid.is_locstream:
526+
if dst_grid.is_locstream:
527527
raise ValueError(
528-
f"{method!r} regridding is (at the moment) only available "
529-
"when neither the source and destination grids are "
530-
"DSG featureTypes."
528+
f"{method!r} regridding is (at the moment) not available "
529+
"when the destination grid is a DSG featureType."
531530
)
532531

533532
elif cartesian and (src_grid.is_mesh or dst_grid.is_mesh):
@@ -656,6 +655,7 @@ def regrid(
656655
dst=dst,
657656
weights_file=weights_file if from_file else None,
658657
src_mesh_location=src_grid.mesh_location,
658+
src_featureType=src_grid.featureType,
659659
dst_featureType=dst_grid.featureType,
660660
src_z=src_grid.z,
661661
dst_z=dst_grid.z,
@@ -674,6 +674,9 @@ def regrid(
674674
)
675675

676676
if return_operator:
677+
# Note: The `RegridOperator.tosparse` method will also set
678+
# 'dst_mask' to False for destination points with all
679+
# zero weights.
677680
regrid_operator.tosparse()
678681
return regrid_operator
679682

@@ -1279,7 +1282,7 @@ def spherical_grid(
12791282

12801283
# Set cyclicity of X axis
12811284
if mesh_location or featureType:
1282-
cyclic = None
1285+
cyclic = False
12831286
elif cyclic is None:
12841287
cyclic = f.iscyclic(x_axis)
12851288
else:
@@ -2281,6 +2284,11 @@ def create_esmpy_locstream(grid, mask=None):
22812284
# but the esmpy mask requires 0/1 for masked/unmasked
22822285
# elements.
22832286
mask = np.invert(mask).astype("int32")
2287+
if mask.size == 1:
2288+
# Make sure that there's a mask element for each point in
2289+
# the locstream (rather than a scalar that applies to all
2290+
# elements).
2291+
mask = np.full((location_count,), mask, dtype="int32")
22842292
else:
22852293
# No masked points
22862294
mask = np.full((location_count,), 1, dtype="int32")

cf/test/test_regrid_featureType.py

Lines changed: 73 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,6 @@
2525
except ImportError:
2626
pass
2727

28-
disallowed_methods = (
29-
"conservative",
30-
"conservative_2nd",
31-
"nearest_dtos",
32-
)
33-
3428
methods = (
3529
"linear",
3630
"nearest_stod",
@@ -169,6 +163,78 @@ def test_Field_regrid_grid_to_featureType_3d(self):
169163
else:
170164
self.assertFalse(y.mask.any())
171165

166+
@unittest.skipUnless(esmpy_imported, "Requires esmpy/ESMF package.")
167+
def test_Field_regrid_featureType_to_grid_2d(self):
168+
self.assertFalse(cf.regrid_logging())
169+
170+
# Create some nice data
171+
src = self.dst_featureType
172+
src.del_construct("cellmethod0")
173+
src = src[:12]
174+
src[...] = 273 + np.arange(12)
175+
x = src.coord("X")
176+
x[...] = [4, 6, 9, 11, 14, 16, 4, 6, 9, 11, 14, 16]
177+
y = src.coord("Y")
178+
y[...] = [41, 41, 31, 31, 21, 21, 39, 39, 29, 29, 19, 19]
179+
180+
dst = self.src_grid.copy()
181+
x = dst.coord("X")
182+
x[...] = [5, 10, 15, 20]
183+
y = dst.coord("Y")
184+
y[...] = [10, 20, 30, 40]
185+
186+
# Mask some destination grid points
187+
dst[0, 0, 1, 2] = cf.masked
188+
189+
# Expected destination regridded values
190+
y0 = np.ma.array(
191+
[[0, 0, 0, 0], [0, 0, 1122, 0], [0, 1114, 0, 0], [1106, 0, 0, 0]],
192+
mask=[
193+
[True, True, True, True],
194+
[True, True, False, True],
195+
[True, False, True, True],
196+
[False, True, True, True],
197+
],
198+
)
199+
200+
for src_masked in (False, True):
201+
y = y0.copy()
202+
if src_masked:
203+
src = src.copy()
204+
src[6:8] = cf.masked
205+
# This following element should be smaller, because it
206+
# now only has two source cells contributing to it,
207+
# rather than four.
208+
y[3, 0] = 547
209+
210+
# Loop over whether or not to use the destination grid
211+
# masked points
212+
for use_dst_mask in (False, True):
213+
if use_dst_mask:
214+
y = y.copy()
215+
y[1, 2] = np.ma.masked
216+
217+
kwargs = {"use_dst_mask": use_dst_mask}
218+
method = "nearest_dtos"
219+
for return_operator in (False, True):
220+
if return_operator:
221+
r = src.regrids(
222+
dst, method=method, return_operator=True, **kwargs
223+
)
224+
x = src.regrids(r)
225+
else:
226+
x = src.regrids(dst, method=method, **kwargs)
227+
228+
a = x.array
229+
230+
self.assertEqual(y.size, a.size)
231+
self.assertTrue(np.allclose(y, a, atol=atol, rtol=rtol))
232+
233+
if isinstance(a, np.ma.MaskedArray):
234+
self.assertTrue((y.mask == a.mask).all())
235+
else:
236+
self.assertFalse(y.mask.any())
237+
172238
@unittest.skipUnless(esmpy_imported, "Requires esmpy/ESMF package.")
173239
def test_Field_regrid_grid_to_featureType_2d(self):
174240
self.assertFalse(cf.regrid_logging())
@@ -196,7 +262,6 @@ def test_Field_regrid_grid_to_featureType_2d(self):
196262
a = x.array
197263

198264
y = esmpy_regrid(coord_sys, method, src, dst, **kwargs)
199-
200265
self.assertEqual(y.size, a.size)
201266
self.assertTrue(np.allclose(y, a, atol=atol, rtol=rtol))
202267

@@ -259,7 +324,7 @@ def test_Field_regrid_featureType_bad_methods(self):
259324
dst = self.dst_featureType.copy()
260325
src = self.src_grid.copy()
261326

262-
for method in disallowed_methods:
327+
for method in ("conservative", "conservative_2nd"):
263328
with self.assertRaises(ValueError):
264329
src.regrids(dst, method=method)
265330

0 commit comments

Comments
 (0)