Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions ext/cuda/operators_fd_shmem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ import ClimaCore.RecursiveApply: ⊟, ⊞

Base.@propagate_inbounds function fd_operator_shmem(
space,
::Val{Nvt},
shmem_params,
op::Operators.DivergenceF2C,
args...,
) where {Nvt}
)
# allocate temp output
RT = return_eltype(op, args...)
Ju³ = CUDA.CuStaticSharedArray(RT, (Nvt,))
lJu³ = CUDA.CuStaticSharedArray(RT, (1,))
rJu³ = CUDA.CuStaticSharedArray(RT, (1,))
Ju³ = CUDA.CuStaticSharedArray(RT, interior_size(shmem_params))
lJu³ = CUDA.CuStaticSharedArray(RT, boundary_size(shmem_params))
rJu³ = CUDA.CuStaticSharedArray(RT, boundary_size(shmem_params))
return (Ju³, lJu³, rJu³)
end

Expand Down Expand Up @@ -109,15 +109,15 @@ end

Base.@propagate_inbounds function fd_operator_shmem(
space,
::Val{Nvt},
shmem_params,
op::Operators.GradientC2F,
args...,
) where {Nvt}
)
# allocate temp output
RT = return_eltype(op, args...)
u = CUDA.CuStaticSharedArray(RT, (Nvt,)) # cell centers
lb = CUDA.CuStaticSharedArray(RT, (1,)) # left boundary
rb = CUDA.CuStaticSharedArray(RT, (1,)) # right boundary
u = CUDA.CuStaticSharedArray(RT, interior_size(shmem_params)) # cell centers
lb = CUDA.CuStaticSharedArray(RT, boundary_size(shmem_params)) # left boundary
rb = CUDA.CuStaticSharedArray(RT, boundary_size(shmem_params)) # right boundary
return (u, lb, rb)
end

Expand Down Expand Up @@ -202,15 +202,15 @@ end

Base.@propagate_inbounds function fd_operator_shmem(
space,
::Val{Nvt},
shmem_params,
op::Operators.InterpolateC2F,
args...,
) where {Nvt}
)
# allocate temp output
RT = return_eltype(op, args...)
u = CUDA.CuStaticSharedArray(RT, (Nvt,)) # cell centers
lb = CUDA.CuStaticSharedArray(RT, (1,)) # left boundary
rb = CUDA.CuStaticSharedArray(RT, (1,)) # right boundary
u = CUDA.CuStaticSharedArray(RT, interior_size(shmem_params)) # cell centers
lb = CUDA.CuStaticSharedArray(RT, boundary_size(shmem_params)) # left boundary
rb = CUDA.CuStaticSharedArray(RT, boundary_size(shmem_params)) # right boundary
return (u, lb, rb)
end

Expand Down
32 changes: 18 additions & 14 deletions ext/cuda/operators_fd_shmem_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,19 +209,23 @@ Base.@propagate_inbounds function getidx(
end

"""
fd_allocate_shmem(Val(Nvt), b)
fd_allocate_shmem(shmem_params, b)

Create a new broadcasted object with necessary share memory allocated,
using `Nvt` nodal points per block.
using `params` nodal points per block.
"""
@inline function fd_allocate_shmem(::Val{Nvt}, obj) where {Nvt}
@inline function fd_allocate_shmem(::ShmemParams, obj)
obj
end
@inline function fd_allocate_shmem(
::Val{Nvt},
shmem_params::ShmemParams,
bc::Broadcasted{Style},
) where {Nvt, Style}
Broadcasted{Style}(bc.f, _fd_allocate_shmem(Val(Nvt), bc.args...), bc.axes)
) where {Style}
Broadcasted{Style}(
bc.f,
_fd_allocate_shmem(shmem_params, bc.args...),
bc.axes,
)
end

######### MatrixFields
Expand All @@ -236,22 +240,22 @@ end
#########

@inline function fd_allocate_shmem(
::Val{Nvt},
shmem_params::ShmemParams,
sbc::StencilBroadcasted{Style},
) where {Nvt, Style}
args = _fd_allocate_shmem(Val(Nvt), sbc.args...)
) where {Style}
args = _fd_allocate_shmem(shmem_params, sbc.args...)
work = if Operators.fd_shmem_is_supported(sbc)
fd_operator_shmem(sbc.axes, Val(Nvt), sbc.op, args...)
fd_operator_shmem(sbc.axes, shmem_params, sbc.op, args...)
else
nothing
end
StencilBroadcasted{Style}(sbc.op, args, sbc.axes, work)
end

@inline _fd_allocate_shmem(::Val{Nvt}) where {Nvt} = ()
@inline _fd_allocate_shmem(::Val{Nvt}, arg, xargs...) where {Nvt} = (
fd_allocate_shmem(Val(Nvt), arg),
_fd_allocate_shmem(Val(Nvt), xargs...)...,
@inline _fd_allocate_shmem(::ShmemParams) = ()
@inline _fd_allocate_shmem(shmem_params::ShmemParams, arg, xargs...) = (
fd_allocate_shmem(shmem_params, arg),
_fd_allocate_shmem(shmem_params, xargs...)...,
)

"""
Expand Down
13 changes: 9 additions & 4 deletions ext/cuda/operators_finite_difference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ Base.Broadcast.BroadcastStyle(

include("operators_fd_shmem_is_supported.jl")

struct ShmemParams{Nv} end
interior_size(::ShmemParams{Nv}) where {Nv} = (Nv,)
boundary_size(::ShmemParams{Nv}) where {Nv} = (1,)

function Base.copyto!(
out::Field,
bc::Union{
Expand Down Expand Up @@ -56,6 +60,7 @@ function Base.copyto!(
mask isa NoMask &&
enough_shmem &&
Operators.use_fd_shmem()
shmem_params = ShmemParams{n_face_levels}()
p = fd_shmem_stencil_partition(us, n_face_levels)
args = (
strip_space(out, space),
Expand All @@ -64,7 +69,7 @@ function Base.copyto!(
bounds,
us,
mask,
Val(p.Nvthreads),
shmem_params,
)
auto_launch!(
copyto_stencil_kernel_shmem!,
Expand Down Expand Up @@ -153,8 +158,8 @@ function copyto_stencil_kernel_shmem!(
bds,
us,
mask,
::Val{Nvt},
) where {Nvt}
shmem_params::ShmemParams,
)
@inbounds begin
out_fv = Fields.field_values(out)
us = DataLayouts.UniversalSize(out_fv)
Expand All @@ -165,7 +170,7 @@ function copyto_stencil_kernel_shmem!(
hidx = (i, j, h)
idx = v - 1 + li
bc = Operators.reconstruct_placeholder_broadcasted(space, bc′)
bc_shmem = fd_allocate_shmem(Val(Nvt), bc) # allocates shmem
bc_shmem = fd_allocate_shmem(shmem_params, bc) # allocates shmem

fd_resolve_shmem!(bc_shmem, idx, hidx, bds) # recursively fills shmem
CUDA.sync_threads()
Expand Down
Loading