Skip to content

Commit 8dd6bd2

Browse files
huppdsamkellerhals
andauthored
Ignore delete statements in fused mode (#266)
This PR allows to use the DELETE statements only for the fused stencil mode. --------- Co-authored-by: Samuel <kellerhalssamuel@gmail.com>
1 parent 92e23fb commit 8dd6bd2

File tree

3 files changed

+276
-0
lines changed

3 files changed

+276
-0
lines changed

tools/README.md

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,66 @@ Additionally, there are the following keyword arguments:
192192

193193
- `noaccenddata`: Takes a boolean string input and controls whether a `!$ACC END DATA` directive is generated or not. Defaults to false.<br><br>
194194

195+
#### `!$DSL FUSED START STENCIL()`
196+
197+
This directive denotes the start of a fused stencil. Required arguments are `name`, `vertical_lower`, `vertical_upper`, `horizontal_lower`, `horizontal_upper`. The value for `name` must correspond to a stencil found in one of the stencil modules inside `icon4py`, and all fields defined in the directive must correspond to the fields defined in the respective icon4py stencil. Optionally, absolute and relative tolerances for the output fields can also be set using the `_tol` or `_abs` suffixes respectively. For each stencil, an ACC ENTER/EXIT DATA statements will be created. This ACC ENTER/EXIT DATA region contains the before fileds of the according stencil. An example call looks like this:
198+
199+
```fortran
200+
!$DSL START FUSED STENCIL(name=calculate_diagnostic_quantities_for_turbulence; &
201+
!$DSL kh_smag_ec=kh_smag_ec(:,:,1); vn=p_nh_prog%vn(:,:,1); e_bln_c_s=p_int%e_bln_c_s(:,:,1); &
202+
!$DSL geofac_div=p_int%geofac_div(:,:,1); diff_multfac_smag=diff_multfac_smag(:); &
203+
!$DSL wgtfac_c=p_nh_metrics%wgtfac_c(:,:,1); div_ic=p_nh_diag%div_ic(:,:,1); &
204+
!$DSL hdef_ic=p_nh_diag%hdef_ic(:,:,1); &
205+
!$DSL div_ic_abs_tol=1e-18_wp; vertical_lower=2; &
206+
!$DSL vertical_upper=nlev; horizontal_lower=i_startidx; horizontal_upper=i_endidx)
207+
```
208+
209+
#### `!$DSL END FUSED STENCIL()`
210+
211+
This directive denotes the end of a fused stencil. The required argument is `name`, which must match the name of the preceding `START STENCIL` directive.
212+
213+
Note that each `START STENCIL` and `END STENCIL` will be transformed into a `DELETE` section, when using the `--fused` mode.
214+
Together, the `START FUSED STENCIL` and `END FUSED STENCIL` directives result in the following generated code at the start and end of a stencil respectively.
215+
216+
```fortran
217+
!$ACC DATA CREATE( &
218+
!$ACC kh_smag_e_before, &
219+
!$ACC kh_smag_ec_before, &
220+
!$ACC z_nabla2_e_before ) &
221+
!$ACC IF ( i_am_accel_node )
222+
223+
#ifdef __DSL_VERIFY
224+
!$ACC KERNELS IF( i_am_accel_node ) DEFAULT(PRESENT) ASYNC(1)
225+
kh_smag_e_before(:, :, :) = kh_smag_e(:, :, :)
226+
kh_smag_ec_before(:, :, :) = kh_smag_ec(:, :, :)
227+
z_nabla2_e_before(:, :, :) = z_nabla2_e(:, :, :)
228+
!$ACC END KERNELS
229+
```
230+
231+
```fortran
232+
call wrap_run_calculate_diagnostic_quantities_for_turbulence( &
233+
kh_smag_ec=kh_smag_ec(:, :, 1), &
234+
vn=p_nh_prog%vn(:, :, 1), &
235+
e_bln_c_s=p_int%e_bln_c_s(:, :, 1), &
236+
geofac_div=p_int%geofac_div(:, :, 1), &
237+
diff_multfac_smag=diff_multfac_smag(:), &
238+
wgtfac_c=p_nh_metrics%wgtfac_c(:, :, 1), &
239+
div_ic=p_nh_diag%div_ic(:, :, 1), &
240+
div_ic_before=div_ic_before(:, :, 1), &
241+
hdef_ic=p_nh_diag%hdef_ic(:, :, 1), &
242+
hdef_ic_before=hdef_ic_before(:, :, 1), &
243+
div_ic_abs_tol=1e-18_wp, &
244+
vertical_lower=2, &
245+
vertical_upper=nlev, &
246+
horizontal_lower=i_startidx, &
247+
horizontal_upper=i_endidx)
248+
249+
!$ACC EXIT DATA DELETE( &
250+
!$ACC div_ic_before, &
251+
!$ACC hdef_ic_before ) &
252+
!$ACC IF ( i_am_accel_node )
253+
```
254+
195255
#### `!$DSL INSERT()`
196256

197257
This directive allows the user to generate any text that is placed between the parentheses. This is useful for situations where custom code generation is necessary.
@@ -204,6 +264,16 @@ This directive allows generating an nvtx start profile data statement, and takes
204264

205265
This directive allows generating an nvtx end profile statement.
206266

267+
#### `!$DSL START DELETE
268+
269+
This directive allows to disable code. The code is only disabled if both the fused mode and the substition mode are enabled.
270+
The `START DELETE` indicates the starting line from which on code is deleted.
271+
272+
#### `!$DSL END DELETE`
273+
274+
This directive allows to disable code. The code is only disabled if both the fused mode and the substition mode are enabled.
275+
The `END DELETE` indicates the ending line from which on code is deleted.
276+
207277
#### `!$DSL ENDIF()`
208278

209279
This directive generates an `#endif` statement.

tools/src/icon4pytools/liskov/parsing/transform.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __call__(self, data: Any = None) -> IntegrationCodeInterface:
5454
else:
5555
logger.info("Removing fused stencils.")
5656
self._remove_fused_stencils()
57+
self._remove_delete()
5758

5859
return self.parsed
5960

@@ -113,3 +114,7 @@ def _remove_stencils(self, stencils_to_remove: list[CodeGenInput]) -> None:
113114
def _remove_fused_stencils(self) -> None:
114115
self.parsed.StartFusedStencil = []
115116
self.parsed.EndFusedStencil = []
117+
118+
def _remove_delete(self) -> None:
119+
self.parsed.StartDelete = []
120+
self.parsed.EndDelete = []
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# ICON4Py - ICON inspired code in Python and GT4Py
2+
#
3+
# Copyright (c) 2022, ETH Zurich and MeteoSwiss
4+
# All rights reserved.
5+
#
6+
# This file is free software: you can redistribute it and/or modify it under
7+
# the terms of the GNU General Public License as published by the
8+
# Free Software Foundation, either version 3 of the License, or any later
9+
# version. See the LICENSE.txt file at the top-level directory of this
10+
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
11+
#
12+
# SPDX-License-Identifier: GPL-3.0-or-later
13+
14+
15+
import pytest
16+
17+
from icon4pytools.liskov.codegen.integration.interface import (
18+
BoundsData,
19+
DeclareData,
20+
EndCreateData,
21+
EndDeleteData,
22+
EndFusedStencilData,
23+
EndIfData,
24+
EndProfileData,
25+
EndStencilData,
26+
FieldAssociationData,
27+
ImportsData,
28+
InsertData,
29+
IntegrationCodeInterface,
30+
StartCreateData,
31+
StartDeleteData,
32+
StartFusedStencilData,
33+
StartProfileData,
34+
StartStencilData,
35+
)
36+
from icon4pytools.liskov.parsing.transform import StencilTransformer
37+
38+
39+
@pytest.fixture
40+
def integration_code_interface():
41+
start_fused_stencil_data = StartFusedStencilData(
42+
name="fused_stencil1",
43+
fields=[
44+
FieldAssociationData("scalar1", "scalar1", inp=True, out=False, dims=None),
45+
FieldAssociationData("inp1", "inp1(:,:,1)", inp=True, out=False, dims=2),
46+
FieldAssociationData("out1", "out1(:,:,1)", inp=False, out=True, dims=2, abs_tol="0.5"),
47+
FieldAssociationData(
48+
"out2",
49+
"p_nh%prog(nnew)%out2(:,:,1)",
50+
inp=False,
51+
out=True,
52+
dims=3,
53+
abs_tol="0.2",
54+
),
55+
FieldAssociationData("out3", "p_nh%prog(nnew)%w(:,:,jb)", inp=False, out=True, dims=2),
56+
FieldAssociationData("out4", "p_nh%prog(nnew)%w(:,:,1,2)", inp=False, out=True, dims=3),
57+
FieldAssociationData(
58+
"out5", "p_nh%prog(nnew)%w(:,:,:,ntnd)", inp=False, out=True, dims=3
59+
),
60+
FieldAssociationData(
61+
"out6", "p_nh%prog(nnew)%w(:,:,1,ntnd)", inp=False, out=True, dims=3
62+
),
63+
],
64+
bounds=BoundsData("1", "10", "-1", "-10"),
65+
startln=1,
66+
acc_present=False,
67+
)
68+
end_fused_stencil_data = EndFusedStencilData(name="stencil1", startln=4)
69+
start_stencil_data1 = StartStencilData(
70+
name="stencil1",
71+
fields=[
72+
FieldAssociationData("scalar1", "scalar1", inp=True, out=False, dims=None),
73+
FieldAssociationData("inp1", "inp1(:,:,1)", inp=True, out=False, dims=2),
74+
FieldAssociationData("out1", "out1(:,:,1)", inp=False, out=True, dims=2, abs_tol="0.5"),
75+
FieldAssociationData(
76+
"out2",
77+
"p_nh%prog(nnew)%out2(:,:,1)",
78+
inp=False,
79+
out=True,
80+
dims=3,
81+
abs_tol="0.2",
82+
),
83+
FieldAssociationData("out3", "p_nh%prog(nnew)%w(:,:,jb)", inp=False, out=True, dims=2),
84+
FieldAssociationData("out4", "p_nh%prog(nnew)%w(:,:,1,2)", inp=False, out=True, dims=3),
85+
FieldAssociationData(
86+
"out5", "p_nh%prog(nnew)%w(:,:,:,ntnd)", inp=False, out=True, dims=3
87+
),
88+
FieldAssociationData(
89+
"out6", "p_nh%prog(nnew)%w(:,:,1,ntnd)", inp=False, out=True, dims=3
90+
),
91+
],
92+
bounds=BoundsData("1", "10", "-1", "-10"),
93+
startln=2,
94+
acc_present=False,
95+
mergecopy=False,
96+
copies=True,
97+
)
98+
end_stencil_data1 = EndStencilData(
99+
name="stencil1", startln=3, noendif=False, noprofile=False, noaccenddata=False
100+
)
101+
start_stencil_data2 = StartStencilData(
102+
name="stencil2",
103+
fields=[
104+
FieldAssociationData("scalar1", "scalar1", inp=True, out=False, dims=None),
105+
FieldAssociationData("inp1", "inp1(:,:,1)", inp=True, out=False, dims=2),
106+
FieldAssociationData("out1", "out1(:,:,1)", inp=False, out=True, dims=2, abs_tol="0.5"),
107+
FieldAssociationData(
108+
"out2",
109+
"p_nh%prog(nnew)%out2(:,:,1)",
110+
inp=False,
111+
out=True,
112+
dims=3,
113+
abs_tol="0.2",
114+
),
115+
FieldAssociationData("out3", "p_nh%prog(nnew)%w(:,:,jb)", inp=False, out=True, dims=2),
116+
FieldAssociationData("out4", "p_nh%prog(nnew)%w(:,:,1,2)", inp=False, out=True, dims=3),
117+
FieldAssociationData(
118+
"out5", "p_nh%prog(nnew)%w(:,:,:,ntnd)", inp=False, out=True, dims=3
119+
),
120+
FieldAssociationData(
121+
"out6", "p_nh%prog(nnew)%w(:,:,1,ntnd)", inp=False, out=True, dims=3
122+
),
123+
],
124+
bounds=BoundsData("1", "10", "-1", "-10"),
125+
startln=5,
126+
acc_present=False,
127+
mergecopy=False,
128+
copies=True,
129+
)
130+
end_stencil_data2 = EndStencilData(
131+
name="stencil2", startln=6, noendif=False, noprofile=False, noaccenddata=False
132+
)
133+
declare_data = DeclareData(
134+
startln=7,
135+
declarations={"field2": "(nproma, p_patch%nlev, p_patch%nblks_e)"},
136+
ident_type="REAL(wp)",
137+
suffix="before",
138+
)
139+
imports_data = ImportsData(startln=8)
140+
start_create_data = StartCreateData(extra_fields=["foo", "bar"], startln=9)
141+
end_create_data = EndCreateData(startln=11)
142+
endif_data = EndIfData(startln=12)
143+
start_profile_data = StartProfileData(startln=13, name="test_stencil")
144+
end_profile_data = EndProfileData(startln=14)
145+
insert_data = InsertData(startln=15, content="print *, 'Hello, World!'")
146+
start_delete_data = StartDeleteData(startln=16)
147+
end_delete_data = EndDeleteData(startln=17)
148+
149+
return IntegrationCodeInterface(
150+
StartStencil=[start_stencil_data1, start_stencil_data2],
151+
EndStencil=[end_stencil_data1, end_stencil_data2],
152+
StartFusedStencil=[start_fused_stencil_data],
153+
EndFusedStencil=[end_fused_stencil_data],
154+
StartDelete=[start_delete_data],
155+
EndDelete=[end_delete_data],
156+
Declare=[declare_data],
157+
Imports=imports_data,
158+
StartCreate=[start_create_data],
159+
EndCreate=[end_create_data],
160+
EndIf=[endif_data],
161+
StartProfile=[start_profile_data],
162+
EndProfile=[end_profile_data],
163+
Insert=[insert_data],
164+
)
165+
166+
167+
@pytest.fixture
168+
def stencil_transform_fused(integration_code_interface):
169+
return StencilTransformer(integration_code_interface, fused=True)
170+
171+
172+
@pytest.fixture
173+
def stencil_transform_unfused(integration_code_interface):
174+
return StencilTransformer(integration_code_interface, fused=False)
175+
176+
177+
def test_transform_fused(
178+
stencil_transform_fused,
179+
):
180+
# Check that the transformed interface is as expected
181+
transformed = stencil_transform_fused()
182+
assert len(transformed.StartFusedStencil) == 1
183+
assert len(transformed.EndFusedStencil) == 1
184+
assert len(transformed.StartStencil) == 1
185+
assert len(transformed.EndStencil) == 1
186+
assert len(transformed.StartDelete) == 2
187+
assert len(transformed.EndDelete) == 2
188+
189+
190+
def test_transform_unfused(
191+
stencil_transform_unfused,
192+
):
193+
# Check that the transformed interface is as expected
194+
transformed = stencil_transform_unfused()
195+
196+
assert not transformed.StartFusedStencil
197+
assert not transformed.EndFusedStencil
198+
assert len(transformed.StartStencil) == 2
199+
assert len(transformed.EndStencil) == 2
200+
assert not transformed.StartDelete
201+
assert not transformed.EndDelete

0 commit comments

Comments
 (0)