Skip to content

Commit a967314

Browse files
philip-paul-muellerPhilip MuellerhavogtCopilotmsimberg
authored
Scheduled Halo Exchange (#980)
This PR introduces the [scheduled exchange feature](ghex-org/GHEX#190) from GHEX into ICON4Py. These exchange allows to call the exchange function before all work has been completed, i.e. the exchange will wait until the previous work is done. A similar feature is the "scheduled wait", that allows to initiate the receive without the need to wait on its completion. In addition to this the function also renamed the functions related to halo exchange: - `exchange()` was renamed to `start()`. - `wait()` was renamed to `finish()` (that might now return before the transfer has fully concluded). - `exchange_and_wait()` was renamed to `exchange()`. All of these functions now accepts the an argument called `stream`, which defaults to `DEFAULT_STREAM`. It is indicate how synchronization with the stream should be performed. In case of `start()` it means that the actual exchange should not start until all work previously submitted to `stream` has finished. For `finish()` it means that further work, submitted to `stream`, should not start until the exchange has ended. For `finish()` it is also possible to specify `BLOCK`, which means that `finish()` waits until the transfer has fully finished. The orchestrator was not updated, but the change were made in such a way that it continues to work in diffusion, although using the original, blocking behaviour. --------- Co-authored-by: Philip Mueller <philip.paul.mueller@bluemain.ch> Co-authored-by: Hannes Vogt <vogt@hey.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Co-authored-by: Mikael Simberg <mikael.simberg@iki.fi> Co-authored-by: Hannes Vogt <hannes@havogt.de>
1 parent d5778fb commit a967314

File tree

15 files changed

+2459
-2148
lines changed

15 files changed

+2459
-2148
lines changed

model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,9 @@ def run(
207207
log.debug("advection run - start")
208208

209209
log.debug("communication of prep_adv cell field: mass_flx_ic - start")
210-
self._exchange.exchange_and_wait(dims.CellDim, prep_adv.mass_flx_ic)
210+
self._exchange.exchange(
211+
dims.CellDim, prep_adv.mass_flx_ic, stream=decomposition.DEFAULT_STREAM
212+
)
211213
log.debug("communication of prep_adv cell field: mass_flx_ic - end")
212214

213215
log.debug("running stencil copy_cell_kdim_field - start")
@@ -313,7 +315,11 @@ def run(
313315
log.debug("advection run - start")
314316

315317
log.debug("communication of prep_adv cell field: mass_flx_ic - start")
316-
self._exchange.exchange_and_wait(dims.CellDim, prep_adv.mass_flx_ic)
318+
self._exchange.exchange(
319+
dims.CellDim,
320+
prep_adv.mass_flx_ic,
321+
stream=decomposition.DEFAULT_STREAM,
322+
)
317323
log.debug("communication of prep_adv cell field: mass_flx_ic - end")
318324

319325
# reintegrate density for conservation of mass
@@ -396,7 +402,11 @@ def run(
396402

397403
# exchange updated tracer values, originally happens only if iforcing /= inwp
398404
log.debug("communication of advection cell field: p_tracer_new - start")
399-
self._exchange.exchange_and_wait(dims.CellDim, p_tracer_new)
405+
self._exchange.exchange(
406+
dims.CellDim,
407+
p_tracer_new,
408+
stream=decomposition.DEFAULT_STREAM,
409+
)
400410
log.debug("communication of advection cell field: p_tracer_new - end")
401411

402412
# finalize step

model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection_horizontal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def apply_flux_limiter(
174174
)
175175

176176
log.debug("communication of advection cell field: r_m - start")
177-
self._exchange.exchange_and_wait(dims.CellDim, self._r_m)
177+
self._exchange.exchange(dims.CellDim, self._r_m, stream=decomposition.DEFAULT_STREAM)
178178
log.debug("communication of advection cell field: r_m - end")
179179

180180
# limit outward fluxes

model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def __init__(
417417
self._cell_params = cell_params
418418

419419
self.halo_exchange_wait = decomposition.create_halo_exchange_wait(
420-
self._exchange
420+
self._exchange,
421421
) # wait on a communication handle
422422
self.rd_o_cvd: float = constants.GAS_CONSTANT_DRY_AIR / (
423423
constants.CPD - constants.GAS_CONSTANT_DRY_AIR
@@ -761,11 +761,12 @@ def _sync_cell_fields(self, prognostic_state):
761761
IF ( linit .OR. (iforcing /= inwp .AND. iforcing /= iaes) ) THEN
762762
"""
763763
log.debug("communication of prognostic cell fields: theta, w, exner - start")
764-
self._exchange.exchange_and_wait(
764+
self._exchange.exchange(
765765
dims.CellDim,
766766
prognostic_state.w,
767767
prognostic_state.theta_v,
768768
prognostic_state.exner,
769+
stream=decomposition.DEFAULT_STREAM,
769770
)
770771
log.debug("communication of prognostic cell fields: theta, w, exner - done")
771772

@@ -802,12 +803,17 @@ def _do_diffusion_step(
802803
log.debug("rbf interpolation 1: end")
803804

804805
# 2. HALO EXCHANGE -- CALL sync_patch_array_mult u_vert and v_vert
806+
# TODO(phimuell, muellch): Is asynchronous mode okay here.
807+
# NOTE: We do not specify a stream here but rely on the default argument.
808+
# We do this to ensure that the orchestrator works, but it is not aware
809+
# of the streams.
805810
log.debug("communication rbf extrapolation of vn - start")
806811
self._exchange(
807812
self.u_vert,
808813
self.v_vert,
809814
dim=dims.VertexDim,
810-
wait=True,
815+
full_exchange=True,
816+
# stream=decomposition.DEFAULT_STREAM, # noqa: ERA001 # See NOTE above.
811817
)
812818
log.debug("communication rbf extrapolation of vn - end")
813819

@@ -850,12 +856,14 @@ def _do_diffusion_step(
850856
log.debug("2nd rbf interpolation: end")
851857

852858
# 6. HALO EXCHANGE -- CALL sync_patch_array_mult (Vertex Fields)
859+
# TODO(phimuell, muellch): Is asynchronous mode okay here.
853860
log.debug("communication rbf extrapolation of z_nable2_e - start")
854861
self._exchange(
855862
self.u_vert,
856863
self.v_vert,
857864
dim=dims.VertexDim,
858-
wait=True,
865+
full_exchange=True,
866+
# stream=decomposition.DEFAULT_STREAM, # noqa: ERA001 # See NOTE above.
859867
)
860868
log.debug("communication rbf extrapolation of z_nable2_e - end")
861869

@@ -871,7 +879,12 @@ def _do_diffusion_step(
871879
log.debug("running stencils 04 05 06 (apply_diffusion_to_vn): end")
872880

873881
log.debug("communication of prognistic.vn : start")
874-
handle_edge_comm = self._exchange(prognostic_state.vn, dim=dims.EdgeDim, wait=False)
882+
handle_edge_comm = self._exchange(
883+
prognostic_state.vn,
884+
dim=dims.EdgeDim,
885+
full_exchange=False,
886+
# stream=decomposition.DEFAULT_STREAM, # noqa: ERA001 # See NOTE above.
887+
)
875888

876889
log.debug(
877890
"running stencils 07 08 09 10 (apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence): start"
@@ -917,7 +930,8 @@ def _do_diffusion_step(
917930
log.debug("running stencil 13 to 16 apply_diffusion_to_theta_and_exner: end")
918931

919932
self.halo_exchange_wait(
920-
handle_edge_comm
933+
handle_edge_comm,
934+
# stream=decomposition.DEFAULT_STREAM, # noqa: ERA001 # See NOTE above.
921935
) # need to do this here, since we currently only use 1 communication object.
922936
log.debug("communication of prognogistic.vn - end")
923937

model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,8 +1151,11 @@ def run_predictor_step(
11511151
)
11521152

11531153
log.debug("exchanging prognostic field 'vn' and local field 'rho_at_edges_on_model_levels'")
1154-
self._exchange.exchange_and_wait(
1155-
dims.EdgeDim, prognostic_states.next.vn, z_fields.rho_at_edges_on_model_levels
1154+
self._exchange.exchange(
1155+
dims.EdgeDim,
1156+
prognostic_states.next.vn,
1157+
z_fields.rho_at_edges_on_model_levels,
1158+
stream=decomposition.DEFAULT_STREAM,
11561159
)
11571160

11581161
self._compute_horizontal_velocity_quantities_and_fluxes(
@@ -1223,12 +1226,19 @@ def run_predictor_step(
12231226
log.debug(
12241227
"exchanging prognostic field 'w' and local field 'dwdz_at_cells_on_model_levels'"
12251228
)
1226-
self._exchange.exchange_and_wait(
1227-
dims.CellDim, prognostic_states.next.w, z_fields.dwdz_at_cells_on_model_levels
1229+
self._exchange.exchange(
1230+
dims.CellDim,
1231+
prognostic_states.next.w,
1232+
z_fields.dwdz_at_cells_on_model_levels,
1233+
stream=decomposition.DEFAULT_STREAM,
12281234
)
12291235
else:
12301236
log.debug("exchanging prognostic field 'w'")
1231-
self._exchange.exchange_and_wait(dims.CellDim, prognostic_states.next.w)
1237+
self._exchange.exchange(
1238+
dims.CellDim,
1239+
prognostic_states.next.w,
1240+
stream=decomposition.DEFAULT_STREAM,
1241+
)
12321242

12331243
def run_corrector_step(
12341244
self,
@@ -1319,7 +1329,11 @@ def run_corrector_step(
13191329
)
13201330

13211331
log.debug("exchanging prognostic field 'vn'")
1322-
self._exchange.exchange_and_wait(dims.EdgeDim, (prognostic_states.next.vn))
1332+
self._exchange.exchange(
1333+
dims.EdgeDim,
1334+
prognostic_states.next.vn,
1335+
stream=decomposition.DEFAULT_STREAM,
1336+
)
13231337

13241338
self._compute_averaged_vn_and_fluxes(
13251339
spatially_averaged_vn=self.z_vn_avg,
@@ -1389,9 +1403,10 @@ def run_corrector_step(
13891403
)
13901404

13911405
log.debug("exchange prognostic fields 'rho' , 'exner', 'w'")
1392-
self._exchange.exchange_and_wait(
1406+
self._exchange.exchange(
13931407
dims.CellDim,
13941408
prognostic_states.next.rho,
13951409
prognostic_states.next.exner,
13961410
prognostic_states.next.w,
1411+
stream=decomposition.DEFAULT_STREAM,
13971412
)

0 commit comments

Comments
 (0)