Skip to content

Commit 6046c73

Browse files
authored
Merge pull request #169 from JuliaGPU/tb/simplify
Simplify array constructors.
2 parents 6c1bddd + 904d103 commit 6046c73

File tree

7 files changed

+104
-63
lines changed

7 files changed

+104
-63
lines changed

src/array.jl

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,79 @@
11
# Very simple Julia back-end which is just for testing the implementation and can be used as
22
# a reference implementation
33

4-
5-
## construction
6-
74
struct JLArray{T, N} <: GPUArray{T, N}
85
data::Array{T, N}
9-
size::Dims{N}
6+
dims::Dims{N}
107

11-
function JLArray{T,N}(data::Array{T, N}, size::NTuple{N, Int}) where {T,N}
12-
new(data, size)
8+
function JLArray{T,N}(data::Array{T, N}, dims::Dims{N}) where {T,N}
9+
new(data, dims)
1310
end
1411
end
1512

16-
JLArray(data::AbstractArray{T, N}, size::Dims{N}) where {T,N} = JLArray{T,N}(data, size)
1713

18-
(::Type{<: JLArray{T}})(x::AbstractArray) where T = JLArray(convert(Array{T}, x), size(x))
14+
## construction
15+
16+
# type and dimensionality specified, accepting dims as tuples of Ints
17+
JLArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N} =
18+
JLArray{T,N}(Array{T, N}(undef, dims), dims)
1919

20-
function JLArray{T, N}(size::NTuple{N, Integer}) where {T, N}
21-
JLArray{T, N}(Array{T, N}(undef, size), size)
22-
end
20+
# type and dimensionality specified, accepting dims as series of Ints
21+
JLArray{T,N}(::UndefInitializer, dims::Integer...) where {T,N} = JLArray{T,N}(undef, dims)
2322

24-
struct JLBackend <: GPUBackend end
25-
backend(::Type{<:JLArray}) = JLBackend()
23+
# type but not dimensionality specified
24+
JLArray{T}(::UndefInitializer, dims::Dims{N}) where {T,N} = JLArray{T,N}(undef, dims)
25+
JLArray{T}(::UndefInitializer, dims::Integer...) where {T} =
26+
JLArray{T}(undef, convert(Tuple{Vararg{Int}}, dims))
27+
28+
# empty vector constructor
29+
JLArray{T,1}() where {T} = JLArray{T,1}(undef, 0)
30+
31+
32+
Base.similar(a::JLArray{T,N}) where {T,N} = JLArray{T,N}(undef, size(a))
33+
Base.similar(a::JLArray{T}, dims::Base.Dims{N}) where {T,N} = JLArray{T,N}(undef, dims)
34+
Base.similar(a::JLArray, ::Type{T}, dims::Base.Dims{N}) where {T,N} = JLArray{T,N}(undef, dims)
35+
36+
37+
## array interface
38+
39+
Base.elsize(::Type{<:JLArray{T}}) where {T} = sizeof(T)
40+
41+
Base.size(x::JLArray) = x.dims
42+
Base.sizeof(x::JLArray) = Base.elsize(x) * length(x)
2643

27-
## getters
2844

29-
Base.size(x::JLArray) = x.size
45+
## interop with other arrays
3046

31-
Base.pointer(x::JLArray) = pointer(x.data)
47+
JLArray{T,N}(x::AbstractArray{S,N}) where {T,N,S} =
48+
JLArray{T,N}(convert(Array{T}, x), size(x))
3249

50+
# underspecified constructors
51+
JLArray{T}(xs::AbstractArray{S,N}) where {T,N,S} = JLArray{T,N}(xs)
52+
(::Type{JLArray{T,N} where T})(x::AbstractArray{S,N}) where {S,N} = JLArray{S,N}(x)
53+
JLArray(A::AbstractArray{T,N}) where {T,N} = JLArray{T,N}(A)
3354

34-
## other
55+
# idempotency
56+
JLArray{T,N}(xs::JLArray{T,N}) where {T,N} = xs
57+
58+
59+
## conversions
60+
61+
Base.convert(::Type{T}, x::T) where T <: JLArray = x
62+
63+
64+
## broadcast
65+
66+
BroadcastStyle(::Type{<:JLArray}) = ArrayStyle{JLArray}()
67+
68+
function Base.similar(bc::Broadcasted{ArrayStyle{JLArray}}, ::Type{T}) where T
69+
similar(JLArray{T}, axes(bc))
70+
end
71+
72+
73+
## gpuarray interface
74+
75+
struct JLBackend <: GPUBackend end
76+
backend(::Type{<:JLArray}) = JLBackend()
3577

