Skip to content

Commit 71704ec

Browse files
charleskawczynskiCharlie Kawczynski
authored andcommitted
Tune kernels for use with FastCartesianIndices
1 parent f2c91ee commit 71704ec

11 files changed

+149
-506
lines changed

.buildkite/pipeline.yml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -137,16 +137,6 @@ steps:
137137
agents:
138138
slurm_gpus: 1
139139

140-
- label: "Unit: data cuda threadblocks"
141-
key: unit_data_threadblock
142-
command:
143-
- "julia --project=.buildkite -e 'using CUDA; CUDA.versioninfo()'"
144-
- "julia --color=yes --check-bounds=yes --project=.buildkite test/DataLayouts/unit_cuda_threadblocks.jl"
145-
env:
146-
CLIMACOMMS_DEVICE: "CUDA"
147-
agents:
148-
slurm_gpus: 1
149-
150140
- label: "Unit: data fill"
151141
key: gpu_unit_data_fill
152142
command:

ext/cuda/data_layouts.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ Base.similar(
2727
dims::Dims{N},
2828
) where {T, N, B} = similar(CUDA.CuArray{T, N, B}, dims)
2929

30+
unval(::Val{CI}) where CI = CI
31+
unval(CI) = CI
32+
33+
@inline linear_thread_idx() =
34+
threadIdx().x + (blockIdx().x - Int32(1)) * blockDim().x
35+
3036
include("data_layouts_fill.jl")
3137
include("data_layouts_copyto.jl")
3238
include("data_layouts_fused_copyto.jl")

ext/cuda/data_layouts_copyto.jl

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
DataLayouts.device_dispatch(x::CUDA.CuArray) = ToCUDA()
22

3-
function knl_copyto!(dest, src, us, mask)
4-
I = if mask isa NoMask
5-
universal_index(dest)
6-
else
7-
masked_universal_index(mask)
8-
end
9-
if is_valid_index(dest, I, us)
3+
function knl_copyto!(dest, src, us, mask, cart_inds)
4+
tidx = linear_thread_idx()
5+
if linear_is_valid_index(tidx, us) && tidx length(unval(cart_inds))
6+
I = if mask isa NoMask
7+
unval(cart_inds)[tidx]
8+
else
9+
masked_universal_index(mask, cart_inds)
10+
end
1011
@inbounds dest[I] = src[I]
1112
end
1213
return nothing
1314
end
1415

