Skip to content

Commit 4c3854e

Browse files
committed
allow construction from isbits iterators
1 parent c9bd1bb commit 4c3854e

File tree

3 files changed

+89
-2
lines changed

3 files changed

+89
-2
lines changed

src/construction.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,45 @@ similar(x::X, ::Type{T}, size::Base.Dims{N}) where {X <: GPUArray, T, N} = simil
4141

4242
convert(AT::Type{<: GPUArray{T, N}}, A::GPUArray{T, N}) where {T, N} = A
4343

44+
function indexstyle(x::T) where T
45+
style = try
46+
Base.IndexStyle(x)
47+
catch
48+
nothing
49+
end
50+
style
51+
end
52+
53+
function collect_kernel(state, A, iter, ::IndexCartesian)
54+
idx = @cartesianidx(A, state)
55+
@inbounds A[idx...] = iter[idx...]
56+
return
57+
end
58+
59+
function collect_kernel(state, A, iter, ::IndexLinear)
60+
idx = linear_index(state)
61+
@inbounds A[idx] = iter[idx]
62+
return
63+
end
64+
65+
eltype_or(::Type{<: GPUArray}, or) = or
66+
eltype_or(::Type{<: GPUArray{T}}, or) where T = T
67+
eltype_or(::Type{<: GPUArray{T, N}}, or) where {T, N} = T
68+
69+
function convert(AT::Type{<: GPUArray}, iter)
70+
isize = Base.iteratorsize(iter)
71+
style = indexstyle(iter)
72+
ettrait = Base.iteratoreltype(iter)
73+
if isbits(iter) && isize == Base.HasShape() && style != nothing && ettrait == Base.HasEltype()
74+
# We can collect on the GPU
75+
A = similar(AT, eltype_or(AT, eltype(iter)), size(iter))
76+
gpu_call(collect_kernel, A, (A, iter, style))
77+
A
78+
else
79+
convert(AT, collect(iter))
80+
end
81+
end
82+
4483
function convert(AT::Type{<: GPUArray{T, N}}, A::DenseArray{T, N}) where {T, N}
4584
copy!(AT(Base.size(A)), A)
4685
end

src/mapreduce.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ import Base: any, count, countnz
33
#############################
44
# reduce
55
# functions in base implemented with a direct loop need to be overloaded to use mapreduce
6-
any(pred, A::GPUArray) = Bool(mapreduce(pred, |, Cint(0), A))
6+
any(pred, A::GPUArray) = Bool(mapreduce(pred, |, Int32(0), A))
77
count(pred, A::GPUArray) = Int(mapreduce(pred, +, UInt32(0), A))
88
countnz(A::GPUArray) = Int(mapreduce(x-> x != 0, +, UInt32(0), A))
99
countnz(A::GPUArray, dim) = Int(mapreducedim(x-> x != 0, +, UInt32(0), A, dim))
1010

11+
Base.:(==)(A::GPUArray, B::GPUArray) = Bool(mapreduce(==, &, Int32(1), A, B))
1112

1213
# hack to get around of fetching the first element of the GPUArray
1314
# as a startvalue, which is a bit complicated with the current reduce implementation

src/testsuite/construction.jl

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,44 @@
11
using GPUArrays
22
using Base.Test, GPUArrays.TestSuite
33

4+
5+
6+
# It's kind of annoying to make FillArrays only a test dependency
7+
# so for texting the conversion to GPUArrays of shaped iterators,
8+
# I just copied the core types from FillArrays:s
9+
10+
abstract type AbstractFill{T, N} <: AbstractArray{T, N} end
11+
@inline function Base.getindex(F::AbstractFill, k::Integer)
12+
@boundscheck checkbounds(F, k)
13+
getindex_value(F)
14+
end
15+
@inline function Base.getindex(F::AbstractFill{T, N}, kj::Vararg{<:Integer, N}) where {T, N}
16+
@boundscheck checkbounds(F, kj...)
17+
getindex_value(F)
18+
end
19+
Base.IndexStyle(F::AbstractFill) = IndexLinear()
20+
struct Fill{T, N} <: AbstractFill{T, N}
21+
value::T
22+
size::NTuple{N, Int}
23+
end
24+
getindex_value(x::Fill) = x.value
25+
@inline Base.size(F::Fill) = F.size
26+
struct Eye{T} <: AbstractMatrix{T}
27+
size::NTuple{2, Int}
28+
end
29+
Base.size(E::Eye) = E.size
30+
@inline function Base.getindex(E::Eye{T}, k::Integer, j::Integer) where T
31+
@boundscheck checkbounds(E, k, j)
32+
ifelse(k == j, one(T), zero(T))
33+
end
34+
35+
436
function run_construction(Typ)
537
@testset "Construction" begin
638
constructors(Typ)
739
conversion(Typ)
840
value_constructor(Typ)
41+
iterator_constructors(Typ)
942
end
1043
end
1144

@@ -124,7 +157,7 @@ function value_constructor(Typ)
124157
@test all(x-> x == Int32(77), Array(x2))
125158

126159
x = eye(T, 2, 2)
127-
160+
128161
x1 = eye(Typ{T, 2}, 2, 2)
129162
x2 = eye(Typ{T}, (2, 2))
130163
x3 = eye(Typ{T, 2}, (2, 2))
@@ -135,3 +168,17 @@ function value_constructor(Typ)
135168
end
136169
end
137170
end
171+
function iterator_constructors(Typ)
172+
@testset "iterator constructors" begin
173+
for T in supported_eltypes()
174+
@test Typ(Fill(T(0), (10,))) == zeros(Typ{T}, 10)
175+
@test Typ(Fill(T(0), (10, 10))) == zeros(Typ{T}, 10, 10)
176+
x = Typ{Float32}(Fill(T(0), (10, 10)))
177+
eltype(x) == Float32
178+
179+
@test Typ(Eye{T}((10, 10))) == eye(Typ{T}, 10, 10)
180+
x = Typ{Float32}(Eye{T}((10, 10)))
181+
@test eltype(x) == Float32
182+
end
183+
end
184+
end

0 commit comments

Comments
 (0)