Skip to content

Commit 30aeb41

Browse files
Generalize shmem size
1 parent bb9aa81 commit 30aeb41

File tree

3 files changed

+41
-33
lines changed

3 files changed

+41
-33
lines changed

ext/cuda/operators_fd_shmem.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@ import ClimaCore.RecursiveApply: ⊟, ⊞
66

77
Base.@propagate_inbounds function fd_operator_shmem(
88
space,
9-
::Val{Nvt},
9+
::Val{params},
1010
op::Operators.DivergenceF2C,
1111
args...,
12-
) where {Nvt}
12+
) where {params}
1313
# allocate temp output
1414
RT = return_eltype(op, args...)
15-
Ju³ = CUDA.CuStaticSharedArray(RT, (Nvt,))
16-
lJu³ = CUDA.CuStaticSharedArray(RT, (1,))
17-
rJu³ = CUDA.CuStaticSharedArray(RT, (1,))
15+
Ju³ = CUDA.CuStaticSharedArray(RT, shmem_size(Val(params)))
16+
lJu³ = CUDA.CuStaticSharedArray(RT, boundary_shmem_size(Val(params)))
17+
rJu³ = CUDA.CuStaticSharedArray(RT, boundary_shmem_size(Val(params)))
1818
return (Ju³, lJu³, rJu³)
1919
end
2020

@@ -109,15 +109,15 @@ end
109109

110110
Base.@propagate_inbounds function fd_operator_shmem(
111111
space,
112-
::Val{Nvt},
112+
::Val{params},
113113
op::Operators.GradientC2F,
114114
args...,
115-
) where {Nvt}
115+
) where {params}
116116
# allocate temp output
117117
RT = return_eltype(op, args...)
118-
u = CUDA.CuStaticSharedArray(RT, (Nvt,)) # cell centers
119-
lb = CUDA.CuStaticSharedArray(RT, (1,)) # left boundary
120-
rb = CUDA.CuStaticSharedArray(RT, (1,)) # right boundary
118+
u = CUDA.CuStaticSharedArray(RT, shmem_size(Val(params))) # cell centers
119+
lb = CUDA.CuStaticSharedArray(RT, boundary_shmem_size(Val(params))) # left boundary
120+
rb = CUDA.CuStaticSharedArray(RT, boundary_shmem_size(Val(params))) # right boundary
121121
return (u, lb, rb)
122122
end
123123

@@ -202,15 +202,15 @@ end
202202

203203
Base.@propagate_inbounds function fd_operator_shmem(
204204
space,
205-
::Val{Nvt},
205+
::Val{params},
206206
op::Operators.InterpolateC2F,
207207
args...,
208-
) where {Nvt}
208+
) where {params}
209209
# allocate temp output
210210
RT = return_eltype(op, args...)
211-
u = CUDA.CuStaticSharedArray(RT, (Nvt,)) # cell centers
212-
lb = CUDA.CuStaticSharedArray(RT, (1,)) # left boundary
213-
rb = CUDA.CuStaticSharedArray(RT, (1,)) # right boundary
211+
u = CUDA.CuStaticSharedArray(RT, shmem_size(Val(params))) # cell centers
212+
lb = CUDA.CuStaticSharedArray(RT, boundary_shmem_size(Val(params))) # left boundary
213+
rb = CUDA.CuStaticSharedArray(RT, boundary_shmem_size(Val(params))) # right boundary
214214
return (u, lb, rb)
215215
end
216216

ext/cuda/operators_fd_shmem_common.jl

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -209,19 +209,23 @@ Base.@propagate_inbounds function getidx(
209209
end
210210

211211
"""
212-
fd_allocate_shmem(Val(Nvt), b)
212+
fd_allocate_shmem(Val(params), b)
213213
214214
Create a new broadcasted object with necessary share memory allocated,
215-
using `Nvt` nodal points per block.
215+
using `params` nodal points per block.
216216
"""
217-
@inline function fd_allocate_shmem(::Val{Nvt}, obj) where {Nvt}
217+
@inline function fd_allocate_shmem(::Val{params}, obj) where {params}
218218
obj
219219
end
220220
@inline function fd_allocate_shmem(
221-
::Val{Nvt},
221+
::Val{params},
222222
bc::Broadcasted{Style},
223-
) where {Nvt, Style}
224-
Broadcasted{Style}(bc.f, _fd_allocate_shmem(Val(Nvt), bc.args...), bc.axes)
223+
) where {params, Style}
224+
Broadcasted{Style}(
225+
bc.f,
226+
_fd_allocate_shmem(Val(params), bc.args...),
227+
bc.axes,
228+
)
225229
end
226230

