@@ -6,15 +6,15 @@ import ClimaCore.RecursiveApply: ⊟, ⊞
66
77Base. @propagate_inbounds function fd_operator_shmem (
88 space,
9- :: Val{Nvt} ,
9+ params ,
1010 op:: Operators.DivergenceF2C ,
1111 args... ,
12- ) where {Nvt}
12+ )
1313 # allocate temp output
1414 RT = return_eltype (op, args... )
15- Ju³ = CUDA. CuStaticSharedArray (RT, (Nvt, ))
16- lJu³ = CUDA. CuStaticSharedArray (RT, ( 1 , ))
17- rJu³ = CUDA. CuStaticSharedArray (RT, ( 1 , ))
15+ Ju³ = CUDA. CuStaticSharedArray (RT, shmem_size (params ))
16+ lJu³ = CUDA. CuStaticSharedArray (RT, boundary_shmem_size ( ))
17+ rJu³ = CUDA. CuStaticSharedArray (RT, boundary_shmem_size ( ))
1818 return (Ju³, lJu³, rJu³)
1919end
2020
@@ -29,20 +29,21 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
2929 arg,
3030)
3131 @inbounds begin
32- vt = threadIdx (). x
32+ si = FDShmemIndex ()
33+ bi = FDShmemBoundaryIndex ()
3334 lg = Geometry. LocalGeometry (space, idx, hidx)
3435 if ! on_boundary (idx, space, op)
3536 u³ = Operators. getidx (space, arg, idx, hidx)
36- Ju³[vt ] = Geometry. Jcontravariant3 (u³, lg)
37+ Ju³[si ] = Geometry. Jcontravariant3 (u³, lg)
3738 elseif on_left_boundary (idx, space, op)
3839 bloc = Operators. left_boundary_window (space)
3940 bc = Operators. get_boundary (op, bloc)
4041 ub = Operators. getidx (space, bc. val, nothing , hidx)
4142 bJu³ = on_left_boundary (idx, space) ? lJu³ : rJu³
4243 if bc isa Operators. SetValue
43- bJu³[1 ] = Geometry. Jcontravariant3 (ub, lg)
44+ bJu³[bi ] = Geometry. Jcontravariant3 (ub, lg)
4445 elseif bc isa Operators. SetDivergence
45- bJu³[1 ] = ub
46+ bJu³[bi ] = ub
4647 elseif bc isa Operators. Extrapolate # no shmem needed
4748 end
4849 elseif on_right_boundary (idx, space, op)
@@ -51,9 +52,9 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
5152 ub = Operators. getidx (space, bc. val, nothing , hidx)
5253 bJu³ = on_left_boundary (idx, space) ? lJu³ : rJu³
5354 if bc isa Operators. SetValue
54- bJu³[1 ] = Geometry. Jcontravariant3 (ub, lg)
55+ bJu³[bi ] = Geometry. Jcontravariant3 (ub, lg)
5556 elseif bc isa Operators. SetDivergence
56- bJu³[1 ] = ub
57+ bJu³[bi ] = ub
5758 elseif bc isa Operators. Extrapolate # no shmem needed
5859 end
5960 end
@@ -70,11 +71,12 @@ Base.@propagate_inbounds function fd_operator_evaluate(
7071 arg,
7172)
7273 @inbounds begin
73- vt = threadIdx (). x
74+ si = FDShmemIndex ()
75+ bi = FDShmemBoundaryIndex ()
7476 lg = Geometry. LocalGeometry (space, idx, hidx)
7577 if ! on_boundary (idx, space, op)
76- Ju³₋ = Ju³[vt ] # corresponds to idx - half
77- Ju³₊ = Ju³[vt + 1 ] # corresponds to idx + half
78+ Ju³₋ = Ju³[si ] # corresponds to idx - half
79+ Ju³₊ = Ju³[si + 1 ] # corresponds to idx + half
7880 return (Ju³₊ ⊟ Ju³₋) ⊠ lg. invJ
7981 else
8082 bloc =
@@ -85,8 +87,8 @@ Base.@propagate_inbounds function fd_operator_evaluate(
8587 @assert bc isa Operators. SetValue || bc isa Operators. SetDivergence
8688 if on_left_boundary (idx, space)
8789 if bc isa Operators. SetValue
88- Ju³₋ = lJu³[1 ] # corresponds to idx - half
89- Ju³₊ = Ju³[vt + 1 ] # corresponds to idx + half
90+ Ju³₋ = lJu³[bi ] # corresponds to idx - half
91+ Ju³₊ = Ju³[si + 1 ] # corresponds to idx + half
9092 return (Ju³₊ ⊟ Ju³₋) ⊠ lg. invJ
9193 else
9294 # @assert bc isa Operators.SetDivergence
@@ -95,12 +97,12 @@ Base.@propagate_inbounds function fd_operator_evaluate(
9597 else
9698 @assert on_right_boundary (idx, space)
9799 if bc isa Operators. SetValue
98- Ju³₋ = Ju³[vt ] # corresponds to idx - half
99- Ju³₊ = rJu³[1 ] # corresponds to idx + half
100+ Ju³₋ = Ju³[si ] # corresponds to idx - half
101+ Ju³₊ = rJu³[bi ] # corresponds to idx + half
100102 return (Ju³₊ ⊟ Ju³₋) ⊠ lg. invJ
101103 else
102104 @assert bc isa Operators. SetDivergence
103- return rJu³[1 ]
105+ return rJu³[bi ]
104106 end
105107 end
106108 end
@@ -109,15 +111,15 @@ end
109111
110112Base. @propagate_inbounds function fd_operator_shmem (
111113 space,
112- :: Val{Nvt} ,
114+ params ,
113115 op:: Operators.GradientC2F ,
114116 args... ,
115- ) where {Nvt}
117+ )
116118 # allocate temp output
117119 RT = return_eltype (op, args... )
118- u = CUDA. CuStaticSharedArray (RT, (Nvt, )) # cell centers
119- lb = CUDA. CuStaticSharedArray (RT, ( 1 , )) # left boundary
120- rb = CUDA. CuStaticSharedArray (RT, ( 1 , )) # right boundary
120+ u = CUDA. CuStaticSharedArray (RT, shmem_size (params )) # cell centers
121+ lb = CUDA. CuStaticSharedArray (RT, boundary_shmem_size ( )) # left boundary
122+ rb = CUDA. CuStaticSharedArray (RT, boundary_shmem_size ( )) # right boundary
121123 return (u, lb, rb)
122124end
123125
@@ -132,11 +134,12 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
132134 arg,
133135)
134136 @inbounds begin
137+ si = FDShmemIndex (idx)
138+ bi = FDShmemBoundaryIndex ()
135139 is_out_of_bounds (idx, space) && return nothing
136- vt = threadIdx (). x
137140 cov3 = Geometry. Covariant3Vector (1 )
138141 if in_domain (idx, arg_space)
139- u[vt ] = cov3 ⊗ Operators. getidx (space, arg, idx, hidx)
142+ u[si ] = cov3 ⊗ Operators. getidx (space, arg, idx, hidx)
140143 end
141144 if on_any_boundary (idx, space, op)
142145 lloc = Operators. left_boundary_window (space)
@@ -149,10 +152,10 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
149152 ub = Operators. getidx (space, bc. val, nothing , hidx)
150153 bu = on_left_boundary (idx, space) ? lb : rb
151154 if bc isa Operators. SetValue
152- bu[1 ] = cov3 ⊗ ub
155+ bu[bi ] = cov3 ⊗ ub
153156 elseif bc isa Operators. SetGradient
154157 lg = Geometry. LocalGeometry (space, idx, hidx)
155- bu[1 ] = Geometry. project (Geometry. Covariant3Axis (), ub, lg)
158+ bu[bi ] = Geometry. project (Geometry. Covariant3Axis (), ub, lg)
156159 elseif bc isa Operators. Extrapolate # no shmem needed
157160 end
158161 end
@@ -169,11 +172,12 @@ Base.@propagate_inbounds function fd_operator_evaluate(
169172 args... ,
170173)
171174 @inbounds begin
172- vt = threadIdx (). x
175+ si = FDShmemIndex ()
176+ bi = FDShmemBoundaryIndex ()
173177 lg = Geometry. LocalGeometry (space, idx, hidx)
174178 if ! on_boundary (idx, space, op)
175- u₋ = u[vt - 1 ] # corresponds to idx - half
176- u₊ = u[vt ] # corresponds to idx + half
179+ u₋ = u[si - 1 ] # corresponds to idx - half
180+ u₊ = u[si ] # corresponds to idx + half
177181 return u₊ ⊟ u₋
178182 else
179183 bloc =
@@ -184,15 +188,15 @@ Base.@propagate_inbounds function fd_operator_evaluate(
184188 @assert bc isa Operators. SetValue
185189 if on_left_boundary (idx, space)
186190 if bc isa Operators. SetValue
187- u₋ = 2 * lb[1 ] # corresponds to idx - half
188- u₊ = 2 * u[vt ] # corresponds to idx + half
191+ u₋ = 2 * lb[bi ] # corresponds to idx - half
192+ u₊ = 2 * u[si ] # corresponds to idx + half
189193 return u₊ ⊟ u₋
190194 end
191195 else
192196 @assert on_right_boundary (idx, space)
193197 if bc isa Operators. SetValue
194- u₋ = 2 * u[vt - 1 ] # corresponds to idx - half
195- u₊ = 2 * rb[1 ] # corresponds to idx + half
198+ u₋ = 2 * u[si - 1 ] # corresponds to idx - half
199+ u₊ = 2 * rb[bi ] # corresponds to idx + half
196200 return u₊ ⊟ u₋
197201 end
198202 end
@@ -202,15 +206,15 @@ end
202206
203207Base. @propagate_inbounds function fd_operator_shmem (
204208 space,
205- :: Val{Nvt} ,
209+ params ,
206210 op:: Operators.InterpolateC2F ,
207211 args... ,
208- ) where {Nvt}
212+ )
209213 # allocate temp output
210214 RT = return_eltype (op, args... )
211- u = CUDA. CuStaticSharedArray (RT, (Nvt, )) # cell centers
212- lb = CUDA. CuStaticSharedArray (RT, ( 1 , )) # left boundary
213- rb = CUDA. CuStaticSharedArray (RT, ( 1 , )) # right boundary
215+ u = CUDA. CuStaticSharedArray (RT, shmem_size (params )) # cell centers
216+ lb = CUDA. CuStaticSharedArray (RT, boundary_shmem_size ( )) # left boundary
217+ rb = CUDA. CuStaticSharedArray (RT, boundary_shmem_size ( )) # right boundary
214218 return (u, lb, rb)
215219end
216220
@@ -225,10 +229,12 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
225229 arg,
226230)
227231 @inbounds begin
228- is_out_of_bounds (idx, space) && return nothing
229232 ᶜidx = get_cent_idx (idx)
233+ si = FDShmemIndex (idx)
234+ bi = FDShmemBoundaryIndex ()
235+ is_out_of_bounds (idx, space) && return nothing
230236 if in_domain (idx, arg_space)
231- u[idx ] = Operators. getidx (space, arg, idx, hidx)
237+ u[si ] = Operators. getidx (space, arg, idx, hidx)
232238 else
233239 lloc = Operators. left_boundary_window (space)
234240 rloc = Operators. right_boundary_window (space)
@@ -242,16 +248,16 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
242248 bc isa Operators. NullBoundaryCondition
243249 if bc isa Operators. NullBoundaryCondition ||
244250 bc isa Operators. Extrapolate
245- u[idx ] = Operators. getidx (space, arg, idx, hidx)
251+ u[si ] = Operators. getidx (space, arg, idx, hidx)
246252 return nothing
247253 end
248254 bu = on_left_boundary (idx, space) ? lb : rb
249255 ub = Operators. getidx (space, bc. val, nothing , hidx)
250256 if bc isa Operators. SetValue
251- bu[1 ] = ub
257+ bu[bi ] = ub
252258 elseif bc isa Operators. SetGradient
253259 lg = Geometry. LocalGeometry (space, idx, hidx)
254- bu[1 ] = Geometry. covariant3 (ub, lg)
260+ bu[bi ] = Geometry. covariant3 (ub, lg)
255261 end
256262 end
257263 end
@@ -267,12 +273,13 @@ Base.@propagate_inbounds function fd_operator_evaluate(
267273 args... ,
268274)
269275 @inbounds begin
270- vt = threadIdx (). x
271- lg = Geometry. LocalGeometry (space, idx, hidx)
272276 ᶜidx = get_cent_idx (idx)
277+ si = FDShmemIndex (ᶜidx)
278+ bi = FDShmemBoundaryIndex ()
279+ lg = Geometry. LocalGeometry (space, idx, hidx)
273280 if ! on_boundary (idx, space, op)
274- u₋ = u[ᶜidx - 1 ] # corresponds to idx - half
275- u₊ = u[ᶜidx ] # corresponds to idx + half
281+ u₋ = u[si - 1 ] # corresponds to idx - half
282+ u₊ = u[si ] # corresponds to idx + half
276283 return RecursiveApply. rdiv (u₊ ⊞ u₋, 2 )
277284 else
278285 bloc =
@@ -285,26 +292,26 @@ Base.@propagate_inbounds function fd_operator_evaluate(
285292 bc isa Operators. Extrapolate
286293 if on_left_boundary (idx, space)
287294 if bc isa Operators. SetValue
288- return lb[1 ]
295+ return lb[bi ]
289296 elseif bc isa Operators. SetGradient
290- u₋ = lb[1 ] # corresponds to idx - half
291- u₊ = u[ᶜidx ] # corresponds to idx + half
297+ u₋ = lb[bi ] # corresponds to idx - half
298+ u₊ = u[si ] # corresponds to idx + half
292299 return u₊ ⊟ RecursiveApply. rdiv (u₋, 2 )
293300 else
294301 @assert bc isa Operators. Extrapolate
295- return u[ᶜidx ]
302+ return u[si ]
296303 end
297304 else
298305 @assert on_right_boundary (idx, space)
299306 if bc isa Operators. SetValue
300- return rb[1 ]
307+ return rb[bi ]
301308 elseif bc isa Operators. SetGradient
302- u₋ = u[ᶜidx - 1 ] # corresponds to idx - half
303- u₊ = rb[1 ] # corresponds to idx + half
309+ u₋ = u[si - 1 ] # corresponds to idx - half
310+ u₊ = rb[bi ] # corresponds to idx + half
304311 return u₋ ⊞ RecursiveApply. rdiv (u₊, 2 )
305312 else
306313 @assert bc isa Operators. Extrapolate
307- return u[ᶜidx - 1 ]
314+ return u[si - 1 ]
308315 end
309316 end
310317 end
0 commit comments