@@ -2,6 +2,7 @@ import ClimaCore: DataLayouts, Spaces, Geometry, RecursiveApply, DataLayouts
22import CUDA
33import ClimaCore. Operators: return_eltype, get_local_geometry
44import ClimaCore. Geometry: ⊗
5+ import ClimaCore. RecursiveApply: ⊟ , ⊞
56
67Base. @propagate_inbounds function fd_operator_shmem (
78 space,
@@ -131,13 +132,11 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
131132 arg,
132133)
133134 @inbounds begin
135+ is_out_of_bounds (idx, space) && return nothing
134136 vt = threadIdx (). x
135137 cov3 = Geometry. Covariant3Vector (1 )
136138 if in_domain (idx, arg_space)
137139 u[vt] = cov3 ⊗ Operators. getidx (space, arg, idx, hidx)
138- else # idx can be Spaces.nlevels(ᶜspace)+1 because threads must extend to faces
139- ᶜspace = Spaces. center_space (arg_space)
140- @assert idx == Spaces. nlevels (ᶜspace) + 1
141140 end
142141 if on_any_boundary (idx, space, op)
143142 lloc = Operators. left_boundary_window (space)
@@ -200,3 +199,114 @@ Base.@propagate_inbounds function fd_operator_evaluate(
200199 end
201200 end
202201end
202+
203+ Base. @propagate_inbounds function fd_operator_shmem (
204+ space,
205+ :: Val{Nvt} ,
206+ op:: Operators.InterpolateC2F ,
207+ args... ,
208+ ) where {Nvt}
209+ # allocate temp output
210+ RT = return_eltype (op, args... )
211+ u = CUDA. CuStaticSharedArray (RT, (Nvt,)) # cell centers
212+ lb = CUDA. CuStaticSharedArray (RT, (1 ,)) # left boundary
213+ rb = CUDA. CuStaticSharedArray (RT, (1 ,)) # right boundary
214+ return (u, lb, rb)
215+ end
216+
217+ Base. @propagate_inbounds function fd_operator_fill_shmem! (
218+ op:: Operators.InterpolateC2F ,
219+ (u, lb, rb),
220+ bc_bds,
221+ arg_space,
222+ space,
223+ idx:: Integer ,
224+ hidx,
225+ arg,
226+ )
227+ @inbounds begin
228+ is_out_of_bounds (idx, space) && return nothing
229+ ᶜidx = get_cent_idx (idx)
230+ if in_domain (idx, arg_space)
231+ u[idx] = Operators. getidx (space, arg, idx, hidx)
232+ else
233+ lloc = Operators. left_boundary_window (space)
234+ rloc = Operators. right_boundary_window (space)
235+ bloc = on_left_boundary (idx, space, op) ? lloc : rloc
236+ @assert bloc isa typeof (lloc) && on_left_boundary (idx, space, op) ||
237+ bloc isa typeof (rloc) && on_right_boundary (idx, space, op)
238+ bc = Operators. get_boundary (op, bloc)
239+ @assert bc isa Operators. SetValue ||
240+ bc isa Operators. SetGradient ||
241+ bc isa Operators. Extrapolate ||
242+ bc isa Operators. NullBoundaryCondition
243+ if bc isa Operators. NullBoundaryCondition ||
244+ bc isa Operators. Extrapolate
245+ u[idx] = Operators. getidx (space, arg, idx, hidx)
246+ return nothing
247+ end
248+ bu = on_left_boundary (idx, space) ? lb : rb
249+ ub = Operators. getidx (space, bc. val, nothing , hidx)
250+ if bc isa Operators. SetValue
251+ bu[1 ] = ub
252+ elseif bc isa Operators. SetGradient
253+ lg = Geometry. LocalGeometry (space, idx, hidx)
254+ bu[1 ] = Geometry. covariant3 (ub, lg)
255+ end
256+ end
257+ end
258+ return nothing
259+ end
260+
261+ Base. @propagate_inbounds function fd_operator_evaluate (
262+ op:: Operators.InterpolateC2F ,
263+ (u, lb, rb),
264+ space,
265+ idx:: PlusHalf ,
266+ hidx,
267+ args... ,
268+ )
269+ @inbounds begin
270+ vt = threadIdx (). x
271+ lg = Geometry. LocalGeometry (space, idx, hidx)
272+ ᶜidx = get_cent_idx (idx)
273+ if ! on_boundary (idx, space, op)
274+ u₋ = u[ᶜidx - 1 ] # corresponds to idx - half
275+ u₊ = u[ᶜidx] # corresponds to idx + half
276+ return RecursiveApply. rdiv (u₊ ⊞ u₋, 2 )
277+ else
278+ bloc =
279+ on_left_boundary (idx, space, op) ?
280+ Operators. left_boundary_window (space) :
281+ Operators. right_boundary_window (space)
282+ bc = Operators. get_boundary (op, bloc)
283+ @assert bc isa Operators. SetValue ||
284+ bc isa Operators. SetGradient ||
285+ bc isa Operators. Extrapolate
286+ if on_left_boundary (idx, space)
287+ if bc isa Operators. SetValue
288+ return lb[1 ]
289+ elseif bc isa Operators. SetGradient
290+ u₋ = lb[1 ] # corresponds to idx - half
291+ u₊ = u[ᶜidx] # corresponds to idx + half
292+ return u₊ ⊟ RecursiveApply. rdiv (u₋, 2 )
293+ else
294+ @assert bc isa Operators. Extrapolate
295+ return u[ᶜidx]
296+ end
297+ else
298+ @assert on_right_boundary (idx, space)
299+ if bc isa Operators. SetValue
300+ return rb[1 ]
301+ elseif bc isa Operators. SetGradient
302+ u₋ = u[ᶜidx - 1 ] # corresponds to idx - half
303+ u₊ = rb[1 ] # corresponds to idx + half
304+ return u₋ ⊞ RecursiveApply. rdiv (u₊, 2 )
305+ else
306+ @assert bc isa Operators. Extrapolate
307+ return u[ᶜidx - 1 ]
308+ end
309+ end
310+ end
311+ end
312+ end
0 commit comments