Skip to content

Commit 5a83d5b

Browse files
committed
Remove redundant array argument.
1 parent 4dbbfc6 commit 5a83d5b

File tree

11 files changed

+47
-38
lines changed

11 files changed

+47
-38
lines changed

src/device/execution.jl

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,42 @@ abstract type AbstractGPUBackend end
77
abstract type AbstractKernelContext end
88

99
"""
10-
backend(T::Type{<:AbstractArray})
10+
backend(T::Type)
11+
backend(x)
1112
1213
Gets the GPUArrays back-end responsible for managing arrays of type `T`.
1314
"""
14-
backend(::Type{<:AbstractArray}) = error("This array is not a GPU array") # COV_EXCL_LINE
15+
backend(::Type) = error("This object is not a GPU array") # COV_EXCL_LINE
16+
backend(x) = backend(typeof(x))
1517

1618
"""
17-
gpu_call(kernel::Function, A::AbstractGPUArray, args...; kwargs...)
19+
gpu_call(kernel::Function, arg0, args...; kwargs...)
1820
19-
Calls function `kernel` on the GPU device that backs array `A`, passing along arguments
20-
`args`. The keyword arguments `kwargs` are not passed along, but are interpreted on the host
21-
to influence how the kernel is executed. The following keyword arguments are supported:
21+
Executes `kernel` on the device that backs `arg` (see [`backend`](@ref)), passing along any
22+
arguments `args`. Additionally, the kernel will be passed the kernel execution context (see
23+
[`AbstractKernelContext`]), so its signature should be `(ctx::AbstractKernelContext, arg0,
24+
args...)`.
2225
26+
The keyword arguments `kwargs` are not passed to the function, but are interpreted on the
27+
host to influence how the kernel is executed. The following keyword arguments are supported:
28+
29+
- `target::AbstractArray`: specify which array object to use for determining execution
30+
properties (defaults to the first argument `arg0`).
2331
- `total_threads::Int`: how many threads should be launched _in total_. The actual number of
24-
threads and blocks is determined using a heuristic. Defaults to the length of `A` if no
25-
other keyword arguments that influence the launch configuration are specified.
32+
threads and blocks is determined using a heuristic. Defaults to the length of `arg0` if
33+
no other keyword arguments that influence the launch configuration are specified.
2634
- `threads::Int` and `blocks::Int`: configure exactly how many threads and blocks are
27-
launched. This cannot be used in combination with the `total_threads` argument.
35+
launched. This cannot be used in combination with the `total_threads` argument.
2836
"""
29-
function gpu_call(kernel::Base.Callable, A::AbstractArray, args...;
37+
function gpu_call(kernel::Base.Callable, args...;
38+
target::AbstractArray=first(args),
3039
total_threads::Union{Int,Nothing}=nothing,
3140
threads::Union{Int,Nothing}=nothing,
3241
blocks::Union{Int,Nothing}=nothing,
3342
kwargs...)
3443
# determine how many threads/blocks to launch
3544
if total_threads===nothing && threads===nothing && blocks===nothing
36-
total_threads = length(A)
45+
total_threads = length(target)
3746
end
3847
if total_threads !== nothing
3948
if threads !== nothing || blocks !== nothing
@@ -49,7 +58,7 @@ function gpu_call(kernel::Base.Callable, A::AbstractArray, args...;
4958
end
5059
end
5160

52-
gpu_call(backend(typeof(A)), kernel, args...; threads=threads, blocks=blocks, kwargs...)
61+
gpu_call(backend(target), kernel, args...; threads=threads, blocks=blocks, kwargs...)
5362
end
5463

5564
gpu_call(backend::AbstractGPUBackend, kernel, args...; kwargs...) = error("Not implemented") # COV_EXCL_LINE

src/host/abstractarray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ function Base.copyto!(dest::AbstractGPUArray{T, N}, destcrange::CartesianIndices
161161

162162
dest_offsets = first.(destcrange.indices) .- 1
163163
src_offsets = first.(srccrange.indices) .- 1
164-
gpu_call(copy_kernel!, dest,
164+
gpu_call(copy_kernel!,
165165
dest, dest_offsets, src, src_offsets, shape, size(dest), size(src), len;
166166
total_threads=len)
167167
dest

src/host/base.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ end
6262
function Base.repeat(a::AbstractGPUVecOrMat, m::Int, n::Int = 1)
6363
o, p = size(a, 1), size(a, 2)
6464
b = similar(a, o*m, p*n)
65-
gpu_call(a, b, a, o, p, m, n; total_threads=n) do ctx, b, a, o, p, m, n
65+
gpu_call(b, a, o, p, m, n; target=a, total_threads=n) do ctx, b, a, o, p, m, n
6666
j = linear_index(ctx)
6767
j > n && return
6868
d = (j - 1) * p + 1
@@ -82,7 +82,7 @@ end
8282
function Base.repeat(a::AbstractGPUVector, m::Int)
8383
o = length(a)
8484
b = similar(a, o*m)
85-
gpu_call(a, b, a, o, m; total_threads=m) do ctx, b, a, o, m
85+
gpu_call(b, a, o, m; target=a, total_threads=m) do ctx, b, a, o, m
8686
i = linear_index(ctx)
8787
i > m && return
8888
c = (i - 1)*o + 1

src/host/broadcast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ end
4747
@inline function Base.copyto!(dest::GPUDestArray, bc::Broadcasted{Nothing})
4848
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
4949
bc′ = Broadcast.preprocess(dest, bc)
50-
gpu_call(dest, dest, bc′) do ctx, dest, bc′
50+
gpu_call(dest, bc′) do ctx, dest, bc′
5151
let I = CartesianIndex(@cartesianidx(dest))
5252
@inbounds dest[I] = bc′[I]
5353
end

src/host/construction.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ function Base.fill(X::Type{<: AbstractGPUArray{T}}, val, dims::NTuple{N, Integer
99
fill!(res, convert(T, val))
1010
end
1111
function Base.fill!(A::AbstractGPUArray{T}, x) where T
12-
gpu_call(A, A, convert(T, x)) do ctx, a, val
12+
gpu_call(A, convert(T, x)) do ctx, a, val
1313
idx = @linearidx(a, ctx)
1414
@inbounds a[idx] = val
1515
return
@@ -30,7 +30,7 @@ end
3030

3131
function (T::Type{<: AbstractGPUArray})(s::UniformScaling, dims::Dims{2})
3232
res = zeros(T, dims)
33-
gpu_call(uniformscaling_kernel, res, res, size(res, 1), s; total_threads=minimum(dims))
33+
gpu_call(uniformscaling_kernel, res, size(res, 1), s; total_threads=minimum(dims))
3434
res
3535
end
3636
(T::Type{<: AbstractGPUArray})(s::UniformScaling, m::Integer, n::Integer) = T(s, Dims((m, n)))
@@ -67,7 +67,7 @@ function Base.convert(AT::Type{<: AbstractGPUArray}, iter)
6767
if isbits(iter) && isa(isize, Base.HasShape) && style != nothing && isa(ettrait, Base.HasEltype)
6868
# We can collect on the GPU
6969
A = similar(AT, eltype_or(AT, eltype(iter)), size(iter))
70-
gpu_call(collect_kernel, A, A, iter, style)
70+
gpu_call(collect_kernel, A, iter, style)
7171
A
7272
else
7373
convert(AT, collect(iter))

src/host/indexing.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ function Base._unsafe_getindex!(dest::AbstractGPUArray, src::AbstractGPUArray, I
9898
return dest
9999
end
100100
idims = map(length, Is)
101-
gpu_call(index_kernel, dest, dest, src, idims, map(x-> to_index(dest, x), Is))
101+
gpu_call(index_kernel, dest, src, idims, map(x-> to_index(dest, x), Is))
102102
return dest
103103
end
104104

@@ -125,7 +125,7 @@ function Base._unsafe_setindex!(::IndexStyle, dest::T, src, Is::Union{Real, Abst
125125
idims = length.(Is)
126126
len = prod(idims)
127127
src_gpu = adapt(T, src)
128-
gpu_call(setindex_kernel!, dest, dest, src_gpu, idims, map(x-> to_index(dest, x), Is), len;
128+
gpu_call(setindex_kernel!, dest, src_gpu, idims, map(x-> to_index(dest, x), Is), len;
129129
total_threads=len)
130130
return dest
131131
end

src/host/linalg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ end
115115
## high-level functionality
116116

117117
function LinearAlgebra.transpose!(At::AbstractGPUArray{T, 2}, A::AbstractGPUArray{T, 2}) where T
118-
gpu_call(At, At, A) do ctx, At, A
118+
gpu_call(At, A) do ctx, At, A
119119
idx = @cartesianidx A ctx
120120
@inbounds At[idx[2], idx[1]] = A[idx[1], idx[2]]
121121
return
@@ -129,7 +129,7 @@ end
129129

130130
function LinearAlgebra.permutedims!(dest::AbstractGPUArray, src::AbstractGPUArray, perm) where N
131131
perm isa Tuple || (perm = Tuple(perm))
132-
gpu_call(dest, dest, src, perm) do ctx, dest, src, perm
132+
gpu_call(dest, src, perm) do ctx, dest, src, perm
133133
I = @cartesianidx src ctx
134134
@inbounds dest[genperm(I, perm)...] = src[I...]
135135
return

src/host/mapreduce.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ end
118118

119119
function Base._mapreducedim!(f, op, R::AbstractGPUArray, A::GPUSrcArray)
120120
range = ifelse.(length.(axes(R)) .== 1, axes(A), nothing)
121-
gpu_call(mapreducedim_kernel, R, f, op, R, A, range)
121+
gpu_call(mapreducedim_kernel, f, op, R, A, range; target=R)
122122
return R
123123
end
124124

@@ -174,8 +174,8 @@ function acc_mapreduce(f, op, v0::OT, A::GPUSrcArray, rest...) where {OT}
174174
end
175175
out = similar(A, OT, (blocks,))
176176
fill!(out, v0)
177-
gpu_call(reduce_kernel, out, f, op, v0, A, Val{threads}(), out, rest...;
178-
threads=threads, blocks=blocks)
177+
gpu_call(reduce_kernel, f, op, v0, A, Val{threads}(), out, rest...;
178+
target=out, threads=threads, blocks=blocks)
179179
reduce(op, Array(out))
180180
end
181181

src/host/random.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ function global_rng(A::AbstractGPUArray)
7777
end
7878

7979
function Random.rand!(rng::RNG, A::AbstractGPUArray{T}) where T <: Number
80-
gpu_call(A, rng.state, A) do ctx, randstates, a
80+
gpu_call(rng.state, A; target=A) do ctx, randstates, a
8181
idx = linear_index(ctx)
8282
idx > length(a) && return
8383
@inbounds a[idx] = gpu_rand(T, ctx, randstates)

test/testsuite/base.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,25 +108,25 @@ function test_base(AT)
108108

109109
@testset "ntuple test" begin
110110
result = AT(Vector{NTuple{3, Float32}}(undef, 1))
111-
gpu_call(ntuple_test, result, result, Val(3))
111+
gpu_call(ntuple_test, result, Val(3))
112112
@test Array(result)[1] == (77, 2*77, 3*77)
113113
x = 88f0
114-
gpu_call(ntuple_closure, result, result, Val(3), x)
114+
gpu_call(ntuple_closure, result, Val(3), x)
115115
@test Array(result)[1] == (x, 2*x, 3*x)
116116
end
117117

118118
@testset "cartesian iteration" begin
119119
Ac = rand(Float32, 32, 32)
120120
A = AT(Ac)
121121
result = fill!(copy(A), 0.0)
122-
gpu_call(cartesian_iter, result, A, result, size(A))
122+
gpu_call(cartesian_iter, A, result, size(A); target=result)
123123
Array(result) == Ac
124124
end
125125

126126
@testset "Custom kernel from Julia function" begin
127127
x = AT(rand(Float32, 100))
128128
y = AT(rand(Float32, 100))
129-
gpu_call(clmap!, x, -, x, y)
129+
gpu_call(clmap!, -, x, y; target=x)
130130
jy = Array(y)
131131
@test map!(-, jy, jy) Array(x)
132132
end

0 commit comments

Comments
 (0)