Skip to content

Commit b2fab28

Browse files
committed
Simplify array constructors.
1 parent 6c1bddd commit b2fab28

File tree

5 files changed

+17
-38
lines changed

5 files changed

+17
-38
lines changed

src/construction.jl

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
function Base.fill(X::Type{<: GPUArray}, val::T, dims::NTuple{N, Integer}) where {T, N}
2-
res = similar(X, T, dims)
2+
res = similar(X{T}, dims)
33
fill!(res, val)
44
end
55
function Base.fill(X::Type{<: GPUArray{T}}, val, dims::NTuple{N, Integer}) where {T, N}
6-
res = similar(X, T, dims)
6+
res = similar(X, dims)
77
fill!(res, convert(T, val))
88
end
99
function Base.fill!(A::GPUArray{T}, x) where T
@@ -33,17 +33,6 @@ function (T::Type{<: GPUArray})(s::UniformScaling, dims::Dims{2})
3333
end
3434
(T::Type{<: GPUArray})(s::UniformScaling, m::Integer, n::Integer) = T(s, Dims((m, n)))
3535

36-
(T::Type{<: GPUArray})(x) = convert(T, x)
37-
(T::Type{<: GPUArray})(dims::Integer...) = T(dims)
38-
(T::Type{<: GPUArray})(dims::NTuple{N, Base.OneTo{Int}}) where N = T(undef, length.(dims))
39-
(T::Type{<: GPUArray{X} where X})(dims::NTuple{N, Integer}) where N = similar(T, eltype(T), dims)
40-
(T::Type{<: GPUArray{X} where X})(::UndefInitializer, dims::NTuple{N, Integer}) where N = similar(T, eltype(T), dims)
41-
42-
Base.similar(x::X, ::Type{T}, size::Base.Dims{N}) where {X <: GPUArray, T, N} = similar(X, T, size)
43-
Base.similar(::Type{X}, ::Type{T}, size::NTuple{N, Base.OneTo{Int}}) where {X <: GPUArray, T, N} = similar(X, T, length.(size))
44-
45-
Base.convert(AT::Type{<: GPUArray{T, N}}, A::GPUArray{T, N}) where {T, N} = A
46-
4736
function indexstyle(x::T) where T
4837
style = try
4938
Base.IndexStyle(x)
@@ -84,15 +73,15 @@ function Base.convert(AT::Type{<: GPUArray}, iter)
8473
end
8574

8675
function Base.convert(AT::Type{<: GPUArray{T, N}}, A::DenseArray{T, N}) where {T, N}
87-
copyto!(AT(Base.size(A)), A)
76+
copyto!(AT(undef, size(A)), A)
8877
end
8978

9079
function Base.convert(AT::Type{<: GPUArray{T1}}, A::DenseArray{T2, N}) where {T1, T2, N}
91-
copyto!(similar(AT, T1, size(A)), convert(Array{T1, N}, A))
80+
copyto!(similar(AT{T1}, size(A)), convert(Array{T1, N}, A))
9281
end
9382

9483
function Base.convert(AT::Type{<: GPUArray}, A::DenseArray{T2, N}) where {T2, N}
95-
copyto!(similar(AT, T2, size(A)), A)
84+
copyto!(similar(AT{T2}, size(A)), A)
9685
end
9786

9887
function Base.convert(AT::Type{Array{T, N}}, A::GPUArray{CT, CN}) where {T, N, CT, CN}

src/random.jl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,21 +51,16 @@ end
5151

5252
struct RNG <: AbstractRNG
5353
state::GPUArray{NTuple{4,UInt32},1}
54-
55-
function RNG(A::GPUArray)
56-
dev = GPUArrays.device(A)
57-
N = GPUArrays.threads(dev)
58-
state = similar(A, NTuple{4, UInt32}, N)
59-
copyto!(state, [ntuple(i-> rand(UInt32), 4) for i=1:N])
60-
new(state)
61-
end
6254
end
6355

6456
const GLOBAL_RNGS = Dict()
6557
function global_rng(A::GPUArray)
6658
dev = GPUArrays.device(A)
6759
get!(GLOBAL_RNGS, dev) do
68-
RNG(A)
60+
N = GPUArrays.threads(dev)
61+
state = similar(A, NTuple{4, UInt32}, N)
62+
copyto!(state, [ntuple(i-> rand(UInt32), 4) for i=1:N])
63+
RNG(state)
6964
end
7065
end
7166

src/testsuite/construction.jl

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,22 @@ end
1010
function constructors(AT)
1111
@testset "constructors + similar" begin
1212
for T in supported_eltypes()
13-
B = AT{T}(10)
13+
B = AT{T}(undef, 10)
1414
@test B isa AT{T,1}
1515
@test size(B) == (10,)
1616
@test eltype(B) == T
1717

18-
B = AT{T}(10, 10)
18+
B = AT{T}(undef, 10, 10)
1919
@test B isa AT{T,2}
2020
@test size(B) == (10, 10)
2121
@test eltype(B) == T
2222

23-
B = AT{T}((10, 10))
23+
B = AT{T}(undef, (10, 10))
2424
@test B isa AT{T,2}
2525
@test size(B) == (10, 10)
2626
@test eltype(B) == T
2727

28-
B = similar(B, Int32, (11, 15))
28+
B = similar(B, Int32, 11, 15)
2929
@test B isa AT{Int32,2}
3030
@test size(B) == (11, 15)
3131
@test eltype(B) == Int32
@@ -50,11 +50,6 @@ function constructors(AT)
5050
@test size(B) == (11, 15)
5151
@test eltype(B) == Int32
5252

53-
B = similar(AT{Int32, 2}, T, (11, 15))
54-
@test B isa AT{T,2}
55-
@test size(B) == (11, 15)
56-
@test eltype(B) == T
57-
5853
B = similar(AT{T}, (5,))
5954
@test B isa AT{T,1}
6055
@test size(B) == (5,)
@@ -127,7 +122,7 @@ function value_constructor(AT)
127122
@test Array(x3) x
128123

129124
fill!(x1, 2f0)
130-
x2 = fill!(AT{Int32}((4, 4, 4)), 77f0)
125+
x2 = fill!(AT{Int32}(undef, (4, 4, 4)), 77f0)
131126
@test all(x-> x == 2f0, Array(x1))
132127
@test all(x-> x == Int32(77), Array(x2))
133128

@@ -152,7 +147,7 @@ function iterator_constructors(AT)
152147
x = AT{Float32}(Fill(T(0), (10, 10)))
153148
@test eltype(x) == Float32
154149
@test AT(Eye{T}((10))) == AT{T}(I, 10, 10)
155-
x = AT{Float32}(Eye{T}((10)))
150+
x = AT{Float32}(Eye{T}(10))
156151
@test eltype(x) == Float32
157152
end
158153
end

src/testsuite/indexing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ function test_indexing(AT)
1313
end
1414
@testset "multi dim, sliced setindex" begin
1515
x = fill(AT{T}, T(0), (10, 10, 10, 10))
16-
y = AT{T}(5, 5, 10, 10)
16+
y = AT{T}(undef, 5, 5, 10, 10)
1717
rand!(y)
1818
x[2:6, 2:6, :, :] = y
1919
x[2:6, 2:6, :, :] == y

src/testsuite/random.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ function test_random(AT)
22
@testset "Random" begin
33
@testset "rand" begin # uniform
44
for T in (Float32, Float64, Int64, Int32), d in (2, (2,2))
5-
A = AT{T}(d)
5+
A = AT{T}(undef, d)
66
B = copy(A)
77
rand!(A)
88
rand!(B)

0 commit comments

Comments
 (0)