@@ -9,17 +9,13 @@ import ClimaCore.Utilities
99# #### Boundary helpers
1010# ####
1111
12- @inline function has_left_boundary (space, op)
13- lloc = Operators. LeftBoundaryWindow {Spaces.left_boundary_name(space)} ()
14- return Operators. has_boundary (op, lloc)
15- end
16- @inline function has_right_boundary (space, op)
17- rloc = Operators. RightBoundaryWindow {Spaces.right_boundary_name(space)} ()
18- return Operators. has_boundary (op, rloc)
19- end
12+ @inline has_left_boundary (space, op) =
13+ Operators. has_boundary (op, Operators. left_boundary_window (space))
14+ @inline has_right_boundary (space, op) =
15+ Operators. has_boundary (op, Operators. right_boundary_window (space))
2016
21- @inline on_boundary (space, op, loc, idx ) =
22- Operators . has_boundary (op, loc) && on_boundary (idx, space)
17+ @inline on_boundary (idx, space, op ) =
18+ on_left_boundary (idx, space, op) || on_right_boundary (idx, space, op )
2319
2420@inline on_left_boundary (idx, space, op) =
2521 has_left_boundary (space, op) && on_left_boundary (idx, space)
9288 op,
9389 args... ,
9490)
95- lloc = Operators. LeftBoundaryWindow {Spaces.left_boundary_name(space)} ()
96- Operators. should_call_left_boundary (idx, space, lloc, op, args... ) ||
91+ Operators. should_call_left_boundary (idx, space, op, args... ) ||
9792 in_left_boundary_window_range (idx, bc_bds)
9893end
9994
10499 op,
105100 args... ,
106101)
107- rloc = Operators. RightBoundaryWindow {Spaces.right_boundary_name(space)} ()
108- Operators. should_call_right_boundary (idx, space, rloc, op, args... ) ||
102+ Operators. should_call_right_boundary (idx, space, op, args... ) ||
109103 in_right_boundary_window_range (idx, bc_bds)
110104end
111105
146140 op,
147141 args... ,
148142)
149- lloc = Operators. LeftBoundaryWindow {Spaces.left_boundary_name(space)} ()
150- Operators. should_call_left_boundary (idx, space, lloc, op, args... ) ||
143+ Operators. should_call_left_boundary (idx, space, op, args... ) ||
151144 in_left_boundary_window_range (idx, bc_bds)
152145end
153146
158151 op,
159152 args... ,
160153)
161- rloc = Operators. RightBoundaryWindow {Spaces.right_boundary_name(space)} ()
162154 ᶜspace = Spaces. center_space (space)
163155 idx > Spaces. nlevels (ᶜspace) && return false # short-circuit if
164- Operators. should_call_right_boundary (idx, space, rloc, op, args... ) ||
156+ Operators. should_call_right_boundary (idx, space, op, args... ) ||
165157 in_right_boundary_window_range (idx, bc_bds)
166158end
167159
172164Base. @propagate_inbounds function getidx (
173165 parent_space,
174166 bc:: StencilBroadcasted{CUDAWithShmemColumnStencilStyle} ,
175- loc:: Interior ,
176- idx,
177- hidx,
178- )
179- space = axes (bc)
180- if Operators. fd_shmem_is_supported (bc)
181- return fd_operator_evaluate (
182- bc. op,
183- bc. work,
184- loc,
185- space,
186- idx,
187- hidx,
188- bc. args... ,
189- )
190- end
191- Operators. stencil_interior (bc. op, loc, space, idx, hidx, bc. args... )
192- end
193-
194-
195- Base. @propagate_inbounds function getidx (
196- parent_space,
197- bc:: StencilBroadcasted{CUDAWithShmemColumnStencilStyle} ,
198- loc:: Operators.LeftBoundaryWindow ,
199167 idx,
200168 hidx,
201169)
@@ -204,63 +172,34 @@ Base.@propagate_inbounds function getidx(
204172 return fd_operator_evaluate (
205173 bc. op,
206174 bc. work,
207- loc,
208175 space,
209176 idx,
210177 hidx,
211178 bc. args... ,
212179 )
213180 end
214181 op = bc. op
215- if Operators. should_call_left_boundary (idx, space, loc, bc. op, bc. args... )
182+ if Operators. should_call_left_boundary (idx, space, bc. op, bc. args... )
216183 Operators. stencil_left_boundary (
217184 op,
218- Operators. get_boundary (op, loc),
219- loc,
220- space,
221- idx,
222- hidx,
223- bc. args... ,
224- )
225- else
226- # fallback to interior stencil
227- Operators. stencil_interior (op, loc, space, idx, hidx, bc. args... )
228- end
229- end
230-
231- Base. @propagate_inbounds function getidx (
232- parent_space,
233- bc:: StencilBroadcasted{CUDAWithShmemColumnStencilStyle} ,
234- loc:: Operators.RightBoundaryWindow ,
235- idx,
236- hidx,
237- )
238- space = axes (bc)
239- if Operators. fd_shmem_is_supported (bc)
240- return fd_operator_evaluate (
241- bc. op,
242- bc. work,
243- loc,
185+ Operators. get_boundary (op, Operators. left_boundary_window (space)),
244186 space,
245187 idx,
246188 hidx,
247189 bc. args... ,
248190 )
249- end
250- op = bc. op
251- if Operators. should_call_right_boundary (idx, space, loc, bc. op, bc. args... )
191+ elseif Operators. should_call_right_boundary (idx, space, bc. op, bc. args... )
252192 Operators. stencil_right_boundary (
253193 op,
254- Operators. get_boundary (op, loc),
255- loc,
194+ Operators. get_boundary (op, Operators. right_boundary_window (space)),
256195 space,
257196 idx,
258197 hidx,
259198 bc. args... ,
260199 )
261200 else
262201 # fallback to interior stencil
263- Operators. stencil_interior (op, loc, space, idx, hidx, bc. args... )
202+ Operators. stencil_interior (op, space, idx, hidx, bc. args... )
264203 end
265204end
266205
@@ -375,9 +314,6 @@ Base.@propagate_inbounds function fd_resolve_shmem!(
375314)
376315 (li, lw, rw, ri) = bds
377316 space = axes (sbc)
378-
379- ᶜspace = Spaces. center_space (space)
380- ᶠspace = Spaces. face_space (space)
381317 arg_space = get_arg_space (sbc, sbc. args)
382318 ᶜidx = get_cent_idx (idx)
383319 ᶠidx = get_face_idx (idx)
@@ -387,13 +323,6 @@ Base.@propagate_inbounds function fd_resolve_shmem!(
387323 # After recursion, check if shmem is supported for this operator
388324 Operators. fd_shmem_is_supported (sbc) || return nothing
389325
390- (; op) = sbc
391- lloc = Operators. LeftBoundaryWindow {Spaces.left_boundary_name(space)} ()
392- rloc = Operators. RightBoundaryWindow {Spaces.right_boundary_name(space)} ()
393- iloc = Operators. Interior ()
394-
395- IP = Topologies. isperiodic (Spaces. vertical_topology (space))
396-
397326 # There are `Nf` threads, where `Nf` is the number of face levels. So,
398327 # each thread is responsible for filling shared memory at its cell center
399328 # (if the broadcasted argument lives on cell centers)
@@ -403,52 +332,18 @@ Base.@propagate_inbounds function fd_resolve_shmem!(
403332 # (the space of all broadcasted arguments must all match, so using the first is valid).
404333
405334 bc_bds = Operators. window_bounds (space, sbc)
406- (bc_li, bc_lw, bc_rw, bc_ri) = bc_bds
407335 ᵃidx = arg_space isa Operators. AllFaceFiniteDifferenceSpace ? ᶠidx : ᶜidx
408336
409- if in_interior (ᵃidx, arg_space, bc_bds, sbc. op, sbc. args... )
410- fd_operator_fill_shmem! (
411- sbc. op,
412- sbc. work,
413- iloc,
414- bc_bds,
415- arg_space,
416- space,
417- ᵃidx,
418- hidx,
419- sbc. args... ,
420- )
421- elseif in_left_boundary_window (ᵃidx, arg_space, bc_bds, sbc. op, sbc. args... )
422- fd_operator_fill_shmem! (
423- sbc. op,
424- sbc. work,
425- lloc,
426- bc_bds,
427- arg_space,
428- space,
429- ᵃidx,
430- hidx,
431- sbc. args... ,
432- )
433- elseif in_right_boundary_window (
434- ᵃidx,
435- arg_space,
436- bc_bds,
337+ fd_operator_fill_shmem! (
437338 sbc. op,
339+ sbc. work,
340+ bc_bds,
341+ arg_space,
342+ space,
343+ ᵃidx,
344+ hidx,
438345 sbc. args... ,
439346 )
440- fd_operator_fill_shmem! (
441- sbc. op,
442- sbc. work,
443- rloc,
444- bc_bds,
445- arg_space,
446- space,
447- ᵃidx,
448- hidx,
449- sbc. args... ,
450- )
451- end
452347 CUDA. sync_threads ()
453348 return nothing
454349end
0 commit comments