Skip to content

Commit e30d2b7

Browse files
charleskawczynskiCharlie Kawczynski
andauthored
Add shmem support for InterpolateC2F (#2290)
Co-authored-by: Charlie Kawczynski <[email protected]>
1 parent b47dffb commit e30d2b7

File tree

5 files changed

+245
-144
lines changed

5 files changed

+245
-144
lines changed

ext/cuda/operators_fd_shmem.jl

Lines changed: 113 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import ClimaCore: DataLayouts, Spaces, Geometry, RecursiveApply, DataLayouts
22
import CUDA
33
import ClimaCore.Operators: return_eltype, get_local_geometry
44
import ClimaCore.Geometry:
5+
import ClimaCore.RecursiveApply: ,
56

67
Base.@propagate_inbounds function fd_operator_shmem(
78
space,
@@ -131,13 +132,11 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
131132
arg,
132133
)
133134
@inbounds begin
135+
is_out_of_bounds(idx, space) && return nothing
134136
vt = threadIdx().x
135137
cov3 = Geometry.Covariant3Vector(1)
136138
if in_domain(idx, arg_space)
137139
u[vt] = cov3 Operators.getidx(space, arg, idx, hidx)
138-
else # idx can be Spaces.nlevels(ᶜspace)+1 because threads must extend to faces
139-
ᶜspace = Spaces.center_space(arg_space)
140-
@assert idx == Spaces.nlevels(ᶜspace) + 1
141140
end
142141
if on_any_boundary(idx, space, op)
143142
lloc = Operators.left_boundary_window(space)
@@ -200,3 +199,114 @@ Base.@propagate_inbounds function fd_operator_evaluate(
200199
end
201200
end
202201
end
202+
203+
Base.@propagate_inbounds function fd_operator_shmem(
204+
space,
205+
::Val{Nvt},
206+
op::Operators.InterpolateC2F,
207+
args...,
208+
) where {Nvt}
209+
# allocate temp output
210+
RT = return_eltype(op, args...)
211+
u = CUDA.CuStaticSharedArray(RT, (Nvt,)) # cell centers
212+
lb = CUDA.CuStaticSharedArray(RT, (1,)) # left boundary
213+
rb = CUDA.CuStaticSharedArray(RT, (1,)) # right boundary
214+
return (u, lb, rb)
215+
end
216+
217+
Base.@propagate_inbounds function fd_operator_fill_shmem!(
218+
op::Operators.InterpolateC2F,
219+
(u, lb, rb),
220+
bc_bds,
221+
arg_space,
222+
space,
223+
idx::Integer,
224+
hidx,
225+
arg,
226+
)
227+
@inbounds begin
228+
is_out_of_bounds(idx, space) && return nothing
229+
ᶜidx = get_cent_idx(idx)
230+
if in_domain(idx, arg_space)
231+
u[idx] = Operators.getidx(space, arg, idx, hidx)
232+
else
233+
lloc = Operators.left_boundary_window(space)
234+
rloc = Operators.right_boundary_window(space)
235+
bloc = on_left_boundary(idx, space, op) ? lloc : rloc
236+
@assert bloc isa typeof(lloc) && on_left_boundary(idx, space, op) ||
237+
bloc isa typeof(rloc) && on_right_boundary(idx, space, op)
238+
bc = Operators.get_boundary(op, bloc)
239+
@assert bc isa Operators.SetValue ||
240+
bc isa Operators.SetGradient ||
241+
bc isa Operators.Extrapolate ||
242+
bc isa Operators.NullBoundaryCondition
243+
if bc isa Operators.NullBoundaryCondition ||
244+
bc isa Operators.Extrapolate
245+
u[idx] = Operators.getidx(space, arg, idx, hidx)
246+
return nothing
247+
end
248+
bu = on_left_boundary(idx, space) ? lb : rb
249+
ub = Operators.getidx(space, bc.val, nothing, hidx)
250+
if bc isa Operators.SetValue
251+
bu[1] = ub
252+
elseif bc isa Operators.SetGradient
253+
lg = Geometry.LocalGeometry(space, idx, hidx)
254+
bu[1] = Geometry.covariant3(ub, lg)
255+
end
256+
end
257+
end
258+
return nothing
259+
end
260+
261+
Base.@propagate_inbounds function fd_operator_evaluate(
262+
op::Operators.InterpolateC2F,
263+
(u, lb, rb),
264+
space,
265+
idx::PlusHalf,
266+
hidx,
267+
args...,
268+
)
269+
@inbounds begin
270+
vt = threadIdx().x
271+
lg = Geometry.LocalGeometry(space, idx, hidx)
272+
ᶜidx = get_cent_idx(idx)
273+
if !on_boundary(idx, space, op)
274+
u₋ = u[ᶜidx - 1] # corresponds to idx - half
275+
u₊ = u[ᶜidx] # corresponds to idx + half
276+
return RecursiveApply.rdiv(u₊ u₋, 2)
277+
else
278+
bloc =
279+
on_left_boundary(idx, space, op) ?
280+
Operators.left_boundary_window(space) :
281+
Operators.right_boundary_window(space)
282+
bc = Operators.get_boundary(op, bloc)
283+
@assert bc isa Operators.SetValue ||
284+
bc isa Operators.SetGradient ||
285+
bc isa Operators.Extrapolate
286+
if on_left_boundary(idx, space)
287+
if bc isa Operators.SetValue
288+
return lb[1]
289+
elseif bc isa Operators.SetGradient
290+
u₋ = lb[1] # corresponds to idx - half
291+
u₊ = u[ᶜidx] # corresponds to idx + half
292+
return u₊ RecursiveApply.rdiv(u₋, 2)
293+
else
294+
@assert bc isa Operators.Extrapolate
295+
return u[ᶜidx]
296+
end
297+
else
298+
@assert on_right_boundary(idx, space)
299+
if bc isa Operators.SetValue
300+
return rb[1]
301+
elseif bc isa Operators.SetGradient
302+
u₋ = u[ᶜidx - 1] # corresponds to idx - half
303+
u₊ = rb[1] # corresponds to idx + half
304+
return u₋ RecursiveApply.rdiv(u₊, 2)
305+
else
306+
@assert bc isa Operators.Extrapolate
307+
return u[ᶜidx - 1]
308+
end
309+
end
310+
end
311+
end
312+
end

ext/cuda/operators_fd_shmem_common.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ import ClimaCore.Utilities
4343
(has_left_boundary(space, op) && on_left_boundary(idx, space)) ||
4444
has_right_boundary(space, op) && on_right_boundary(idx, space)
4545

46+
@inline function is_out_of_bounds(idx::Integer, space)
47+
ᶜspace = Spaces.center_space(space)
48+
return idx == Spaces.nlevels(ᶜspace) + 1
49+
end
50+
4651
#####
4752
##### range window helpers (faces)
4853
#####
@@ -348,6 +353,15 @@ Base.@propagate_inbounds function fd_resolve_shmem!(
348353
return nothing
349354
end
350355

356+
Base.@propagate_inbounds function fd_resolve_shmem!(
357+
sbc::StencilBroadcasted,
358+
idx, # top-level index
359+
hidx,
360+
bds,
361+
)
362+
_fd_resolve_shmem!(idx, hidx, bds, sbc.args...)
363+
end
364+
351365
Base.@propagate_inbounds _fd_resolve_shmem!(idx, hidx, bds) = nothing
352366
Base.@propagate_inbounds function _fd_resolve_shmem!(
353367
idx,
@@ -360,12 +374,8 @@ Base.@propagate_inbounds function _fd_resolve_shmem!(
360374
_fd_resolve_shmem!(idx, hidx, bds, xargs...)
361375
end
362376

363-
Base.@propagate_inbounds fd_resolve_shmem!(
364-
bc::Broadcasted{CUDAWithShmemColumnStencilStyle},
365-
idx,
366-
hidx,
367-
bds,
368-
) = _fd_resolve_shmem!(idx, hidx, bds, bc.args...)
377+
Base.@propagate_inbounds fd_resolve_shmem!(bc::Broadcasted, idx, hidx, bds) =
378+
_fd_resolve_shmem!(idx, hidx, bds, bc.args...)
369379
@inline fd_resolve_shmem!(obj, idx, hidx, bds) = nothing
370380

371381
if hasfield(Method, :recursion_relation)

ext/cuda/operators_fd_shmem_is_supported.jl

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ end
162162
bcs::NamedTuple,
163163
) =
164164
all(values(bcs)) do bc
165-
all(supported_bc -> bc isa supported_bc, (Operators.SetValue,))
165+
any(supported_bc -> bc isa supported_bc, (Operators.SetValue,))
166166
end
167167

168168
##### GradientC2F
@@ -177,5 +177,23 @@ end
177177
bcs::NamedTuple,
178178
) =
179179
all(values(bcs)) do bc
180-
all(supported_bc -> bc isa supported_bc, (Operators.SetValue,))
180+
any(supported_bc -> bc isa supported_bc, (Operators.SetValue,))
181+
end
182+
183+
##### InterpolateC2F
184+
@inline Operators.fd_shmem_is_supported(op::Operators.InterpolateC2F) =
185+
Operators.fd_shmem_is_supported(op, op.bcs)
186+
@inline Operators.fd_shmem_is_supported(
187+
op::Operators.InterpolateC2F,
188+
::@NamedTuple{},
189+
) = true
190+
@inline Operators.fd_shmem_is_supported(
191+
op::Operators.InterpolateC2F,
192+
bcs::NamedTuple,
193+
) =
194+
all(values(bcs)) do bc
195+
any(
196+
supported_bc -> bc isa supported_bc,
197+
(Operators.SetValue, Operators.SetGradient, Operators.Extrapolate),
198+
)
181199
end

0 commit comments

Comments
 (0)