3678
"""
3779
Thread group local memory
@@ -51,8 +93,6 @@ to_blocks(state, x) = x
5193
# unpacks local memory for each block
5294
to_blocks(state, x::LocalMem) = x.x[blockidx_x(state)]
5395

54-
Base.similar(::Type{<: JLArray}, ::Type{T}, size::Base.Dims{N}) where {T, N} = JLArray{T, N}(size)
55-
5696
unsafe_reinterpret(::Type{T}, A::JLArray, size::Tuple) where T =
5797
reshape(reinterpret(T, A.data), size)
5898

src/broadcast.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,6 @@ const GPUDestArray = Union{GPUArray,
3131
LinearAlgebra.Adjoint{<:Any,<:GPUArray},
3232
SubArray{<:Any,<:Any,<:GPUArray}}
3333

34-
# This method is responsible for selection the output type of broadcast
35-
function Base.similar(bc::Broadcasted{<:ArrayStyle{GPU}}, ::Type{ElType}) where
36-
{GPU <: GPUArray, ElType}
37-
similar(GPU, ElType, axes(bc))
38-
end
39-
4034
# We purposefully only specialize `copyto!`, dependent packages need to make sure that they
4135
# can handle:
4236
# - `bc::Broadcast.Broadcasted{Style}`

src/construction.jl

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
function Base.fill(X::Type{<: GPUArray}, val::T, dims::NTuple{N, Integer}) where {T, N}
2-
res = similar(X, T, dims)
2+
res = similar(X{T}, dims)
33
fill!(res, val)
44
end
55
function Base.fill(X::Type{<: GPUArray{T}}, val, dims::NTuple{N, Integer}) where {T, N}
6-
res = similar(X, T, dims)
6+
res = similar(X, dims)
77
fill!(res, convert(T, val))
88
end
99
function Base.fill!(A::GPUArray{T}, x) where T
@@ -33,17 +33,6 @@ function (T::Type{<: GPUArray})(s::UniformScaling, dims::Dims{2})
3333
end
3434
(T::Type{<: GPUArray})(s::UniformScaling, m::Integer, n::Integer) = T(s, Dims((m, n)))
3535

36-
(T::Type{<: GPUArray})(x) = convert(T, x)
37-
(T::Type{<: GPUArray})(dims::Integer...) = T(dims)
38-
(T::Type{<: GPUArray})(dims::NTuple{N, Base.OneTo{Int}}) where N = T(undef, length.(dims))
39-
(T::Type{<: GPUArray{X} where X})(dims::NTuple{N, Integer}) where N = similar(T, eltype(T), dims)
40-
(T::Type{<: GPUArray{X} where X})(::UndefInitializer, dims::NTuple{N, Integer}) where N = similar(T, eltype(T), dims)
41-
42-
Base.similar(x::X, ::Type{T}, size::Base.Dims{N}) where {X <: GPUArray, T, N} = similar(X, T, size)
43-
Base.similar(::Type{X}, ::Type{T}, size::NTuple{N, Base.OneTo{Int}}) where {X <: GPUArray, T, N} = similar(X, T, length.(size))
44-
45-
Base.convert(AT::Type{<: GPUArray{T, N}}, A::GPUArray{T, N}) where {T, N} = A
46-
4736
function indexstyle(x::T) where T
4837
style = try
4938
Base.IndexStyle(x)
@@ -84,15 +73,15 @@ function Base.convert(AT::Type{<: GPUArray}, iter)
8473
end
8574

8675
function Base.convert(AT::Type{<: GPUArray{T, N}}, A::DenseArray{T, N}) where {T, N}
87-
copyto!(AT(Base.size(A)), A)
76+
copyto!(AT(undef, size(A)), A)
8877
end
8978

9079
function Base.convert(AT::Type{<: GPUArray{T1}}, A::DenseArray{T2, N}) where {T1, T2, N}
91-
copyto!(similar(AT, T1, size(A)), convert(Array{T1, N}, A))
80+
copyto!(similar(AT, size(A)), convert(Array{T1, N}, A))
9281
end
9382

9483
function Base.convert(AT::Type{<: GPUArray}, A::DenseArray{T2, N}) where {T2, N}
95-
copyto!(similar(AT, T2, size(A)), A)
84+
copyto!(similar(AT{T2}, size(A)), A)
9685
end
9786

9887
function Base.convert(AT::Type{Array{T, N}}, A::GPUArray{CT, CN}) where {T, N, CT, CN}

src/random.jl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,21 +51,16 @@ end
5151

5252
struct RNG <: AbstractRNG
5353
state::GPUArray{NTuple{4,UInt32},1}
54-
55-
function RNG(A::GPUArray)
56-
dev = GPUArrays.device(A)
57-
N = GPUArrays.threads(dev)
58-
state = similar(A, NTuple{4, UInt32}, N)
59-
copyto!(state, [ntuple(i-> rand(UInt32), 4) for i=1:N])
60-
new(state)
61-
end
6254
end
6355

6456
const GLOBAL_RNGS = Dict()
6557
function global_rng(A::GPUArray)
6658
dev = GPUArrays.device(A)
6759
get!(GLOBAL_RNGS, dev) do
68-
RNG(A)
60+
N = GPUArrays.threads(dev)
61+
state = similar(A, NTuple{4, UInt32}, N)
62+
copyto!(state, [ntuple(i-> rand(UInt32), 4) for i=1:N])
63+
RNG(state)
6964
end
7065
end
7166

src/testsuite/construction.jl

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,22 @@ end
1010
function constructors(AT)
1111
@testset "constructors + similar" begin
1212
for T in supported_eltypes()
13-
B = AT{T}(10)
13+
B = AT{T}(undef, 10)
1414
@test B isa AT{T,1}
1515
@test size(B) == (10,)
1616
@test eltype(B) == T
1717

18-
B = AT{T}(10, 10)
18+
B = AT{T}(undef, 10, 10)
1919
@test B isa AT{T,2}
2020
@test size(B) == (10, 10)
2121
@test eltype(B) == T
2222

23-
B = AT{T}((10, 10))
23+
B = AT{T}(undef, (10, 10))
2424
@test B isa AT{T,2}
2525
@test size(B) == (10, 10)
2626
@test eltype(B) == T
2727

28-
B = similar(B, Int32, (11, 15))
28+
B = similar(B, Int32, 11, 15)
2929
@test B isa AT{Int32,2}
3030
@test size(B) == (11, 15)
3131
@test eltype(B) == Int32
@@ -50,11 +50,6 @@ function constructors(AT)
5050
@test size(B) == (11, 15)
5151
@test eltype(B) == Int32
5252

53-
B = similar(AT{Int32, 2}, T, (11, 15))
54-
@test B isa AT{T,2}
55-
@test size(B) == (11, 15)
56-
@test eltype(B) == T
57-
5853
B = similar(AT{T}, (5,))
5954
@test B isa AT{T,1}
6055
@test size(B) == (5,)
@@ -66,6 +61,34 @@ function constructors(AT)
6661
@test eltype(B) == T
6762
end
6863
end
64+
@testset "comparison against Array" begin
65+
for typs in [(), (Int,), (Int,1), (Int,2), (Float32,), (Float32,1), (Float32,2)],
66+
args in [(), (1,), (1,2), ((1,),), ((1,2),),
67+
(undef,), (undef, 1,), (undef, 1,2), (undef, (1,),), (undef, (1,2),),
68+
(Int,), (Int, 1,), (Int, 1,2), (Int, (1,),), (Int, (1,2),),
69+
([1,2],), ([1 2],)]
70+
cpu = try
71+
Array{typs...}(args...)
72+
catch ex
73+
isa(ex, MethodError) || rethrow()
74+
nothing
75+
end
76+
77+
gpu = try
78+
AT{typs...}(args...)
79+
catch ex
80+
isa(ex, MethodError) || rethrow()
81+
cpu == nothing || rethrow()
82+
nothing
83+
end
84+
85+
if cpu == nothing
86+
@test gpu == nothing
87+
else
88+
@test typeof(cpu) == typeof(convert(Array, gpu))
89+
end
90+
end
91+
end
6992
end
7093

7194
function conversion(AT)
@@ -127,7 +150,7 @@ function value_constructor(AT)
127150
@test Array(x3) x
128151

129152
fill!(x1, 2f0)
130-
x2 = fill!(AT{Int32}((4, 4, 4)), 77f0)
153+
x2 = fill!(AT{Int32}(undef, (4, 4, 4)), 77f0)
131154
@test all(x-> x == 2f0, Array(x1))
132155
@test all(x-> x == Int32(77), Array(x2))
133156

@@ -152,7 +175,7 @@ function iterator_constructors(AT)
152175
x = AT{Float32}(Fill(T(0), (10, 10)))
153176
@test eltype(x) == Float32
154177
@test AT(Eye{T}((10))) == AT{T}(I, 10, 10)
155-
x = AT{Float32}(Eye{T}((10)))
178+
x = AT{Float32}(Eye{T}(10))
156179
@test eltype(x) == Float32
157180
end
158181
end

src/testsuite/indexing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ function test_indexing(AT)
1313
end
1414
@testset "multi dim, sliced setindex" begin
1515
x = fill(AT{T}, T(0), (10, 10, 10, 10))
16-
y = AT{T}(5, 5, 10, 10)
16+
y = AT{T}(undef, 5, 5, 10, 10)
1717
rand!(y)
1818
x[2:6, 2:6, :, :] = y
1919
x[2:6, 2:6, :, :] == y

src/testsuite/random.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ function test_random(AT)
22
@testset "Random" begin
33
@testset "rand" begin # uniform
44
for T in (Float32, Float64, Int64, Int32), d in (2, (2,2))
5-
A = AT{T}(d)
5+
A = AT{T}(undef, d)
66
B = copy(A)
77
rand!(A)
88
rand!(B)

0 commit comments

Comments
 (0)