Skip to content

Commit 0cd7ef6

Browse files
aulemahalpre-commit-ci[bot]huard
authored
Remove need for scipy, fix parallel regridder (#461)
* Remove need for scipy. - Better name collision avoidment in init para regrid * upd changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update xesmf/smm.py Co-authored-by: David Huard <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add comments - simplify unname-rename * fix fix - add 0 removal to add_nans to imitate previous --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: David Huard <[email protected]>
1 parent eef222c commit 0cd7ef6

File tree

4 files changed

+86
-44
lines changed

4 files changed

+86
-44
lines changed

CHANGES.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
What's new
22
==========
33

4+
0.9.1 (unreleased)
5+
------------------
6+
* Remove scipy-dependent code in ``add_nans_to_weight``. By `Pascal Bourgault <https://github.com/aulemahal>`_.
7+
* Fix some name collision issues in the parallel regridder initialisation. By `Pascal Bourgault <https://github.com/aulemahal>`_.
8+
49
0.9.0 (2025-11-21)
510
------------------
611
* Added ``Regridder`` option ``post_mask_source`` to mask contributions of specified source grid cells, with a special setting for masking domain edge cells to avoid extrapolation with ``nearest_s2d`` when remapping to a larger domain (``post_mask_source = 'domain_edge'``, :pull:`444`). By `Martin Schupfner <https://github.com/sol1105>`_.

xesmf/frontend.py

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,14 @@
2222
post_apply_target_mask_to_weights,
2323
read_weights,
2424
)
25-
from .util import LAT_CF_ATTRS, LON_CF_ATTRS, _get_edge_indices_2d, split_polygons_and_holes
25+
from .util import (
26+
LAT_CF_ATTRS,
27+
LON_CF_ATTRS,
28+
_get_edge_indices_2d,
29+
_rename_dataset,
30+
_unname_dataset,
31+
split_polygons_and_holes,
32+
)
2633

