Skip to content

Commit 6c89e44

Browse files
Merge pull request #544 from matthewhoffman/landice/esmf_interp_fix
Fix bugs in ESMF interpolation method This merge fixes a number of issues with the 'esmf' method for interpolation in the MALI interpolation script: * indexing was off by one, which made interpolation with unstructured source meshes garbage. For structured source meshes, it made interpolation shifted by one grid cell. * support MPAS source fields with a vertical dimension when using the 'esmf' method * refactor to use sparse matrix multiply, which speeds up interpolation a few hundred times * add destination mesh area normalization support. This is necessary when using the ESMF 'conserve' method for destination cells that are only partly overlapped by the source cells. Two issues remain for interpolating between two MPAS meshes with the 'esmf' method: 1. If the destination mesh is larger than the source mesh, those locations are filled with zeros. The ESMF 'conserve' method does not support extrapolation, and there is no obvious solution to this issue and it would need to be handled manually on a case by case basis. 2. Some fields are only defined on subdomains (e.g. temperature is only defined where ice thickness is nonzero) and the script currently has no mechanism for masking them. This results in garbage values getting interpolated in, e.g. temperature values near the margin will have values around 100 K because of interpolating realistic values around 250K with garbage values of 0K. This issue applies to the barycentric method as well. However, this situation can be worked around by performing extrapolation of the temperature field on the source mesh before doing interpolation between meshes. This PR also includes a refactoring of create_SCRIP_file_from_planar_rectangular_grid.py that speeds it up by orders of magnitude for large meshes, as well as a minor update to define_cullMask.py
2 parents bcedec2 + b0e6147 commit 6c89e44

File tree

3 files changed

+40
-23
lines changed

3 files changed

+40
-23
lines changed

landice/mesh_tools_li/define_cullMask.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
if keepCellMask[n] == 1:
102102
keepCellMaskNew[iCell] = 1
103103
keepCellMask = np.copy(keepCellMaskNew) # after we've looped over all cells assign the new mask to the variable we need (either for another loop around the domain or to write out)
104-
print(' Num of cells to keep: {}'.format(sum(keepCellMask)))
104+
print(f'Num of cells to keep: {keepCellMask.sum()}')
105105

106106
# Now convert the keepCellMask to the cullMask
107107
cullCell[:] = np.absolute(keepCellMask[:]-1) # Flip the mask for which ones to cull
@@ -148,7 +148,7 @@
148148
ind = np.nonzero(((xCell-xCell[iCell])**2 + (yCell-yCell[iCell])**2)**0.5 < dist)[0]
149149
keepCellMask[ind] = 1
150150

151-
print(' Num of cells to keep:'.format(sum(keepCellMask)))
151+
print(f'Num of cells to keep: {keepCellMask.sum()}')
152152

153153
# Now convert the keepCellMask to the cullMask
154154
cullCell[:] = np.absolute(keepCellMask[:]-1) # Flip the mask for which ones to cull

landice/mesh_tools_li/interpolate_to_mpasli_grid.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import math
3333
from collections import OrderedDict
3434
import scipy.spatial
35+
import scipy.sparse
3536
import time
3637
from datetime import datetime
3738

@@ -64,9 +65,15 @@
6465
S = wfile.variables['S'][:]
6566
col = wfile.variables['col'][:]
6667
row = wfile.variables['row'][:]
68+
n_a = len(wfile.dimensions['n_a'])
69+
n_b = len(wfile.dimensions['n_b'])
70+
dst_frac = wfile.variables['frac_b'][:]
6771
wfile.close()
6872
#----------------------------
6973

74+
# convert to SciPy Compressed Sparse Row (CSR) matrix format
75+
weights_csr = scipy.sparse.coo_array((S, (row - 1, col - 1)), shape=(n_b, n_a)).tocsr()
76+
7077
print('') # make a space in stdout before further output
7178

7279

@@ -78,15 +85,20 @@
7885
#----------------------------
7986

8087
def ESMF_interp(sourceField):
81-
# Interpolates from the sourceField to the destinationField using ESMF weights
88+
# Interpolates from the sourceField to the destinationField using ESMF weights
89+
destinationField = np.zeros(xCell.shape) # fields on cells only
8290
try:
83-
# Initialize new field to 0 - required
84-
destinationField = np.zeros(xCell.shape) # fields on cells only
85-
sourceFieldFlat = sourceField.flatten() # Flatten source field
86-
for i in range(len(row)):
87-
destinationField[row[i]-1] = destinationField[row[i]-1] + S[i] * sourceFieldFlat[col[i]]
91+
# Convert the source field into the SciPy Compressed Sparse Row matrix format
92+
# This needs some reshaping to get the matching dimensions
93+
source_csr = scipy.sparse.csr_matrix(sourceField.flatten()[:, np.newaxis])
94+
# Use SciPy CSR dot product - much faster than iterating over elements of the full matrix
95+
destinationField = weights_csr.dot(source_csr).toarray().squeeze()
96+
# For conserve remapping, need to normalize by destination area fraction
97+
# It should be safe to do this for other methods
98+
ind = np.where(dst_frac > 0.0)[0]
99+
destinationField[ind] /= dst_frac[ind]
88100
except:
89-
'error in ESMF_interp'
101+
print('error in ESMF_interp')
90102
return destinationField
91103

