From 97b84fdba4bd7ab032afd1681375cc93db36b127 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Sat, 26 Apr 2025 13:34:40 -0400 Subject: [PATCH] Generalize shmem size --- ext/cuda/operators_fd_shmem.jl | 30 +++++++++++------------ ext/cuda/operators_fd_shmem_common.jl | 32 ++++++++++++++----------- ext/cuda/operators_finite_difference.jl | 13 ++++++---- 3 files changed, 42 insertions(+), 33 deletions(-) diff --git a/ext/cuda/operators_fd_shmem.jl b/ext/cuda/operators_fd_shmem.jl index 0f0ce921d0..da1d3ffdd9 100644 --- a/ext/cuda/operators_fd_shmem.jl +++ b/ext/cuda/operators_fd_shmem.jl @@ -6,15 +6,15 @@ import ClimaCore.RecursiveApply: ⊟, ⊞ Base.@propagate_inbounds function fd_operator_shmem( space, - ::Val{Nvt}, + shmem_params, op::Operators.DivergenceF2C, args..., -) where {Nvt} +) # allocate temp output RT = return_eltype(op, args...) - Ju³ = CUDA.CuStaticSharedArray(RT, (Nvt,)) - lJu³ = CUDA.CuStaticSharedArray(RT, (1,)) - rJu³ = CUDA.CuStaticSharedArray(RT, (1,)) + Ju³ = CUDA.CuStaticSharedArray(RT, interior_size(shmem_params)) + lJu³ = CUDA.CuStaticSharedArray(RT, boundary_size(shmem_params)) + rJu³ = CUDA.CuStaticSharedArray(RT, boundary_size(shmem_params)) return (Ju³, lJu³, rJu³) end @@ -109,15 +109,15 @@ end Base.@propagate_inbounds function fd_operator_shmem( space, - ::Val{Nvt}, + shmem_params, op::Operators.GradientC2F, args..., -) where {Nvt} +) # allocate temp output RT = return_eltype(op, args...) - u = CUDA.CuStaticSharedArray(RT, (Nvt,)) # cell centers - lb = CUDA.CuStaticSharedArray(RT, (1,)) # left boundary - rb = CUDA.CuStaticSharedArray(RT, (1,)) # right boundary + u = CUDA.CuStaticSharedArray(RT, interior_size(shmem_params)) # cell centers + lb = CUDA.CuStaticSharedArray(RT, boundary_size(shmem_params)) # left boundary + rb = CUDA.CuStaticSharedArray(RT, boundary_size(shmem_params)) # right boundary return (u, lb, rb) end @@ -202,15 +202,15 @@ end Base.@propagate_inbounds function fd_operator_shmem( space, - ::Val{Nvt}, + shmem_params, op::Operators.InterpolateC2F, args..., -) where {Nvt} +) # allocate temp output RT = return_eltype(op, args...) - u = CUDA.CuStaticSharedArray(RT, (Nvt,)) # cell centers - lb = CUDA.CuStaticSharedArray(RT, (1,)) # left boundary - rb = CUDA.CuStaticSharedArray(RT, (1,)) # right boundary + u = CUDA.CuStaticSharedArray(RT, interior_size(shmem_params)) # cell centers + lb = CUDA.CuStaticSharedArray(RT, boundary_size(shmem_params)) # left boundary + rb = CUDA.CuStaticSharedArray(RT, boundary_size(shmem_params)) # right boundary return (u, lb, rb) end diff --git a/ext/cuda/operators_fd_shmem_common.jl b/ext/cuda/operators_fd_shmem_common.jl index 6c04ce9f59..6aded04ead 100644 --- a/ext/cuda/operators_fd_shmem_common.jl +++ b/ext/cuda/operators_fd_shmem_common.jl @@ -209,19 +209,23 @@ Base.@propagate_inbounds function getidx( end """ - fd_allocate_shmem(Val(Nvt), b) + fd_allocate_shmem(shmem_params, b) Create a new broadcasted object with necessary share memory allocated, -using `Nvt` nodal points per block. +using `params` nodal points per block. """ -@inline function fd_allocate_shmem(::Val{Nvt}, obj) where {Nvt} +@inline function fd_allocate_shmem(::ShmemParams, obj) obj end @inline function fd_allocate_shmem( - ::Val{Nvt}, + shmem_params::ShmemParams, bc::Broadcasted{Style}, -) where {Nvt, Style} - Broadcasted{Style}(bc.f, _fd_allocate_shmem(Val(Nvt), bc.args...), bc.axes) +) where {Style} + Broadcasted{Style}( + bc.f, + _fd_allocate_shmem(shmem_params, bc.args...), + bc.axes, + ) end ######### MatrixFields @@ -236,22 +240,22 @@ end ######### @inline function fd_allocate_shmem( - ::Val{Nvt}, + shmem_params::ShmemParams, sbc::StencilBroadcasted{Style}, -) where {Nvt, Style} - args = _fd_allocate_shmem(Val(Nvt), sbc.args...) +) where {Style} + args = _fd_allocate_shmem(shmem_params, sbc.args...) work = if Operators.fd_shmem_is_supported(sbc) - fd_operator_shmem(sbc.axes, Val(Nvt), sbc.op, args...) + fd_operator_shmem(sbc.axes, shmem_params, sbc.op, args...) else nothing end StencilBroadcasted{Style}(sbc.op, args, sbc.axes, work) end -@inline _fd_allocate_shmem(::Val{Nvt}) where {Nvt} = () -@inline _fd_allocate_shmem(::Val{Nvt}, arg, xargs...) where {Nvt} = ( - fd_allocate_shmem(Val(Nvt), arg), - _fd_allocate_shmem(Val(Nvt), xargs...)..., +@inline _fd_allocate_shmem(::ShmemParams) = () +@inline _fd_allocate_shmem(shmem_params::ShmemParams, arg, xargs...) = ( + fd_allocate_shmem(shmem_params, arg), + _fd_allocate_shmem(shmem_params, xargs...)..., ) """ diff --git a/ext/cuda/operators_finite_difference.jl b/ext/cuda/operators_finite_difference.jl index 86a5cb5cec..cd2f8ad09a 100644 --- a/ext/cuda/operators_finite_difference.jl +++ b/ext/cuda/operators_finite_difference.jl @@ -22,6 +22,10 @@ Base.Broadcast.BroadcastStyle( include("operators_fd_shmem_is_supported.jl") +struct ShmemParams{Nv} end +interior_size(::ShmemParams{Nv}) where {Nv} = (Nv,) +boundary_size(::ShmemParams{Nv}) where {Nv} = (1,) + function Base.copyto!( out::Field, bc::Union{ @@ -56,6 +60,7 @@ function Base.copyto!( mask isa NoMask && enough_shmem && Operators.use_fd_shmem() + shmem_params = ShmemParams{n_face_levels}() p = fd_shmem_stencil_partition(us, n_face_levels) args = ( strip_space(out, space), @@ -64,7 +69,7 @@ function Base.copyto!( bounds, us, mask, - Val(p.Nvthreads), + shmem_params, ) auto_launch!( copyto_stencil_kernel_shmem!, @@ -153,8 +158,8 @@ function copyto_stencil_kernel_shmem!( bds, us, mask, - ::Val{Nvt}, -) where {Nvt} + shmem_params::ShmemParams, +) @inbounds begin out_fv = Fields.field_values(out) us = DataLayouts.UniversalSize(out_fv) @@ -165,7 +170,7 @@ function copyto_stencil_kernel_shmem!( hidx = (i, j, h) idx = v - 1 + li bc = Operators.reconstruct_placeholder_broadcasted(space, bc′) - bc_shmem = fd_allocate_shmem(Val(Nvt), bc) # allocates shmem + bc_shmem = fd_allocate_shmem(shmem_params, bc) # allocates shmem fd_resolve_shmem!(bc_shmem, idx, hidx, bds) # recursively fills shmem CUDA.sync_threads()