Skip to content

Commit ad541e0

Browse files
committed
Improve load distribution of solver for lowres gpu simulations
1 parent 46ea344 commit ad541e0

File tree

3 files changed

+40
-11
lines changed

3 files changed

+40
-11
lines changed

ext/cuda/cuda_utils.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,40 @@ function threads_via_occupancy(f!::F!, args) where {F!}
208208
return config.threads
209209
end
210210

211+
"""
212+
config_via_occupancy(f!::F!, nitems, args) where {F!}
213+
214+
Returns a named tuple of `(:threads, :blocks)` that contains an approximate
215+
optimal launch configuration for the kernel `f!` with arguments `args`, given
216+
`nitems` total items to process.
217+
218+
If the number of items is greater than the minimal number of threads required for the config
219+
suggested by `CUDA.launch_configuration` to be valid, that config is returned. Otherwise,
220+
the threads are spread out across more SMs to improve occupancy.
221+
"""
222+
function config_via_occupancy(f!::F!, nitems, args) where {F!}
223+
kernel = CUDA.@cuda always_inline = true launch = false f!(args...)
224+
config = CUDA.launch_configuration(kernel.fun)
225+
SM_count = CUDA.attribute(CUDA.device(), CUDA.DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT)
226+
max_block_size = CUDA.attribute(CUDA.device(), CUDA.DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X)
227+
if cld(nitems, config.threads) < config.blocks
228+
# gpu will not saturate, so spread out threads across more SMs
229+
even_distribution_threads = cld(nitems, SM_count)
230+
# Ensure we don't exceed max block size (usually limited by register pressure)
231+
# If so, attempt to halve the number of threads
232+
even_distribution_threads =
233+
even_distribution_threads > max_block_size ? div(even_distribution_threads, 2) :
234+
even_distribution_threads
235+
# it should be safe to assume even_distribution_threads < config.threads here
236+
threads = min(even_distribution_threads, config.threads)
237+
blocks = cld(nitems, threads)
238+
else
239+
threads = min(nitems, config.threads)
240+
blocks = cld(nitems, threads)
241+
end
242+
return (; threads, blocks)
243+
end
244+
211245
"""
212246
thread_index()
213247

ext/cuda/matrix_fields_multiple_field_solve.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,12 @@ NVTX.@annotate function multiple_field_solve!(
3838
args = (device, caches, xs, As, bs, x1, us, mask, cart_inds, Val(Nnames))
3939

4040
nitems = Ni * Nj * Nh * Nnames
41-
threads = threads_via_occupancy(multiple_field_solve_kernel!, args)
42-
n_max_threads = min(threads, nitems)
43-
p = linear_partition(nitems, n_max_threads)
44-
41+
(; threads, blocks) = config_via_occupancy(multiple_field_solve_kernel!, nitems, args)
4542
auto_launch!(
4643
multiple_field_solve_kernel!,
4744
args;
48-
threads_s = p.threads,
49-
blocks_s = p.blocks,
45+
threads_s = threads,
46+
blocks_s = blocks,
5047
always_inline = true,
5148
)
5249
call_post_op_callback() && post_op_callback(x, dev, cache, x, A, b, x1)

ext/cuda/matrix_fields_single_field_solve.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,13 @@ function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b)
1919
mask = Spaces.get_mask(axes(x))
2020
cart_inds = cartesian_indices_columnwise(us)
2121
args = (device, cache, x, A, b, us, mask, cart_inds)
22-
threads = threads_via_occupancy(single_field_solve_kernel!, args)
2322
nitems = Ni * Nj * Nh
24-
n_max_threads = min(threads, nitems)
25-
p = linear_partition(nitems, n_max_threads)
23+
(; threads, blocks) = config_via_occupancy(single_field_solve_kernel!, nitems, args)
2624
auto_launch!(
2725
single_field_solve_kernel!,
2826
args;
29-
threads_s = p.threads,
30-
blocks_s = p.blocks,
27+
threads_s = threads,
28+
blocks_s = blocks,
3129
)
3230
call_post_op_callback() && post_op_callback(x, device, cache, x, A, b)
3331
end

0 commit comments

Comments
 (0)