92104
#----------------------------
@@ -328,7 +340,7 @@ def interpolate_field_with_layers(MPASfieldName):
328340
if filetype=='cism':
329341
print(' Input layer {}, layer {} min/max: {} {}'.format(z, InputFieldName, InputField[z,:,:].min(), InputField[z,:,:].max()))
330342
elif filetype=='mpas':
331-
print(' Input layer {}, layer {} min/max: {} {}'.format(z, InputFieldName, InputField[:,z].min(), InputField[z,:].max()))
343+
print(' Input layer {}, layer {} min/max: {} {}'.format(z, InputFieldName, InputField[:,z].min(), InputField[:,z].max()))
332344
# Call the appropriate routine for actually doing the interpolation
333345
if args.interpType == 'b':
334346
print(" ...Layer {}, Interpolating this layer to MPAS grid using built-in bilinear method...".format(z))
@@ -349,7 +361,10 @@ def interpolate_field_with_layers(MPASfieldName):
349361
mpas_grid_input_layers[z,:] = InputField[:,z].flatten()[nn_idx_cell] # 2d cism fields need to be flattened. (Note the indices were flattened during init, so this just matches that operation for the field data itself.) 1d mpas fields do not, but the operation won't do anything because they are already flat.
350362
elif args.interpType == 'e':
351363
print(" ...Layer{}, Interpolating this layer to MPAS grid using ESMF-weights method...".format(z))
352-
mpas_grid_input_layers[z,:] = ESMF_interp(InputField[z,:,:])
364+
if filetype=='cism':
365+
mpas_grid_input_layers[z,:] = ESMF_interp(InputField[z,:,:])
366+
elif filetype=='mpas':
367+
mpas_grid_input_layers[z,:] = ESMF_interp(InputField[:,z])
353368
else:
354369
sys.exit('ERROR: Unknown interpolation method specified')
355370
print(' interpolated MPAS {}, layer {} min/max {} {}: '.format(MPASfieldName, z, mpas_grid_input_layers[z,:].min(), mpas_grid_input_layers[z,:].max()))

mesh_tools/create_SCRIP_files/create_SCRIP_file_from_planar_rectangular_grid.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -125,18 +125,19 @@
125125
print ('Filling in corners of each cell.')
126126
grid_corner_lon_local = np.zeros( (nx * ny, 4) ) # It is WAYYY faster to fill in the array entry-by-entry in memory than to disk.
127127
grid_corner_lat_local = np.zeros( (nx * ny, 4) )
128-
for j in range(ny):
129-
for i in range(nx):
130-
iCell = j*nx + i
131-
132-
grid_corner_lon_local[iCell, 0] = stag_lon[j, i]
133-
grid_corner_lon_local[iCell, 1] = stag_lon[j, i+1]
134-
grid_corner_lon_local[iCell, 2] = stag_lon[j+1, i+1]
135-
grid_corner_lon_local[iCell, 3] = stag_lon[j+1, i]
136-
grid_corner_lat_local[iCell, 0] = stag_lat[j, i]
137-
grid_corner_lat_local[iCell, 1] = stag_lat[j, i+1]
138-
grid_corner_lat_local[iCell, 2] = stag_lat[j+1, i+1]
139-
grid_corner_lat_local[iCell, 3] = stag_lat[j+1, i]
128+
129+
jj = np.arange(ny)
130+
ii = np.arange(nx)
131+
i_ind, j_ind = np.meshgrid(ii, jj)
132+
cell_ind = j_ind * nx + i_ind
133+
grid_corner_lon_local[cell_ind, 0] = stag_lon[j_ind, i_ind]
134+
grid_corner_lon_local[cell_ind, 1] = stag_lon[j_ind, i_ind + 1]
135+
grid_corner_lon_local[cell_ind, 2] = stag_lon[j_ind + 1, i_ind + 1]
136+
grid_corner_lon_local[cell_ind, 3] = stag_lon[j_ind + 1, i_ind]
137+
grid_corner_lat_local[cell_ind, 0] = stag_lat[j_ind, i_ind]
138+
grid_corner_lat_local[cell_ind, 1] = stag_lat[j_ind, i_ind + 1]
139+
grid_corner_lat_local[cell_ind, 2] = stag_lat[j_ind + 1, i_ind + 1]
140+
grid_corner_lat_local[cell_ind, 3] = stag_lat[j_ind + 1, i_ind]
140141

141142
grid_corner_lon[:] = grid_corner_lon_local[:]
142143
grid_corner_lat[:] = grid_corner_lat_local[:]
@@ -171,3 +172,4 @@
171172

172173
fin.close()
173174
fout.close()
175+
print('scrip file generation complete')

0 commit comments

Comments
 (0)