Skip to content

Commit a27dcc4

Browse files
committed
Enable allowscalar(false) testing for JLArray.
1 parent bb83e17 commit a27dcc4

File tree

5 files changed

+51
-34
lines changed

5 files changed

+51
-34
lines changed

src/array.jl

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
# Very simple Julia backend which is just for testing the implementation
2-
# and can be used as a reference implementation
1+
# Very simple Julia back-end which is just for testing the implementation and can be used as
2+
# a reference implementation
3+
4+
5+
## construction
36

47
struct JLArray{T, N} <: GPUArray{T, N}
58
data::Array{T, N}
@@ -12,6 +15,22 @@ end
1215

1316
JLArray(data::AbstractArray{T, N}, size::Dims{N}) where {T,N} = JLArray{T,N}(data, size)
1417

18+
(::Type{<: JLArray{T}})(x::AbstractArray) where T = JLArray(convert(Array{T}, x), size(x))
19+
20+
function JLArray{T, N}(size::NTuple{N, Integer}) where {T, N}
21+
JLArray{T, N}(Array{T, N}(undef, size), size)
22+
end
23+
24+
25+
## getters
26+
27+
size(x::JLArray) = x.size
28+
29+
pointer(x::JLArray) = pointer(x.data)
30+
31+
32+
## I/O
33+
1534
Base.show(io::IO, x::JLArray) = show(io, collect(x))
1635
Base.show(io::IO, x::LinearAlgebra.Adjoint{<:Any,<:JLArray}) = show(io, LinearAlgebra.adjoint(collect(x.parent)))
1736
Base.show(io::IO, x::LinearAlgebra.Transpose{<:Any,<:JLArray}) = show(io, LinearAlgebra.transpose(collect(x.parent)))
@@ -20,15 +39,16 @@ Base.show(io::IO, ::MIME"text/plain", x::JLArray) = show(io, MIME"text/plain"(),
2039
Base.show(io::IO, ::MIME"text/plain", x::LinearAlgebra.Adjoint{<:Any,<:JLArray}) = show(io, MIME"text/plain"(), LinearAlgebra.adjoint(collect(x.parent)))
2140
Base.show(io::IO, ::MIME"text/plain", x::LinearAlgebra.Transpose{<:Any,<:JLArray}) = show(io, MIME"text/plain"(), LinearAlgebra.transpose(collect(x.parent)))
2241

42+
43+
## other
44+
2345
"""
2446
Thread group local memory
2547
"""
2648
struct LocalMem{N, T}
2749
x::NTuple{N, Vector{T}}
2850
end
2951

30-
size(x::JLArray) = x.size
31-
pointer(x::JLArray) = pointer(x.data)
3252
to_device(state, x::JLArray) = x.data
3353
to_device(state, x::Tuple) = to_device.(Ref(state), x)
3454
to_device(state, x::RefValue{<: JLArray}) = RefValue(to_device(state, x[]))
@@ -40,12 +60,6 @@ to_blocks(state, x) = x
4060
# unpacks local memory for each block
4161
to_blocks(state, x::LocalMem) = x.x[blockidx_x(state)]
4262

43-
(::Type{<: JLArray{T}})(x::AbstractArray) where T = JLArray(convert(Array{T}, x), size(x))
44-
45-
function JLArray{T, N}(size::NTuple{N, Integer}) where {T, N}
46-
JLArray{T, N}(Array{T, N}(undef, size), size)
47-
end
48-
4963
similar(::Type{<: JLArray}, ::Type{T}, size::Base.Dims{N}) where {T, N} = JLArray{T, N}(size)
5064

5165
function unsafe_reinterpret(::Type{T}, A::JLArray{ET}, size::NTuple{N, Integer}) where {T, ET, N}
@@ -131,7 +145,8 @@ function _gpu_call(f, A::JLArray, args::Tuple, blocks_threads::Tuple{T, T}) wher
131145
block_args = to_blocks.(Ref(state), device_args)
132146
for threadidx in CartesianIndices(threads)
133147
thread_state = JLState(state, threadidx.I)
134-
tasks[threadidx] = @async f(thread_state, block_args...)
148+
tasks[threadidx] = @async @allowscalar f(thread_state, block_args...)
149+
# TODO: @async obfuscates the trace to any exception which happens during f
135150
end
136151
for t in tasks
137152
fetch(t)
@@ -146,7 +161,6 @@ device(x::JLArray) = JLDevice()
146161
threads(dev::JLDevice) = 256
147162
blocks(dev::JLDevice) = (256, 256, 256)
148163

149-
150164
@inline function synchronize_threads(::JLState)
151165
#=
152166
All threads are getting started asynchronously,so a yield will
@@ -168,8 +182,9 @@ end
168182
blas_module(::JLArray) = LinearAlgebra.BLAS
169183
blasbuffer(A::JLArray) = A.data
170184

185+
# defining our own plan type is the easiest way to pass around the plans in Base interface
186+
# without ambiguities
171187

172-
# defining our own plan type is the easiest way to pass around the plans in Base interface without ambiguities
173188
struct FFTPlan{T}
174189
p::T
175190
end
@@ -192,7 +207,6 @@ function plan_ifft(A::JLArray; kw_args...)
192207
FFTPlan(plan_ifft(A.data; kw_args...))
193208
end
194209

195-
196210
function *(plan::FFTPlan, A::JLArray)
197211
x = plan.p * A.data
198212
JLArray(x)

src/indexing.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
1-
const _allowslow = Ref(true)
1+
const _allowscalar = Ref(true)
22

3-
allowslow(flag = true) = (_allowslow[] = flag)
3+
allowscalar(flag = true) = (_allowscalar[] = flag)
44

5-
function assertslow(op = "Operation")
6-
# _allowslow[] || error("$op is disabled")
7-
return
5+
function assertscalar(op = "Operation")
6+
_allowscalar[] || error("$op is disabled")
7+
return
8+
end
9+
10+
macro allowscalar(ex)
11+
quote
12+
local prev = _allowscalar[]
13+
_allowscalar[] = true
14+
local ret = $(esc(ex))
15+
_allowscalar[] = prev
16+
ret
17+
end
818
end
919

1020
Base.IndexStyle(::Type{<:GPUArray}) = Base.IndexLinear()
@@ -16,7 +26,7 @@ function _getindex(xs::GPUArray{T}, i::Integer) where T
1626
end
1727

1828
function Base.getindex(xs::GPUArray{T}, i::Integer) where T
19-
assertslow("getindex")
29+
ndims(xs) > 0 && assertscalar("scalar getindex")
2030
_getindex(xs, i)
2131
end
2232

@@ -27,12 +37,13 @@ function _setindex!(xs::GPUArray{T}, v::T, i::Integer) where T
2737
end
2838

2939
function Base.setindex!(xs::GPUArray{T}, v::T, i::Integer) where T
30-
assertslow("setindex!")
40+
assertscalar("scalar setindex!")
3141
_setindex!(xs, v, i)
3242
end
3343

3444
Base.setindex!(xs::GPUArray, v, i::Integer) = xs[i] = convert(eltype(xs), v)
3545

46+
3647
# Vector indexing
3748

3849
to_index(a, x) = x

src/testsuite.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,11 @@ end
6060
Runs the GPUArrays test suite on array type `Typ`
6161
"""
6262
function test(Typ)
63-
GPUArrays.allowslow(false)
63+
# TODO: more fine-grained allowscalar within test_indexing
64+
GPUArrays.allowscalar(true)
65+
TestSuite.test_indexing(Typ)
66+
67+
GPUArrays.allowscalar(false)
6468
TestSuite.test_gpuinterface(Typ)
6569
TestSuite.test_base(Typ)
6670
TestSuite.test_blas(Typ)
@@ -69,6 +73,5 @@ function test(Typ)
6973
TestSuite.test_fft(Typ)
7074
TestSuite.test_linalg(Typ)
7175
TestSuite.test_mapreduce(Typ)
72-
TestSuite.test_indexing(Typ)
7376
TestSuite.test_random(Typ)
7477
end

src/testsuite/base.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,6 @@ function test_base(Typ)
6868
@test Array(a) == x
6969
end
7070

71-
GPUArrays.allowslow(true)
72-
# right now in CLArrays we fallback to geindex since on some hardware
73-
# somehow the vcat kernel segfaults -.-
7471
@testset "vcat + hcat" begin
7572
x = fill(0f0, (10, 10))
7673
y = rand(Float32, 20, 10)
@@ -84,7 +81,6 @@ function test_base(Typ)
8481
against_base(hcat, Typ{Float32}, (3, 3), (3, 3))
8582
against_base(vcat, Typ{Float32}, (3, 3), (3, 3))
8683
end
87-
GPUArrays.allowslow(false)
8884

8985
@testset "reinterpret" begin
9086
a = rand(ComplexF32, 22)

src/testsuite/indexing.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@ function test_indexing(Typ)
44
@testset "Indexing with $T" begin
55
x = rand(T, 32)
66
src = Typ(x)
7-
GPUArrays.allowslow(true)
87
for (i, xi) in enumerate(x)
98
@test src[i] == xi
109
end
11-
GPUArrays.allowslow(false)
1210
@test Array(src[1:3]) == x[1:3]
1311
@test Array(src[3:end]) == x[3:end]
1412
end
@@ -25,7 +23,6 @@ function test_indexing(Typ)
2523
@testset "Indexing with $T" begin
2624
x = fill(zero(T), 7)
2725
src = Typ(x)
28-
GPUArrays.allowslow(true)
2926
for i = 1:7
3027
src[i] = i
3128
end
@@ -34,7 +31,6 @@ function test_indexing(Typ)
3431
@test Array(src[1:3]) == T[77, 22, 11]
3532
src[1] = T(0)
3633
src[2:end] = T(77)
37-
GPUArrays.allowslow(false)
3834
@test Array(src) == T[0, 77, 77, 77, 77, 77, 77]
3935
end
4036
end
@@ -43,18 +39,15 @@ function test_indexing(Typ)
4339
@testset "issue #42 with $T" begin
4440
Ac = rand(Float32, 2, 2)
4541
A = Typ(Ac)
46-
GPUArrays.allowslow(true)
4742
@test A[1] == Ac[1]
4843
@test A[end] == Ac[end]
4944
@test A[1, 1] == Ac[1, 1]
50-
GPUArrays.allowslow(false)
5145
end
5246
end
5347
for T in (Float32, Int32)
5448
@testset "Colon() $T" begin
5549
Ac = rand(T, 10)
5650
A = Typ(Ac)
57-
GPUArrays.allowslow(false)
5851
A[:] = T(1)
5952
@test all(x-> x == 1, A)
6053
A[:] = Typ(Ac)

0 commit comments

Comments
 (0)