Skip to content

Commit 9267daa

Browse files
committed
Add GPU support for extruded 1D spaces
1 parent 53acb3f commit 9267daa

27 files changed

+821
-566
lines changed

.buildkite/pipeline.yml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,8 +358,6 @@ steps:
358358
- label: "Unit: distributed remapping with CUDA (1 process)"
359359
key: distributed_remapping_gpu_1proc
360360
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/Remapping/distributed_remapping.jl"
361-
env:
362-
CLIMACOMMS_DEVICE: "CUDA"
363361
env:
364362
CLIMACOMMS_DEVICE: "CUDA"
365363
agents:
@@ -609,6 +607,16 @@ steps:
609607
- "julia --color=yes --check-bounds=yes --project=.buildkite test/Operators/hybrid/unit_2d.jl"
610608
- "julia --color=yes --check-bounds=yes --project=.buildkite test/Operators/hybrid/convergence_2d.jl"
611609

610+
- label: "Unit: hyb ops 2d CUDA"
611+
key: unit_hyb_ops_2d_cuda
612+
command:
613+
- "julia --color=yes --check-bounds=yes --project=.buildkite test/Operators/hybrid/unit_2d.jl"
614+
- "julia --color=yes --check-bounds=yes --project=.buildkite test/Operators/hybrid/convergence_2d.jl"
615+
env:
616+
CLIMACOMMS_DEVICE: "CUDA"
617+
agents:
618+
slurm_gpus: 1
619+
612620
- label: "Unit: hyb ops 3d"
613621
key: unit_hyb_ops_3d
614622
command:

ext/cuda/data_layouts_mapreduce.jl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,7 @@ end
2424
function mapreduce_cuda(
2525
f,
2626
op,
27-
data::Union{
28-
DataLayouts.VF,
29-
DataLayouts.IJFH,
30-
DataLayouts.IJHF,
31-
DataLayouts.VIJFH,
32-
DataLayouts.VIJHF,
33-
};
27+
data::DataLayouts.AbstractData;
3428
weighted_jacobian = OnesArray(parent(data)),
3529
opargs...,
3630
)
@@ -132,9 +126,9 @@ function mapreduce_cuda_kernel!(
132126
gidx = _get_gidx(tidx, bidx, effective_blksize)
133127
reduction = CUDA.CuStaticSharedArray(T, shmemsize)
134128
reduction[tidx] = 0
135-
(Nij, _, _, Nv, Nh) = DataLayouts.universal_size(us)
129+
(Ni, Nj, _, Nv, Nh) = DataLayouts.universal_size(us)
136130
Nf = 1 # a view into `fidx` always gives a size of Nf = 1
137-
nitems = Nv * Nij * Nij * Nf * Nh
131+
nitems = Nv * Ni * Nj * Nf * Nh
138132

139133
# load shmem
140134
if gidx nitems

ext/cuda/data_layouts_threadblock.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -213,15 +213,15 @@ end
213213
us::DataLayouts.UniversalSize,
214214
n_max_threads::Integer,
215215
)
216-
(Nij, _, _, _, Nh) = DataLayouts.universal_size(us)
216+
(Ni, Nj, _, _, Nh) = DataLayouts.universal_size(us)
217217
Nh_thread = min(
218-
Int(fld(n_max_threads, Nij * Nij)),
218+
Int(fld(n_max_threads, Ni * Nj)),
219219
maximum_allowable_threads()[3],
220220
Nh,
221221
)
222222
Nh_blocks = cld(Nh, Nh_thread)
223-
@assert prod((Nij, Nij, Nh_thread)) n_max_threads "threads,n_max_threads=($(prod((Nij, Nij, Nh_thread))),$n_max_threads)"
224-
return (; threads = (Nij, Nij, Nh_thread), blocks = (Nh_blocks,))
223+
@assert prod((Ni, Nj, Nh_thread)) n_max_threads "threads,n_max_threads=($(prod((Ni, Nj, Nh_thread))),$n_max_threads)"
224+
return (; threads = (Ni, Nj, Nh_thread), blocks = (Nh_blocks,))
225225
end
226226
@inline function columnwise_universal_index(us::UniversalSize)
227227
(i, j, th) = CUDA.threadIdx()
@@ -241,9 +241,9 @@ end
241241
n_max_threads::Integer;
242242
Nnames,
243243
)
244-
(Nij, _, _, _, Nh) = DataLayouts.universal_size(us)
245-
@assert prod((Nij, Nij, Nnames)) n_max_threads "threads,n_max_threads=($(prod((Nij, Nij, Nnames))),$n_max_threads)"
246-
return (; threads = (Nij, Nij, Nnames), blocks = (Nh,))
244+
(Ni, Nj, _, _, Nh) = DataLayouts.universal_size(us)
245+
@assert prod((Ni, Nj, Nnames)) n_max_threads "threads,n_max_threads=($(prod((Ni, Nj, Nnames))),$n_max_threads)"
246+
return (; threads = (Ni, Nj, Nnames), blocks = (Nh,))
247247
end
248248
@inline function multiple_field_solve_universal_index(us::UniversalSize)
249249
(i, j, iname) = CUDA.threadIdx()
@@ -258,12 +258,12 @@ end
258258
us::DataLayouts.UniversalSize,
259259
n_max_threads::Integer = 256;
260260
)
261-
(Nq, _, _, Nv, Nh) = DataLayouts.universal_size(us)
262-
Nvthreads = min(fld(n_max_threads, Nq * Nq), maximum_allowable_threads()[3])
261+
(Ni, Nj, _, Nv, Nh) = DataLayouts.universal_size(us)
262+
Nvthreads = min(fld(n_max_threads, Ni * Nj), maximum_allowable_threads()[3])
263263
Nvblocks = cld(Nv, Nvthreads)
264-
@assert prod((Nq, Nq, Nvthreads)) n_max_threads "threads,n_max_threads=($(prod((Nq, Nq, Nvthreads))),$n_max_threads)"
265-
@assert Nq * Nq n_max_threads
266-
return (; threads = (Nq, Nq, Nvthreads), blocks = (Nh, Nvblocks), Nvthreads)
264+
@assert prod((Ni, Nj, Nvthreads)) n_max_threads "threads,n_max_threads=($(prod((Ni, Nj, Nvthreads))),$n_max_threads)"
265+
@assert Ni * Nj n_max_threads
266+
return (; threads = (Ni, Nj, Nvthreads), blocks = (Nh, Nvblocks), Nvthreads)
267267
end
268268
@inline function spectral_universal_index(space::Spaces.AbstractSpace)
269269
i = threadIdx().x

0 commit comments

Comments
 (0)