Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ext/cuda/data_layouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ Base.similar(
dims::Dims{N},
) where {T, N, B} = similar(CUDA.CuArray{T, N, B}, dims)

unval(::Val{CI}) where {CI} = CI
unval(CI) = CI

@inline linear_thread_idx() =
threadIdx().x + (blockIdx().x - Int32(1)) * blockDim().x

include("data_layouts_fill.jl")
include("data_layouts_copyto.jl")
include("data_layouts_fused_copyto.jl")
Expand Down
39 changes: 25 additions & 14 deletions ext/cuda/data_layouts_copyto.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
DataLayouts.device_dispatch(x::CUDA.CuArray) = ToCUDA()

function knl_copyto!(dest, src, us, mask)
I = if mask isa NoMask
universal_index(dest)
else
masked_universal_index(mask)
end
if is_valid_index(dest, I, us)
function knl_copyto!(dest, src, us, mask, cart_inds)
tidx = linear_thread_idx()
if linear_is_valid_index(tidx, us) && tidx ≤ length(unval(cart_inds))
I = if mask isa NoMask
unval(cart_inds)[tidx]
else
masked_universal_index(mask, cart_inds)
end
@inbounds dest[I] = src[I]
end
return nothing
end

function knl_copyto_linear!(dest, src, us)
i = threadIdx().x + (blockIdx().x - Int32(1)) * blockDim().x
i = linear_thread_idx()
if linear_is_valid_index(i, us)
@inbounds dest[i] = src[i]
end
Expand All @@ -32,13 +33,18 @@ if VERSION ≥ v"1.11.0-beta"
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
us = DataLayouts.UniversalSize(dest)
if Nv > 0 && Nh > 0
args = (dest, bc, us, mask)
cart_inds = if mask isa NoMask
cartesian_indices(us)
else
cartesian_indicies_mask(us, mask)
end
args = (dest, bc, us, mask, cart_inds)
threads = threads_via_occupancy(knl_copyto!, args)
n_max_threads = min(threads, get_N(us))
p = if mask isa NoMask
partition(dest, n_max_threads)
linear_partition(prod(size(dest)), n_max_threads)
else
masked_partition(us, n_max_threads, mask)
masked_partition(mask, n_max_threads, us)
end
auto_launch!(
knl_copyto!,
Expand Down Expand Up @@ -72,13 +78,18 @@ else
blocks_s = p.blocks,
)
else
args = (dest, bc, us, mask)
cart_inds = if mask isa NoMask
cartesian_indices(us)
else
cartesian_indicies_mask(us, mask)
end
args = (dest, bc, us, mask, cart_inds)
threads = threads_via_occupancy(knl_copyto!, args)
n_max_threads = min(threads, get_N(us))
p = if mask isa NoMask
partition(dest, n_max_threads)
linear_partition(prod(size(dest)), n_max_threads)
else
masked_partition(us, n_max_threads, mask)
masked_partition(mask, n_max_threads, us)
end
auto_launch!(
knl_copyto!,
Expand Down
30 changes: 18 additions & 12 deletions ext/cuda/data_layouts_fill.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
function knl_fill!(dest, val, us, mask)
I = if mask isa NoMask
universal_index(dest)
else
masked_universal_index(mask)
end
if is_valid_index(dest, I, us)
function knl_fill!(dest, val, us, mask, cart_inds)
tidx = linear_thread_idx()
if linear_is_valid_index(tidx, us) && tidx ≤ length(unval(cart_inds))
I = if mask isa NoMask
unval(cart_inds)[tidx]
else
masked_universal_index(mask, cart_inds)
end
@inbounds dest[I] = val
end
return nothing
end

function knl_fill_linear!(dest, val, us)
i = threadIdx().x + (blockIdx().x - Int32(1)) * blockDim().x
i = linear_thread_idx()
if linear_is_valid_index(i, us)
@inbounds dest[i] = val
end
return nothing
end

function Base.fill!(dest::AbstractData, bc, to::ToCUDA, mask = NoMask())
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
(Ni, Nj, Nv, _, Nh) = DataLayouts.universal_size(dest)
us = DataLayouts.UniversalSize(dest)
if Nv > 0 && Nh > 0
if !(VERSION ≥ v"1.11.0-beta") &&
Expand All @@ -36,13 +37,18 @@ function Base.fill!(dest::AbstractData, bc, to::ToCUDA, mask = NoMask())
blocks_s = p.blocks,
)
else
args = (dest, bc, us, mask)
cart_inds = if mask isa NoMask
cartesian_indices(us)
else
cartesian_indicies_mask(us, mask)
end
args = (dest, bc, us, mask, cart_inds)
threads = threads_via_occupancy(knl_fill!, args)
n_max_threads = min(threads, get_N(us))
p = if mask isa NoMask
partition(dest, n_max_threads)
linear_partition(prod(size(dest)), n_max_threads)
else
masked_partition(us, n_max_threads, mask)
masked_partition(mask, n_max_threads, us)
end
auto_launch!(
knl_fill!,
Expand Down
41 changes: 25 additions & 16 deletions ext/cuda/data_layouts_fused_copyto.jl
Original file line number Diff line number Diff line change
@@ -1,36 +1,44 @@
Base.@propagate_inbounds function rcopyto_at!(
pair::Pair{<:AbstractData, <:Any},
I,
cart_inds,
tidx,
us,
)
dest, bc = pair.first, pair.second
if is_valid_index(dest, I, us)
if linear_is_valid_index(tidx, us) && tidx ≤ length(unval(cart_inds))
I = unval(cart_inds)[tidx]
dest[I] = isascalar(bc) ? bc[] : bc[I]
end
return nothing
end
Base.@propagate_inbounds function rcopyto_at!(pair::Pair{<:DataF, <:Any}, I, us)
Base.@propagate_inbounds function rcopyto_at!(
pair::Pair{<:DataF, <:Any},
cart_inds,
tidx,
us,
)
dest, bc = pair.first, pair.second
if is_valid_index(dest, I, us)
if linear_is_valid_index(tidx, us) && tidx ≤ length(unval(cart_inds))
I = unval(cart_inds)[tidx]
bcI = isascalar(bc) ? bc[] : bc[I]
dest[] = bcI
end
return nothing
end
Base.@propagate_inbounds function rcopyto_at!(pairs::Tuple, I, us)
rcopyto_at!(first(pairs), I, us)
rcopyto_at!(Base.tail(pairs), I, us)
Base.@propagate_inbounds function rcopyto_at!(pairs::Tuple, cart_inds, tidx, us)
rcopyto_at!(first(pairs), cart_inds, tidx, us)
rcopyto_at!(Base.tail(pairs), cart_inds, tidx, us)
end
Base.@propagate_inbounds rcopyto_at!(pairs::Tuple{<:Any}, I, us) =
rcopyto_at!(first(pairs), I, us)
@inline rcopyto_at!(pairs::Tuple{}, I, us) = nothing
Base.@propagate_inbounds rcopyto_at!(pairs::Tuple{<:Any}, cart_inds, tidx, us) =
rcopyto_at!(first(pairs), cart_inds, tidx, us)
@inline rcopyto_at!(pairs::Tuple{}, cart_inds, tidx, us) = nothing

function knl_fused_copyto!(fmbc::FusedMultiBroadcast, dest1, us)
function knl_fused_copyto!(fmbc::FusedMultiBroadcast, dest1, us, cart_inds)
@inbounds begin
I = universal_index(dest1)
if is_valid_index(dest1, I, us)
tidx = linear_thread_idx()
if linear_is_valid_index(tidx, us) && tidx ≤ length(unval(cart_inds))
(; pairs) = fmbc
rcopyto_at!(pairs, I, us)
rcopyto_at!(pairs, cart_inds, tidx, us)
end
end
return nothing
Expand Down Expand Up @@ -138,10 +146,11 @@ function launch_fused_copyto!(fmb::FusedMultiBroadcast)
blocks_s = p.blocks,
)
else
args = (fmb, dest1, us)
cart_inds = cartesian_indices(us)
args = (fmb, dest1, us, cart_inds)
threads = threads_via_occupancy(knl_fused_copyto!, args)
n_max_threads = min(threads, get_N(us))
p = partition(dest1, n_max_threads)
p = linear_partition(prod(size(dest1)), n_max_threads)
auto_launch!(
knl_fused_copyto!,
args;
Expand Down
Loading
Loading