Skip to content

Commit 4d7246e

Browse files
charleskawczynskiCharlie Kawczynski
andauthored
Add performance tuning tests (#2291)
Co-authored-by: Charlie Kawczynski <[email protected]>
1 parent e30d2b7 commit 4d7246e

File tree

4 files changed

+229
-4
lines changed

4 files changed

+229
-4
lines changed

.buildkite/pipeline.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,16 @@ steps:
137137
agents:
138138
slurm_gpus: 1
139139

140+
- label: "Unit: data cuda threadblocks"
141+
key: unit_data_threadblock
142+
command:
143+
- "julia --project=.buildkite -e 'using CUDA; CUDA.versioninfo()'"
144+
- "julia --color=yes --check-bounds=yes --project=.buildkite test/DataLayouts/unit_cuda_threadblocks.jl"
145+
env:
146+
CLIMACOMMS_DEVICE: "CUDA"
147+
agents:
148+
slurm_gpus: 1
149+
140150
- label: "Unit: data fill"
141151
key: gpu_unit_data_fill
142152
command:

ext/cuda/data_layouts_threadblock.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,15 +191,21 @@ end
191191
@inline is_valid_index(::DataLayouts.DataF, I::CI5, us::UniversalSize) = true
192192

193193
##### Masked
194-
@inline function masked_partition(
194+
@inline masked_partition(
195195
us::DataLayouts.UniversalSize,
196196
n_max_threads::Integer,
197197
mask::IJHMask,
198+
) = masked_partition(us, n_max_threads, typeof(mask), mask.N[1])
199+
200+
@inline function masked_partition(
201+
us::DataLayouts.UniversalSize,
202+
n_max_threads::Integer,
203+
::Type{<:IJHMask},
204+
n_active_columns::Integer,
198205
)
199206
(Ni, _, _, Nv, Nh) = DataLayouts.universal_size(us)
200207
Nv_thread = min(Int(fld(n_max_threads, Ni)), Nv)
201208
Nv_blocks = cld(Nv, Nv_thread)
202-
n_active_columns = mask.N[1]
203209
@assert Nv_thread n_max_threads "threads,n_max_threads=($Nv_thread,$n_max_threads)"
204210
return (; threads = (Nv_thread,), blocks = (n_active_columns, Nv_blocks))
205211
end
@@ -315,7 +321,7 @@ end
315321
) = Operators.is_valid_index(space, ij, slabidx)
316322

317323
##### shmem fd kernel partition
318-
@inline function fd_stencil_partition(
324+
@inline function fd_shmem_stencil_partition(
319325
us::DataLayouts.UniversalSize,
320326
n_face_levels::Integer,
321327
n_max_threads::Integer = 256;

ext/cuda/operators_finite_difference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ function Base.copyto!(
5656
mask isa NoMask &&
5757
enough_shmem &&
5858
Operators.use_fd_shmem()
59-
p = fd_stencil_partition(us, n_face_levels)
59+
p = fd_shmem_stencil_partition(us, n_face_levels)
6060
args = (
6161
strip_space(out, space),
6262
strip_space(bc, space),
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
#=
2+
julia --project=.buildkite
3+
using Revise; include(joinpath("test", "DataLayouts", "unit_cuda_threadblocks.jl"))
4+
=#
5+
ENV["CLIMACOMMS_DEVICE"] = "CUDA"
6+
using Test
7+
using ClimaCore.DataLayouts
8+
using ClimaCore
9+
import ClimaComms
10+
ClimaComms.@import_required_backends
11+
ext = Base.get_extension(ClimaCore, :ClimaCoreCUDAExt)
12+
@assert !isnothing(ext) # cuda must be loaded to test this extension
13+
14+
function get_inputs()
15+
device = ClimaComms.device()
16+
ArrayType = ClimaComms.array_type(device)
17+
FT = Float64
18+
S = FT
19+
args = (ArrayType{FT}, zeros)
20+
return (; S, args)
21+
end
22+
23+
pt(d) = ext.partition(d, DataLayouts.get_N(DataLayouts.UniversalSize(d)))
24+
pt_stencil(d) = ext.fd_shmem_stencil_partition(
25+
DataLayouts.UniversalSize(d),
26+
DataLayouts.get_Nv(d),
27+
)
28+
pt_columnwise(d) =
29+
ext.columnwise_partition(DataLayouts.UniversalSize(d), DataLayouts.get_N(d))
30+
pt_mfs(d; Nnames) = ext.multiple_field_solve_partition(
31+
DataLayouts.UniversalSize(d),
32+
DataLayouts.get_N(d);
33+
Nnames,
34+
)
35+
pt_sem(d) =
36+
ext.spectral_partition(DataLayouts.UniversalSize(d), DataLayouts.get_N(d))
37+
get_Nh(h_elem) = h_elem^2 * 6
38+
39+
function pt_masked(d; frac)
40+
us = DataLayouts.UniversalSize(d)
41+
(Ni, Nj, _, Nv, Nh) = DataLayouts.universal_size(us)
42+
n_active_columns = Int(round(prod((Ni, Nj, Nh)) * frac; digits = 0))
43+
ext.masked_partition(
44+
us,
45+
DataLayouts.get_N(us),
46+
DataLayouts.IJHMask,
47+
n_active_columns,
48+
)
49+
end
50+
51+
#! format: off
52+
53+
@testset "linear_partition" begin
54+
# Fully optimized (but can be 2x slower due to integer division in CartesianIndices).
55+
# If https://github.com/maleadt/StaticCartesian.jl/issues/1 ever works, we should
56+
# basically always use that instead.
57+
end
58+
@testset "DataF partition" begin
59+
(; S, args) = get_inputs()
60+
@test pt(DataF{S}(args...)) == (; threads = 1, blocks = 1)
61+
end
62+
@testset "IJFH/IJHF partition" begin
63+
(; S, args) = get_inputs()
64+
for DL in (IJFH, IJHF)
65+
@test pt(DL{S}(args...; Nij = 1, Nh = 1)) == (; threads = (1, 1, 1), blocks = (1,))
66+
@test pt(DL{S}(args...; Nij = 1, Nh = get_Nh(1))) == (; threads = (1, 1, 6), blocks = (1,))
67+
@test pt(DL{S}(args...; Nij = 4, Nh = get_Nh(30))) == (; threads = (4, 4, 64), blocks = (85,))
68+
@test pt(DL{S}(args...; Nij = 4, Nh = get_Nh(100))) == (; threads = (4, 4, 64), blocks = (938,))
69+
@test pt(DL{S}(args...; Nij = 1, Nh = get_Nh(1))) == (; threads = (1, 1, 6), blocks = (1,))
70+
@test pt(DL{S}(args...; Nij = 1, Nh = get_Nh(30))) == (; threads = (1, 1, 64), blocks = (85,))
71+
@test pt(DL{S}(args...; Nij = 1, Nh = get_Nh(100))) == (; threads = (1, 1, 64), blocks = (938,))
72+
end
73+
end
74+
@testset "IFH/IHF partition" begin
75+
(; S, args) = get_inputs()
76+
for DL in (IFH, IHF)
77+
@test pt(DL{S}(args...; Ni = 1, Nh = 1)) == (; threads = (1, 1), blocks = (1,))
78+
@test pt(DL{S}(args...; Ni = 1, Nh = get_Nh(1))) == (; threads = (1, 6), blocks = (1,))
79+
@test pt(DL{S}(args...; Ni = 4, Nh = get_Nh(30))) == (; threads = (4, 5400), blocks = (1,))
80+
@test pt(DL{S}(args...; Ni = 4, Nh = get_Nh(100))) == (; threads = (4, 60000), blocks = (1,)) # TODO: needs fixed (too many threads per block)
81+
@test pt(DL{S}(args...; Ni = 1, Nh = get_Nh(30))) == (; threads = (1, 5400), blocks = (1,))
82+
@test pt(DL{S}(args...; Ni = 1, Nh = get_Nh(100))) == (; threads = (1, 60000), blocks = (1,)) # TODO: needs fixed (too many threads per block)
83+
end
84+
end
85+
@testset "IJF partition" begin
86+
(; S, args) = get_inputs()
87+
@test pt(IJF{S}(args...; Nij = 1)) == (; threads = (1, 1), blocks = (1,))
88+
@test pt(IJF{S}(args...; Nij = 4)) == (; threads = (4, 4), blocks = (1,))
89+
end
90+
@testset "IF partition" begin
91+
(; S, args) = get_inputs()
92+
@test pt(IF{S}(args...; Ni = 1)) == (; threads = (1,), blocks = (1,))
93+
@test pt(IF{S}(args...; Ni = 4)) == (; threads = (4,), blocks = (1,))
94+
end
95+
@testset "VF partition" begin
96+
(; S, args) = get_inputs()
97+
@test pt(VF{S}(args...; Nv = 1)) == (; threads = (1, ), blocks = (1, ))
98+
@test pt(VF{S}(args...; Nv = 10)) == (; threads = (1, ), blocks = (10, ))
99+
@test pt(VF{S}(args...; Nv = 64)) == (; threads = (1, ), blocks = (64, ))
100+
@test pt(VF{S}(args...; Nv = 1000)) == (; threads = (1, ), blocks = (1000, ))
101+
end
102+
@testset "VIJFH/VIJHF partition" begin
103+
(; S, args) = get_inputs()
104+
for DL in (VIJFH, VIJHF)
105+
@test pt(DL{S}(args...; Nv = 1, Nij = 1, Nh = 1)) == (; threads = (1, 1, 1), blocks = (1, 1))
106+
@test pt(DL{S}(args...; Nv = 1, Nij = 1, Nh = get_Nh(1))) == (; threads = (1, 1, 1), blocks = (6, 1))
107+
@test pt(DL{S}(args...; Nv = 64, Nij = 4, Nh = get_Nh(30))) == (; threads = (64, 4, 4), blocks = (5400, 1))
108+
@test pt(DL{S}(args...; Nv = 64, Nij = 4, Nh = get_Nh(100))) == (; threads = (64, 4, 4), blocks = (60000, 1))
109+
@test pt(DL{S}(args...; Nv = 64, Nij = 1, Nh = get_Nh(100))) == (; threads = (64, 1, 1), blocks = (60000, 1)) # need more threads per block?
110+
@test pt(DL{S}(args...; Nv = 10, Nij = 1, Nh = get_Nh(100))) == (; threads = (10, 1, 1), blocks = (60000, 1)) # need more threads per block
111+
@test pt(DL{S}(args...; Nv = 10, Nij = 1, Nh = get_Nh(30))) == (; threads = (10, 1, 1), blocks = (5400, 1)) # need more threads per block
112+
@test pt(DL{S}(args...; Nv = 1000, Nij = 1, Nh = get_Nh(30))) == (; threads = (1000, 1, 1), blocks = (5400, 1))
113+
@test pt(DL{S}(args...; Nv = 2000, Nij = 1, Nh = get_Nh(30))) == (; threads = (2000, 1, 1), blocks = (5400, 1)) # TODO: fix this? maximum_allowable_threads()[1] == 1024
114+
end
115+
end
116+
@testset "VIFH/VIHF partition" begin
117+
(; S, args) = get_inputs()
118+
for DL in (VIFH, VIHF)
119+
@test pt(DL{S}(args...; Nv = 1, Ni = 1, Nh = 1)) == (; threads = (1, 1), blocks = (1, 1))
120+
@test pt(DL{S}(args...; Nv = 1, Ni = 1, Nh = get_Nh(1))) == (; threads = (1, 1), blocks = (6, 1))
121+
@test pt(DL{S}(args...; Nv = 64, Ni = 4, Nh = get_Nh(30))) == (; threads = (64, 4), blocks = (5400, 1))
122+
@test pt(DL{S}(args...; Nv = 64, Ni = 4, Nh = get_Nh(100))) == (; threads = (64, 4), blocks = (60000, 1))
123+
@test pt(DL{S}(args...; Nv = 64, Ni = 1, Nh = get_Nh(100))) == (; threads = (64, 1), blocks = (60000, 1)) # need more threads per block?
124+
@test pt(DL{S}(args...; Nv = 10, Ni = 1, Nh = get_Nh(100))) == (; threads = (10, 1), blocks = (60000, 1)) # need more threads per block
125+
end
126+
end
127+
128+
@testset "fd_shmem_stencil_partition" begin
129+
(; S, args) = get_inputs()
130+
for DL in (VIFH, VIHF)
131+
@test pt_stencil(DL{S}(args...; Nv = 10, Ni = 1, Nh = get_Nh(100))) == (; threads = (10,), blocks = (60000, 1, 1), Nvthreads = 10)
132+
@test pt_stencil(DL{S}(args...; Nv = 10, Ni = 4, Nh = get_Nh(100))) == (; threads = (10,), blocks = (60000, 1, 4), Nvthreads = 10)
133+
@test pt_stencil(DL{S}(args...; Nv = 100, Ni = 4, Nh = get_Nh(100))) == (; threads = (100,), blocks = (60000, 1, 4), Nvthreads = 100)
134+
end
135+
for DL in (VIJFH, VIJHF)
136+
@test pt_stencil(DL{S}(args...; Nv = 10, Nij = 1, Nh = get_Nh(100))) == (; threads = (10,), blocks = (60000, 1, 1), Nvthreads = 10)
137+
@test pt_stencil(DL{S}(args...; Nv = 10, Nij = 4, Nh = get_Nh(100))) == (; threads = (10,), blocks = (60000, 1, 16), Nvthreads = 10)
138+
@test pt_stencil(DL{S}(args...; Nv = 100, Nij = 4, Nh = get_Nh(100))) == (; threads = (100,), blocks = (60000, 1, 16), Nvthreads = 100)
139+
end
140+
@test pt_stencil(VF{S}(args...; Nv = 10)) == (; threads = (10,), blocks = (1, 1, 1), Nvthreads = 10)
141+
@test pt_stencil(VF{S}(args...; Nv = 1000)) == (; threads = (1000,), blocks = (1, 1, 1), Nvthreads = 1000)
142+
end
143+
144+
@testset "spectral_partition" begin
145+
(; S, args) = get_inputs()
146+
for DL in (VIFH, VIHF)
147+
@test pt_sem(DL{S}(args...; Nv = 10, Ni = 1, Nh = get_Nh(100))) == (; threads = (1, 1, 64), blocks = (60000, 1), Nvthreads = 64)
148+
@test pt_sem(DL{S}(args...; Nv = 10, Ni = 4, Nh = get_Nh(100))) == (; threads = (4, 1, 64), blocks = (60000, 1), Nvthreads = 64)
149+
@test pt_sem(DL{S}(args...; Nv = 100, Ni = 4, Nh = get_Nh(100))) == (; threads = (4, 1, 64), blocks = (60000, 2), Nvthreads = 64)
150+
end
151+
for DL in (VIJFH, VIJHF)
152+
@test pt_sem(DL{S}(args...; Nv = 10, Nij = 1, Nh = get_Nh(100))) == (; threads = (1, 1, 64), blocks = (60000, 1), Nvthreads = 64)
153+
@test pt_sem(DL{S}(args...; Nv = 10, Nij = 4, Nh = get_Nh(100))) == (; threads = (4, 4, 64), blocks = (60000, 1), Nvthreads = 64)
154+
@test pt_sem(DL{S}(args...; Nv = 100, Nij = 4, Nh = get_Nh(100))) == (; threads = (4, 4, 64), blocks = (60000, 2), Nvthreads = 64)
155+
end
156+
for DL in (IJFH, IJHF)
157+
@test pt_sem(DL{S}(args...; Nij = 1, Nh = get_Nh(100))) == (; threads = (1, 1, 64), blocks = (60000, 1), Nvthreads = 64) # can/should we reduce # of blocks?
158+
@test pt_sem(DL{S}(args...; Nij = 4, Nh = get_Nh(100))) == (; threads = (4, 4, 64), blocks = (60000, 1), Nvthreads = 64) # can/should we reduce # of blocks?
159+
end
160+
end
161+
162+
@testset "columnwise_partition" begin
163+
(; S, args) = get_inputs()
164+
for DL in (IFH, IHF)
165+
@test pt_columnwise(DL{S}(args...; Ni = 1, Nh = get_Nh(100))) == (; threads = (1, 1, 64), blocks = (938,))
166+
@test pt_columnwise(DL{S}(args...; Ni = 4, Nh = get_Nh(100))) == (; threads = (4, 1, 64), blocks = (938,))
167+
@test pt_columnwise(DL{S}(args...; Ni = 4, Nh = get_Nh(100))) == (; threads = (4, 1, 64), blocks = (938,))
168+
end
169+
for DL in (IJFH, IJHF)
170+
@test pt_columnwise(DL{S}(args...; Nij = 1, Nh = get_Nh(100))) == (; threads = (1, 1, 64), blocks = (938,)) # more threads per block?
171+
@test pt_columnwise(DL{S}(args...; Nij = 4, Nh = get_Nh(100))) == (; threads = (4, 4, 64), blocks = (938,)) # more threads per block?
172+
end
173+
end
174+
175+
@testset "multiple_field_solve_partition" begin
176+
(; S, args) = get_inputs()
177+
for DL in (IFH, IHF)
178+
@test pt_mfs(DL{S}(args...; Ni = 1, Nh = get_Nh(100)); Nnames = 1) == (; threads = (1, 1, 1), blocks = (60000,))
179+
@test pt_mfs(DL{S}(args...; Ni = 4, Nh = get_Nh(100)); Nnames = 2) == (; threads = (4, 1, 2), blocks = (60000,)) # more threads per block?
180+
end
181+
for DL in (IJFH, IJHF)
182+
@test pt_mfs(DL{S}(args...; Nij = 1, Nh = get_Nh(100)); Nnames = 1) == (; threads = (1, 1, 1), blocks = (60000,)) # more threads per block?
183+
@test pt_mfs(DL{S}(args...; Nij = 4, Nh = get_Nh(100)); Nnames = 2) == (; threads = (4, 4, 2), blocks = (60000,)) # more threads per block?
184+
end
185+
end
186+
187+
@testset "masked_partition" begin
188+
(; S, args) = get_inputs()
189+
for DL in (VIFH, VIHF)
190+
@test pt_masked(DL{S}(args...; Nv = 10, Ni = 1, Nh = get_Nh(100)); frac = 0.5) == (; threads = (10,), blocks = (30000, 1)) # need more threads per block
191+
@test pt_masked(DL{S}(args...; Nv = 10, Ni = 1, Nh = get_Nh(100)); frac = 0.1) == (; threads = (10,), blocks = (6000, 1)) # need more threads per block
192+
@test pt_masked(DL{S}(args...; Nv = 10, Ni = 1, Nh = get_Nh(100)); frac = 0.8) == (; threads = (10,), blocks = (48000, 1)) # need more threads per block
193+
194+
@test pt_masked(DL{S}(args...; Nv = 100, Ni = 1, Nh = get_Nh(100)); frac = 0.5) == (; threads = (100,), blocks = (30000, 1)) # need more threads per block
195+
@test pt_masked(DL{S}(args...; Nv = 100, Ni = 1, Nh = get_Nh(100)); frac = 0.1) == (; threads = (100,), blocks = (6000, 1)) # need more threads per block
196+
@test pt_masked(DL{S}(args...; Nv = 100, Ni = 1, Nh = get_Nh(100)); frac = 0.8) == (; threads = (100,), blocks = (48000, 1)) # need more threads per block
197+
end
198+
for DL in (VIJFH, VIJHF)
199+
@test pt_masked(DL{S}(args...; Nv = 10, Nij = 1, Nh = get_Nh(100)); frac = 0.5) == (; threads = (10,), blocks = (30000, 1)) # need more threads per block
200+
@test pt_masked(DL{S}(args...; Nv = 10, Nij = 1, Nh = get_Nh(100)); frac = 0.1) == (; threads = (10,), blocks = (6000, 1)) # need more threads per block
201+
@test pt_masked(DL{S}(args...; Nv = 10, Nij = 1, Nh = get_Nh(100)); frac = 0.8) == (; threads = (10,), blocks = (48000, 1)) # need more threads per block
202+
203+
@test pt_masked(DL{S}(args...; Nv = 100, Nij = 1, Nh = get_Nh(100)); frac = 0.5) == (; threads = (100,), blocks = (30000, 1)) # need more threads per block
204+
@test pt_masked(DL{S}(args...; Nv = 100, Nij = 1, Nh = get_Nh(100)); frac = 0.1) == (; threads = (100,), blocks = (6000, 1)) # need more threads per block
205+
@test pt_masked(DL{S}(args...; Nv = 100, Nij = 1, Nh = get_Nh(100)); frac = 0.8) == (; threads = (100,), blocks = (48000, 1)) # need more threads per block
206+
end
207+
end
208+
209+
#! format: on

0 commit comments

Comments
 (0)