Skip to content

Commit b3d024f

Browse files
Add single-rank lsq pseudoinv factory test (#1099)
Related to #1065 (comment). This uses @nfarabullini's fixes for doing the LSQ_PSEUDOINV computation with a GPU backend from #1012 to have them merged separately. This also adds LSQ_PSEUDOINV to the parallel test in test_parallel_grid_manager.py. This required changing the computation to avoid doing the SVD on halo points. Custom dimensions are added for LSQ_PSEUDOINV to correctly declare the dimensions in the factory. Previously the dimension was equivalne to a scalar, which lead to the factory returning the plain numpy array. With this change the mypy type ignores can be removed from test_parallel_interpolation.py. --------- Co-authored-by: Nicoletta Farabullini <41536517+nfarabullini@users.noreply.github.com>
1 parent 351b63a commit b3d024f

File tree

9 files changed

+57
-23
lines changed

9 files changed

+57
-23
lines changed

model/common/src/icon4py/model/common/dimension.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
CellDim = gtx.Dimension("Cell")
1717
VertexDim = gtx.Dimension("Vertex")
1818
MAIN_HORIZONTAL_DIMENSIONS = {"CellDim": CellDim, "EdgeDim": EdgeDim, "VertexDim": VertexDim}
19+
LsqCDim = gtx.Dimension("LsqC", gtx.DimensionKind.LOCAL)
20+
LsqUnkDim = gtx.Dimension("LsqUnk", gtx.DimensionKind.LOCAL)
1921
E2CDim = gtx.Dimension("E2C", gtx.DimensionKind.LOCAL)
2022
E2VDim = gtx.Dimension("E2V", gtx.DimensionKind.LOCAL)
2123
C2EDim = gtx.Dimension("C2E", gtx.DimensionKind.LOCAL)

model/common/src/icon4py/model/common/interpolation/interpolation_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def _register_computed_fields(self) -> None:
243243
array_ns=self._xp,
244244
),
245245
fields=(attrs.LSQ_PSEUDOINV,),
246-
domain=(),
246+
domain=(dims.CellDim, dims.LsqUnkDim, dims.LsqCDim),
247247
deps={
248248
"cell_center_x": geometry_attrs.CELL_CENTER_X,
249249
"cell_center_y": geometry_attrs.CELL_CENTER_Y,

model/common/src/icon4py/model/common/interpolation/interpolation_fields.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,8 +1167,8 @@ def compute_lsq_pseudoinv(
11671167
for jjb in range(lsq_dim_c):
11681168
for jjk in range(lsq_dim_unk):
11691169
for jc in range(start_idx, min_rlcell_int):
1170-
u, s, v_t = array_ns.linalg.svd(z_lsq_mat_c[jc, :, :])
11711170
if cell_owner_mask[jc]:
1171+
u, s, v_t = array_ns.linalg.svd(z_lsq_mat_c[jc, :, :])
11721172
lsq_pseudoinv[jc, :lsq_dim_unk, jjb] = (
11731173
lsq_pseudoinv[jc, :lsq_dim_unk, jjb]
11741174
+ v_t[jjk, :lsq_dim_unk] / s[jjk] * u[jjb, jjk] * lsq_weights_c[jc, jjb]
@@ -1237,7 +1237,11 @@ def compute_lsq_coeffs(
12371237
for js in range(lsq_dim_stencil):
12381238
z_dist_g[:, js, :] = array_ns.asarray(
12391239
gnomonic_proj(
1240-
cell_lon, cell_lat, cell_lon[c2e2c[:, js]], cell_lat[c2e2c[:, js]]
1240+
cell_lon,
1241+
cell_lat,
1242+
cell_lon[c2e2c[:, js]],
1243+
cell_lat[c2e2c[:, js]],
1244+
array_ns,
12411245
)
12421246
).T
12431247

@@ -1252,15 +1256,17 @@ def compute_lsq_coeffs(
12521256
ilc_s = c2e2c[jc, :lsq_dim_stencil]
12531257
cc_cell = array_ns.zeros((lsq_dim_stencil, 2))
12541258

1255-
cc_cv = (cell_center_x[jc], cell_center_y[jc])
1259+
cc_cv = array_ns.asarray((cell_center_x[jc], cell_center_y[jc]))
12561260
for js in range(lsq_dim_stencil):
1257-
cc_cell[js, :] = diff_on_edges_torus_numpy(
1258-
cell_center_x[jc],
1259-
cell_center_y[jc],
1260-
cell_center_x[ilc_s][js],
1261-
cell_center_y[ilc_s][js],
1262-
domain_length,
1263-
domain_height,
1261+
cc_cell[js, :] = array_ns.asarray(
1262+
diff_on_edges_torus_numpy(
1263+
cell_center_x[jc],
1264+
cell_center_y[jc],
1265+
cell_center_x[ilc_s][js],
1266+
cell_center_y[ilc_s][js],
1267+
domain_length,
1268+
domain_height,
1269+
)
12641270
)
12651271
z_dist_g[jc, :, :] = cc_cell - cc_cv
12661272

model/common/src/icon4py/model/common/math/projection.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#
66
# Please, refer to the LICENSE file in the root directory.
77
# SPDX-License-Identifier: BSD-3-Clause
8-
8+
from types import ModuleType
99

1010
import numpy as np
1111

@@ -17,6 +17,7 @@ def gnomonic_proj(
1717
lat_c: data_alloc.NDArray,
1818
lon: data_alloc.NDArray,
1919
lat: data_alloc.NDArray,
20+
array_ns: ModuleType = np,
2021
) -> tuple[data_alloc.NDArray, data_alloc.NDArray]:
2122
"""
2223
Compute gnomonic projection.
@@ -38,11 +39,16 @@ def gnomonic_proj(
3839
TODO:
3940
replace this with a suitable library call
4041
"""
41-
cosc = np.sin(lat_c) * np.sin(lat) + np.cos(lat_c) * np.cos(lat) * np.cos(lon - lon_c)
42+
cosc = array_ns.sin(lat_c) * array_ns.sin(lat) + array_ns.cos(lat_c) * array_ns.cos(
43+
lat
44+
) * array_ns.cos(lon - lon_c)
4245
zk = 1.0 / cosc
4346

44-
x = zk * np.cos(lat) * np.sin(lon - lon_c)
45-
y = zk * (np.cos(lat_c) * np.sin(lat) - np.sin(lat_c) * np.cos(lat) * np.cos(lon - lon_c))
47+
x = zk * array_ns.cos(lat) * array_ns.sin(lon - lon_c)
48+
y = zk * (
49+
array_ns.cos(lat_c) * array_ns.sin(lat)
50+
- array_ns.sin(lat_c) * array_ns.cos(lat) * array_ns.cos(lon - lon_c)
51+
)
4652

4753
return x, y
4854

model/common/tests/common/grid/mpi_tests/test_parallel_grid_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ def _get_neighbor_tables(grid: base.Grid) -> dict:
8888
def gather_field(field: np.ndarray, props: decomp_defs.ProcessProperties) -> tuple:
8989
constant_dims = tuple(field.shape[1:])
9090
_log.info(f"gather_field on rank={props.rank} - gathering field of local shape {field.shape}")
91+
# Because of sparse indexing the field may have a non-contigous layout,
92+
# which Gatherv doesn't support. Make sure the field is contiguous.
93+
field = np.ascontiguousarray(field)
9194
constant_length = functools.reduce(operator.mul, constant_dims, 1)
9295
local_sizes = np.array(props.comm.gather(field.size, root=0))
9396
if props.rank == 0:
@@ -337,6 +340,7 @@ def test_geometry_fields_compare_single_multi_rank(
337340
interpolation_attributes.GEOFAC_GRG_Y,
338341
interpolation_attributes.GEOFAC_N2S,
339342
interpolation_attributes.GEOFAC_ROT,
343+
interpolation_attributes.LSQ_PSEUDOINV,
340344
interpolation_attributes.NUDGECOEFFS_E,
341345
interpolation_attributes.POS_ON_TPLANE_E_X,
342346
interpolation_attributes.POS_ON_TPLANE_E_Y,

model/common/tests/common/grid/unit_tests/test_icon.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,8 @@ def test_when_replace_skip_values_then_only_pentagon_points_remain(
214214
) -> None:
215215
if dim == dims.V2E2VDim:
216216
pytest.skip("V2E2VDim is not supported in the current grid configuration.")
217+
if dim in (dims.LsqCDim, dims.LsqUnkDim):
218+
pytest.skip("LsqCDim and LsqUnkDim are not offset dimensions.")
217219
grid = utils.run_grid_manager(grid_descriptor, keep_skip_values=False, backend=backend).grid
218220
connectivity = grid.get_connectivity(dim.value)
219221
if dim in icon.CONNECTIVITIES_ON_PENTAGONS and not grid.limited_area:

model/common/tests/common/interpolation/mpi_tests/test_parallel_interpolation.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,8 @@ def test_distributed_interpolation_lsq_pseudoinv(
235235
parallel_helpers.log_process_properties(processor_props)
236236
parallel_helpers.log_local_field_size(decomposition_info)
237237
factory = interpolation_factory_from_savepoint
238-
field_ref_1 = interpolation_savepoint.__getattribute__("lsq_pseudoinv_1")().asnumpy()
239-
field_ref_2 = interpolation_savepoint.__getattribute__("lsq_pseudoinv_2")().asnumpy()
240-
field_1 = factory.get(attrs.LSQ_PSEUDOINV)[:, 0, :]
241-
field_2 = factory.get(attrs.LSQ_PSEUDOINV)[:, 1, :]
242-
assert test_utils.dallclose(field_1, field_ref_1, atol=1e-15) # type: ignore[arg-type] # mypy does not recognize sliced array as still an array
243-
assert test_utils.dallclose(field_2, field_ref_2, atol=1e-15) # type: ignore[arg-type] # mypy does not recognize sliced array as still an array
238+
field_ref_1 = interpolation_savepoint.lsq_pseudoinv_1().asnumpy()
239+
field_ref_2 = interpolation_savepoint.lsq_pseudoinv_2().asnumpy()
240+
field = factory.get(attrs.LSQ_PSEUDOINV).asnumpy()
241+
assert test_utils.dallclose(field[:, 0, :], field_ref_1, atol=1e-15)
242+
assert test_utils.dallclose(field[:, 1, :], field_ref_2, atol=1e-15)

model/common/tests/common/interpolation/unit_tests/test_interpolation_factory.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,3 +374,18 @@ def test_rbf_interpolation_coeffs_vertex(
374374
field_v2[horizontal_start:],
375375
atol=RBF_TOLERANCES[dims.VertexDim][experiment.name],
376376
)
377+
378+
379+
@pytest.mark.level("integration")
380+
@pytest.mark.datatest
381+
def test_lsq_pseudoinv(
382+
interpolation_savepoint: serialbox.InterpolationSavepoint,
383+
experiment: definitions.Experiment,
384+
backend: gtx_typing.Backend | None,
385+
) -> None:
386+
field_ref_1 = interpolation_savepoint.lsq_pseudoinv_1().asnumpy()
387+
field_ref_2 = interpolation_savepoint.lsq_pseudoinv_2().asnumpy()
388+
factory = _get_interpolation_factory(backend, experiment)
389+
field = factory.get(attrs.LSQ_PSEUDOINV).asnumpy()
390+
assert test_helpers.dallclose(field_ref_1, field[:, 0, :], atol=1e-15)
391+
assert test_helpers.dallclose(field_ref_2, field[:, 1, :], atol=1e-15)

model/testing/src/icon4py/model/testing/serialbox.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -703,10 +703,10 @@ def rbf_vec_idx_v(self):
703703
return self._get_field("rbf_vec_idx_v", dims.VertexDim, dims.V2EDim)
704704

705705
def lsq_pseudoinv_1(self):
706-
return self._get_field("lsq_pseudoinv_1", dims.CellDim, dims.C2E2CDim)
706+
return self._get_field("lsq_pseudoinv_1", dims.CellDim, dims.LsqCDim)
707707

708708
def lsq_pseudoinv_2(self):
709-
return self._get_field("lsq_pseudoinv_2", dims.CellDim, dims.C2E2CDim)
709+
return self._get_field("lsq_pseudoinv_2", dims.CellDim, dims.LsqCDim)
710710

711711

712712
class MetricSavepoint(IconSavepoint):

0 commit comments

Comments
 (0)