Skip to content

Commit 2251ae9

Browse files
MikeInnesSimonDanisch
authored andcommitted
use CuArrays indexing
1 parent 6d18ada commit 2251ae9

File tree

2 files changed

+55
-48
lines changed

2 files changed

+55
-48
lines changed

src/construction.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,6 @@ end
4343
function convert(AT::Type{<: GPUArray{T1}}, A::DenseArray{T2}) where {T1, T2}
4444
copy!(similar(AT, T1, size(A)), T1.(A))
4545
end
46-
using Colors
47-
function convert(AT::Type{<: GPUArray{T1}}, A::DenseArray{T2}) where {T1 <: Colorant, T2 <: Colorant}
48-
copy!(similar(AT, T1, size(A)), T1.(A))
49-
end
5046
function convert(AT::Type{<: GPUArray}, A::DenseArray{T2, N}) where {T2, N}
5147
copy!(similar(AT, T2, size(A)), A)
5248
end

src/indexing.jl

Lines changed: 55 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,64 @@
1-
array_convert{T, N}(t::Type{Array{T, N}}, x::Array) = convert(t, x)
2-
array_convert{T, N}(t::Type{Array{T, N}}, x::T) = [x]
1+
const _allowslow = Ref(true)
32

4-
function array_convert{T, N, T2}(t::Type{Array{T, N}}, x::T2)
5-
arr = collect(x) # iterator
6-
dims = ntuple(Val{N}) do i
7-
ifelse(ndims(arr) >= i, size(arr, i), 1)
8-
end
9-
return reshape(map(T, arr), dims) # broadcast dims
3+
allowslow(flag = true) = (_allowslow[] = flag)
4+
5+
function assertslow(op = "Operation")
6+
_allowslow[] || error("$op is disabled")
7+
return
108
end
119

12-
indexlength(A, i, array::AbstractArray) = length(array)
13-
indexlength(A, i, array::Number) = 1
14-
indexlength(A, i, array::Colon) = size(A, i)
10+
Base.IndexStyle(::Type{<:GPUArray}) = IndexLinear()
1511

16-
function Base.setindex!{T, N}(A::GPUArray{T, N}, value, indexes...)
17-
# similarly, value should always be a julia array
18-
shape = ntuple(Val{N}) do i
19-
indexlength(A, i, indexes[i])
20-
end
21-
if !isa(value, T) # TODO, shape check errors for x[1:3] = 1
22-
Base.setindex_shape_check(value, indexes...)
12+
function _getindex(xs::GPUArray{T}, i::Integer) where T
13+
x = Array{T}(1)
14+
copy!(x, 1, xs, i, 1)
15+
return x[1]
16+
end
17+
18+
function Base.getindex(xs::GPUArray{T}, i::Integer) where T
19+
assertslow("getindex")
20+
_getindex(xs, i)
21+
end
22+
23+
function _setindex!(xs::GPUArray{T}, v::T, i::Integer) where T
24+
x = T[v]
25+
copy!(xs, i, x, 1, 1)
26+
return v
27+
end
28+
29+
function Base.setindex!(xs::GPUArray{T}, v::T, i::Integer) where T
30+
assertslow("setindex!")
31+
_setindex!(xs, v, i)
32+
end
33+
34+
Base.setindex!(xs::GPUArray, v, i::Integer) = xs[i] = convert(eltype(xs), v)
35+
36+
# Vector indexing
37+
38+
using Base.Cartesian
39+
to_index(a, x) = x
40+
to_index(::A, x::Array{ET}) where {A, ET} = copy!(similar(A, ET, size(x)), x)
41+
to_index(a, x::UnitRange{<: Integer}) = convert(UnitRange{Cuint}, x)
42+
to_index(a, x::Base.LogicalIndex) = error("Logical indexing not implemented")
43+
44+
@generated function index_kernel(state, dest::AbstractArray, src::AbstractArray, idims, Is)
45+
N = length(Is.parameters)
46+
quote
47+
i = linear_index(state)
48+
i > length(dest) && return
49+
is = gpu_ind2sub(idims, i)
50+
@nexprs $N i -> @inbounds I_i = Is[i][Int(is[i])]
51+
@inbounds dest[i] = @ncall $N getindex src i -> I_i
52+
return
2353
end
24-
checkbounds(A, indexes...)
25-
v = array_convert(Array{T, N}, value)
26-
# since you shouldn't update GPUArrays with single indices, we simplify the interface
27-
# by always mapping to ranges
28-
ranges_dest = to_cartesian(A, indexes)
29-
ranges_src = CartesianRange(size(v))
30-
31-
copy!(A, ranges_dest, v, ranges_src)
32-
return
3354
end
3455

35-
Base.getindex{T}(A::GPUArray{T, 0}) = Array(A)[]
36-
37-
function Base.getindex{T, N}(A::GPUArray{T, N}, indexes...)
38-
cindexes = Base.to_indices(A, indexes)
39-
# similarly, value should always be a julia array
40-
# We shouldn't really bother about checkbounds performance, since setindex/getindex will always be relatively slow
41-
checkbounds(A, cindexes...)
42-
43-
shape = map(length, cindexes)
44-
result = Array{T, length(shape)}(shape)
45-
ranges_src = to_cartesian(A, cindexes)
46-
ranges_dest = CartesianRange(shape)
47-
copy!(result, ranges_dest, A, ranges_src)
48-
if all(i-> isa(i, Integer), cindexes) # scalar
49-
return result[]
50-
else
51-
return result
56+
57+
function Base._unsafe_getindex!(dest::GPUArray, src::GPUArray, Is::Union{Real, AbstractArray}...)
58+
if length(Is) == 1 && isa(first(Is), Array) && isempty(first(Is)) # indexing with empty array
59+
return dest
5260
end
61+
idims = map(length, Is)
62+
gpu_call(index_kernel, dest, (dest, src, Cuint.(idims), map(x-> to_index(dest, x), Is)))
63+
return dest
5364
end

0 commit comments

Comments
 (0)