2734
try:
2835
import dask.array as da
@@ -39,16 +46,8 @@ def subset_regridder(
3946
kwargs.pop('filename', None) # Don't save subset of weights
4047
kwargs.pop('reuse_weights', None)
4148

42-
# Renaming dims to original names for the subset regridding
43-
if locstream_in:
44-
ds_in = ds_in.rename({'x_in': in_dims[0]})
45-
else:
46-
ds_in = ds_in.rename({'y_in': in_dims[0], 'x_in': in_dims[1]})
47-
48-
if locstream_out:
49-
ds_out = ds_out.rename({'x_out': out_dims[1]})
50-
else:
51-
ds_out = ds_out.rename({'y_out': out_dims[0], 'x_out': out_dims[1]})
49+
ds_in = _rename_dataset(ds_in, locstream_in, in_dims, '_in')
50+
ds_out = _rename_dataset(ds_out, locstream_out, out_dims, '_out')
5251

5352
regridder = Regridder(
5453
ds_in, ds_out, method, locstream_in, locstream_out, periodic, parallel=False, **kwargs
@@ -1153,32 +1152,13 @@ def _init_para_regrid(self, ds_in, ds_out, kwargs):
11531152
ds_out[ds_out.cf['latitude'].name].attrs['bounds'] = 'lat_bounds'
11541153
ds_out = ds_out.drop_dims(ds_out.lon_b.dims + ds_out.lat_b.dims)
11551154
# rename dims to avoid map_blocks confusing ds_in and ds_out dims.
1156-
if self.sequence_in:
1157-
ds_in = ds_in.rename({self.in_horiz_dims[0]: 'x_in'})
1158-
else:
1159-
ds_in = ds_in.rename({self.in_horiz_dims[0]: 'y_in', self.in_horiz_dims[1]: 'x_in'})
1160-
1161-
if self.sequence_out:
1162-
ds_out = ds_out.rename({self.out_horiz_dims[1]: 'x_out'})
1163-
else:
1164-
ds_out = ds_out.rename(
1165-
{self.out_horiz_dims[0]: 'y_out', self.out_horiz_dims[1]: 'x_out'}
1166-
)
1155+
ds_in = _unname_dataset(ds_in, self.sequence_in, self.in_horiz_dims, '_in')
1156+
ds_out = _unname_dataset(ds_out, self.sequence_out, self.out_horiz_dims, '_out')
11671157

11681158
out_chunks = {k: ds_out.chunks.get(k) for k in ['y_out', 'x_out']}
11691159
in_chunks = {k: ds_in.chunks.get(k) for k in ['y_in', 'x_in']}
11701160
chunks = out_chunks | in_chunks
11711161

1172-
# Rename coords to avoid issues in xr.map_blocks
1173-
# If coords and dims are the same, renaming has already been done.
1174-
ds_out = ds_out.rename(
1175-
{
1176-
coord: coord + '_out'
1177-
for coord in self.out_coords.coords.keys()
1178-
if coord not in self.out_horiz_dims
1179-
}
1180-
)
1181-
11821162
weights_dims = ('y_out', 'x_out', 'y_in', 'x_in')
11831163
templ = sps.zeros((self.shape_out + self.shape_in))
11841164
w_templ = xr.DataArray(templ, dims=weights_dims).chunk(

xesmf/smm.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,8 @@ def add_nans_to_weights(weights):
260260
261261
By default, empty rows in the weights sparse matrix are interpreted as zeroes. This can become problematic
262262
when the field being interpreted has legitimate null values. This function inserts nan values in each row to
263-
make sure empty weights are propagated as nans instead of zeros.
263+
make sure empty weights are propagated as nans instead of zeros. It also removes unnecessary entries, ones
264+
where the data is the same as the fill value (0).
264265
265266
Parameters
266267
----------
@@ -272,17 +273,30 @@ def add_nans_to_weights(weights):
272273
DataArray backed by a sparse.COO array
273274
Sparse weights matrix.
274275
"""
275-
276-
# Taken from @trondkr and adapted by @raphaeldussin to use `lil`.
277-
# lil matrix is better than CSR when changing sparsity
278-
m = weights.data.to_scipy_sparse().tolil()
279-
# replace empty rows by one nan value at element 0 (arbitrary)
280-
# so that remapped element become nan instead of zero
281-
for krow in range(len(m.rows)):
282-
m.rows[krow] = [0] if m.rows[krow] == [] else m.rows[krow]
283-
m.data[krow] = [np.nan] if m.data[krow] == [] else m.data[krow]
284-
# update regridder weights (in COO)
285-
weights = weights.copy(data=sps.COO.from_scipy_sparse(m))
276+
# Taken from @trondkr and adapted by @raphaeldussin to use `lil`, translated to COO by @aulemahal
277+
coo = weights.data
278+
coords = coo.coords
279+
data = coo.data
280+
# Remove unnecessary entries (roundtrip through scipy's lil did that implicitely)
281+
coords = coords[:, data != coo.fill_value]
282+
data = data[data != coo.fill_value]
283+
284+
# Replace rows with no weights with a NaN at element 0, so that remapped elements are NaNs instead of zeros.
285+
# Fin rows with no entry in the weights, the unmapped ones
286+
unmapped_rows = set(np.arange(coo.shape[0])) - set(coords[0])
287+
# Generate one coord bper unmapped row
288+
new_coords = np.array([list(unmapped_rows), [0] * len(unmapped_rows)], dtype=coords.dtype)
289+
# Assign a NaN to the new coord so the scalar product of that row gives a NaN
290+
new_data = np.full((len(unmapped_rows),), np.nan)
291+
292+
# Recreate the new COO weights matrix
293+
new = sps.COO(
294+
np.hstack((coords, new_coords)),
295+
np.hstack((data, new_data)),
296+
coo.shape,
297+
fill_value=coo.fill_value,
298+
)
299+
weights = weights.copy(data=new)
286300
return weights
287301

288302

xesmf/util.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,3 +415,46 @@ def _get_edge_indices_2d(nlons, nlats):
415415
edge_mask[:, :1] = True
416416
edge_mask[:, -1:] = True
417417
return np.where(edge_mask.ravel())[0]
418+
419+
420+
def _unname_dataset(ds, sequence, dims, suffix):
421+
"""Rename everything in a dataset so that it can be aligned without modification with another."""
422+
if sequence:
423+
dim = list(set(dims) - {'dummy'})[0]
424+
ds = ds.rename({dim: f'x{suffix}'})
425+
else:
426+
ds = ds.rename({dims[0]: f'y{suffix}', dims[1]: f'x{suffix}'})
427+
if ds[f'x{suffix}'].attrs.get('bounds'):
428+
ds = ds.rename({ds[f'x{suffix}'].attrs['bounds']: f'x{suffix}_bounds'})
429+
ds[f'x{suffix}'].attrs['bounds'] = f'x{suffix}_bounds'
430+
if not sequence and ds[f'y{suffix}'].attrs.get('bounds'):
431+
ds = ds.rename({ds[f'y{suffix}'].attrs['bounds']: f'y{suffix}_bounds'})
432+
ds[f'y{suffix}'].attrs['bounds'] = f'y{suffix}_bounds'
433+
434+
# If coords and dims are the same, renaming has already been done.
435+
ds = ds.rename(
436+
{
437+
coord: coord + suffix
438+
for coord in ds.coords.keys()
439+
if coord not in (f'y{suffix}', f'x{suffix}')
440+
}
441+
)
442+
return ds
443+
444+
445+
def _rename_dataset(ds, sequence, dims, suffix):
446+
"""Restore coordinate names from an "unnamed" dataset"""
447+
ds = ds.rename(
448+
{
449+
coord: coord.rstrip(suffix)
450+
for coord in ds.coords.keys()
451+
if coord not in dims
452+
and coord.endswith(suffix)
453+
and coord not in (f'y{suffix}', f'x{suffix}')
454+
}
455+
)
456+
if sequence:
457+
ds = ds.rename({f'x{suffix}': dims[0]})
458+
else:
459+
ds = ds.rename({f'y{suffix}': dims[0], f'x{suffix}': dims[1]})
460+
return ds

0 commit comments

Comments
 (0)