Skip to content

Commit 4f01705

Browse files
committed
Add single-rank test for lsq_pseudoinv factory
1 parent 4d2c0f5 commit 4f01705

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

model/common/tests/common/interpolation/mpi_tests/test_parallel_interpolation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,8 @@ def test_distributed_interpolation_lsq_pseudoinv(
232232
parallel_helpers.log_process_properties(processor_props)
233233
parallel_helpers.log_local_field_size(decomposition_info)
234234
factory = interpolation_factory_from_savepoint
235-
field_ref_1 = interpolation_savepoint.__getattribute__("lsq_pseudoinv_1")().asnumpy()
236-
field_ref_2 = interpolation_savepoint.__getattribute__("lsq_pseudoinv_2")().asnumpy()
235+
field_ref_1 = interpolation_savepoint.lsq_pseudoinv_1().asnumpy()
236+
field_ref_2 = interpolation_savepoint.lsq_pseudoinv_2().asnumpy()
237237
field_1 = factory.get(attrs.LSQ_PSEUDOINV)[:, 0, :]
238238
field_2 = factory.get(attrs.LSQ_PSEUDOINV)[:, 1, :]
239239
assert test_utils.dallclose(field_1, field_ref_1, atol=1e-15) # type: ignore[arg-type] # mypy does not recognize sliced array as still an array

model/common/tests/common/interpolation/unit_tests/test_interpolation_factory.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,3 +374,20 @@ def test_rbf_interpolation_coeffs_vertex(
374374
field_v2[horizontal_start:],
375375
atol=RBF_TOLERANCES[dims.VertexDim][experiment.name],
376376
)
377+
378+
379+
@pytest.mark.level("integration")
380+
@pytest.mark.datatest
381+
def test_lsq_pseudoinv(
382+
interpolation_savepoint: serialbox.InterpolationSavepoint,
383+
experiment: definitions.Experiment,
384+
backend: gtx_typing.Backend | None,
385+
) -> None:
386+
field_ref_1 = interpolation_savepoint.lsq_pseudoinv_1()
387+
field_ref_2 = interpolation_savepoint.lsq_pseudoinv_2()
388+
factory = _get_interpolation_factory(backend, experiment)
389+
field = factory.get(attrs.LSQ_PSEUDOINV).asnumpy()
390+
field_1 = field[:, 0, :]
391+
field_2 = field[:, 1, :]
392+
assert test_helpers.dallclose(field_ref_1.asnumpy(), field_1, atol=1e-15)
393+
assert test_helpers.dallclose(field_ref_2.asnumpy(), field_2, atol=1e-15)

0 commit comments

Comments
 (0)