1516
function knl_copyto_linear!(dest, src, us)
16-
i = threadIdx().x + (blockIdx().x - Int32(1)) * blockDim().x
17+
i = linear_thread_idx()
1718
if linear_is_valid_index(i, us)
1819
@inbounds dest[i] = src[i]
1920
end
@@ -32,13 +33,18 @@ if VERSION ≥ v"1.11.0-beta"
3233
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
3334
us = DataLayouts.UniversalSize(dest)
3435
if Nv > 0 && Nh > 0
35-
args = (dest, bc, us, mask)
36+
cart_inds = if mask isa NoMask
37+
cartesian_indices(us)
38+
else
39+
cartesian_indicies_mask(us, mask)
40+
end
41+
args = (dest, bc, us, mask, cart_inds)
3642
threads = threads_via_occupancy(knl_copyto!, args)
3743
n_max_threads = min(threads, get_N(us))
3844
p = if mask isa NoMask
39-
partition(dest, n_max_threads)
45+
linear_partition(prod(size(dest)), n_max_threads)
4046
else
41-
masked_partition(us, n_max_threads, mask)
47+
masked_partition(mask, n_max_threads, us)
4248
end
4349
auto_launch!(
4450
knl_copyto!,
@@ -72,13 +78,18 @@ else
7278
blocks_s = p.blocks,
7379
)
7480
else
75-
args = (dest, bc, us, mask)
81+
cart_inds = if mask isa NoMask
82+
cartesian_indices(us)
83+
else
84+
cartesian_indicies_mask(us, mask)
85+
end
86+
args = (dest, bc, us, mask, cart_inds)
7687
threads = threads_via_occupancy(knl_copyto!, args)
7788
n_max_threads = min(threads, get_N(us))
7889
p = if mask isa NoMask
79-
partition(dest, n_max_threads)
90+
linear_partition(prod(size(dest)), n_max_threads)
8091
else
81-
masked_partition(us, n_max_threads, mask)
92+
masked_partition(mask, n_max_threads, us)
8293
end
8394
auto_launch!(
8495
knl_copyto!,

ext/cuda/data_layouts_fill.jl

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,26 @@
1-
function knl_fill!(dest, val, us, mask)
2-
I = if mask isa NoMask
3-
universal_index(dest)
4-
else
5-
masked_universal_index(mask)
6-
end
7-
if is_valid_index(dest, I, us)
1+
function knl_fill!(dest, val, us, mask, cart_inds)
2+
tidx = linear_thread_idx()
3+
if linear_is_valid_index(tidx, us) && tidx length(unval(cart_inds))
4+
I = if mask isa NoMask
5+
unval(cart_inds)[tidx]
6+
else
7+
masked_universal_index(mask, cart_inds)
8+
end
89
@inbounds dest[I] = val
910
end
1011
return nothing
1112
end
1213

1314
function knl_fill_linear!(dest, val, us)
14-
i = threadIdx().x + (blockIdx().x - Int32(1)) * blockDim().x
15+
i = linear_thread_idx()
1516
if linear_is_valid_index(i, us)
1617
@inbounds dest[i] = val
1718
end
1819
return nothing
1920
end
2021

2122
function Base.fill!(dest::AbstractData, bc, to::ToCUDA, mask = NoMask())
22-
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
23+
(Ni, Nj, Nv, _, Nh) = DataLayouts.universal_size(dest)
2324
us = DataLayouts.UniversalSize(dest)
2425
if Nv > 0 && Nh > 0
2526
if !(VERSION v"1.11.0-beta") &&
@@ -36,13 +37,18 @@ function Base.fill!(dest::AbstractData, bc, to::ToCUDA, mask = NoMask())
3637
blocks_s = p.blocks,
3738
)
3839
else
39-
args = (dest, bc, us, mask)
40+
cart_inds = if mask isa NoMask
41+
cartesian_indices(us)
42+
else
43+
cartesian_indicies_mask(us, mask)
44+
end
45+
args = (dest, bc, us, mask, cart_inds)
4046
threads = threads_via_occupancy(knl_fill!, args)
4147
n_max_threads = min(threads, get_N(us))
4248
p = if mask isa NoMask
43-
partition(dest, n_max_threads)
49+
linear_partition(prod(size(dest)), n_max_threads)
4450
else
45-
masked_partition(us, n_max_threads, mask)
51+
masked_partition(mask, n_max_threads, us)
4652
end
4753
auto_launch!(
4854
knl_fill!,

ext/cuda/data_layouts_fused_copyto.jl

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,39 @@
11
Base.@propagate_inbounds function rcopyto_at!(
22
pair::Pair{<:AbstractData, <:Any},
3-
I,
3+
cart_inds,
4+
tidx,
45
us,
56
)
67
dest, bc = pair.first, pair.second
7-
if is_valid_index(dest, I, us)
8+
if linear_is_valid_index(tidx, us) && tidx length(unval(cart_inds))
9+
I = unval(cart_inds)[tidx]
810
dest[I] = isascalar(bc) ? bc[] : bc[I]
911
end
1012
return nothing
1113
end
12-
Base.@propagate_inbounds function rcopyto_at!(pair::Pair{<:DataF, <:Any}, I, us)
14+
Base.@propagate_inbounds function rcopyto_at!(pair::Pair{<:DataF, <:Any}, cart_inds, tidx, us)
1315
dest, bc = pair.first, pair.second
14-
if is_valid_index(dest, I, us)
16+
if linear_is_valid_index(tidx, us) && tidx length(unval(cart_inds))
17+
I = unval(cart_inds)[tidx]
1518
bcI = isascalar(bc) ? bc[] : bc[I]
1619
dest[] = bcI
1720
end
1821
return nothing
1922
end
20-
Base.@propagate_inbounds function rcopyto_at!(pairs::Tuple, I, us)
21-
rcopyto_at!(first(pairs), I, us)
22-
rcopyto_at!(Base.tail(pairs), I, us)
23+
Base.@propagate_inbounds function rcopyto_at!(pairs::Tuple, cart_inds, tidx, us)
24+
rcopyto_at!(first(pairs), cart_inds, tidx, us)
25+
rcopyto_at!(Base.tail(pairs), cart_inds, tidx, us)
2326
end
24-
Base.@propagate_inbounds rcopyto_at!(pairs::Tuple{<:Any}, I, us) =
25-
rcopyto_at!(first(pairs), I, us)
26-
@inline rcopyto_at!(pairs::Tuple{}, I, us) = nothing
27+
Base.@propagate_inbounds rcopyto_at!(pairs::Tuple{<:Any}, cart_inds, tidx, us) =
28+
rcopyto_at!(first(pairs), cart_inds, tidx, us)
29+
@inline rcopyto_at!(pairs::Tuple{}, cart_inds, tidx, us) = nothing
2730

28-
function knl_fused_copyto!(fmbc::FusedMultiBroadcast, dest1, us)
31+
function knl_fused_copyto!(fmbc::FusedMultiBroadcast, dest1, us, cart_inds)
2932
@inbounds begin
30-
I = universal_index(dest1)
31-
if is_valid_index(dest1, I, us)
33+
tidx = linear_thread_idx()
34+
if linear_is_valid_index(tidx, us) && tidx length(unval(cart_inds))
3235
(; pairs) = fmbc
33-
rcopyto_at!(pairs, I, us)
36+
rcopyto_at!(pairs, cart_inds, tidx, us)
3437
end
3538
end
3639
return nothing
@@ -138,10 +141,11 @@ function launch_fused_copyto!(fmb::FusedMultiBroadcast)
138141
blocks_s = p.blocks,
139142
)
140143
else
141-
args = (fmb, dest1, us)
144+
cart_inds = cartesian_indices(us)
145+
args = (fmb, dest1, us, cart_inds)
142146
threads = threads_via_occupancy(knl_fused_copyto!, args)
143147
n_max_threads = min(threads, get_N(us))
144-
p = partition(dest1, n_max_threads)
148+
p = linear_partition(prod(size(dest1)), n_max_threads)
145149
auto_launch!(
146150
knl_fused_copyto!,
147151
args;

0 commit comments

Comments
 (0)