227231
######### MatrixFields
@@ -236,24 +240,27 @@ end
236240
#########
237241

238242
@inline function fd_allocate_shmem(
239-
::Val{Nvt},
243+
::Val{params},
240244
sbc::StencilBroadcasted{Style},
241-
) where {Nvt, Style}
242-
args = _fd_allocate_shmem(Val(Nvt), sbc.args...)
245+
) where {params, Style}
246+
args = _fd_allocate_shmem(Val(params), sbc.args...)
243247
work = if Operators.fd_shmem_is_supported(sbc)
244-
fd_operator_shmem(sbc.axes, Val(Nvt), sbc.op, args...)
248+
fd_operator_shmem(sbc.axes, Val(params), sbc.op, args...)
245249
else
246250
nothing
247251
end
248252
StencilBroadcasted{Style}(sbc.op, args, sbc.axes, work)
249253
end
250254

251-
@inline _fd_allocate_shmem(::Val{Nvt}) where {Nvt} = ()
252-
@inline _fd_allocate_shmem(::Val{Nvt}, arg, xargs...) where {Nvt} = (
253-
fd_allocate_shmem(Val(Nvt), arg),
254-
_fd_allocate_shmem(Val(Nvt), xargs...)...,
255+
@inline _fd_allocate_shmem(::Val{params}) where {params} = ()
256+
@inline _fd_allocate_shmem(::Val{params}, arg, xargs...) where {params} = (
257+
fd_allocate_shmem(Val(params), arg),
258+
_fd_allocate_shmem(Val(params), xargs...)...,
255259
)
256260

261+
shmem_size(::Val{params}) where {params} = (params.Nvt,)
262+
boundary_shmem_size(::Val{params}) where {params} = (1,)
263+
257264
"""
258265
fd_shmem_needed_per_column(::Base.Broadcast.Broadcasted)
259266
fd_shmem_needed_per_column(::StencilBroadcasted)

ext/cuda/operators_finite_difference.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ function Base.copyto!(
5656
mask isa NoMask &&
5757
enough_shmem &&
5858
Operators.use_fd_shmem()
59+
shmem_params = (; Nv = n_face_levels)
5960
p = fd_shmem_stencil_partition(us, n_face_levels)
6061
args = (
6162
strip_space(out, space),
@@ -64,7 +65,7 @@ function Base.copyto!(
6465
bounds,
6566
us,
6667
mask,
67-
Val(p.Nvthreads),
68+
Val(shmem_params),
6869
)
6970
auto_launch!(
7071
copyto_stencil_kernel_shmem!,
@@ -153,8 +154,8 @@ function copyto_stencil_kernel_shmem!(
153154
bds,
154155
us,
155156
mask,
156-
::Val{Nvt},
157-
) where {Nvt}
157+
::Val{shmem_params},
158+
) where {shmem_params}
158159
@inbounds begin
159160
out_fv = Fields.field_values(out)
160161
us = DataLayouts.UniversalSize(out_fv)
@@ -165,7 +166,7 @@ function copyto_stencil_kernel_shmem!(
165166
hidx = (i, j, h)
166167
idx = v - 1 + li
167168
bc = Operators.reconstruct_placeholder_broadcasted(space, bc′)
168-
bc_shmem = fd_allocate_shmem(Val(Nvt), bc) # allocates shmem
169+
bc_shmem = fd_allocate_shmem(Val(shmem_params), bc) # allocates shmem
169170

170171
fd_resolve_shmem!(bc_shmem, idx, hidx, bds) # recursively fills shmem
171172
CUDA.sync_threads()

0 commit comments

Comments
 (0)