Skip to content

Commit a89fa3c

Browse files
authored
Merge pull request #237 from JuliaGPU/tb/from_cuarrays_with_love
Port functionality from CuArrays
2 parents 7c56448 + 8eb42a8 commit a89fa3c

File tree

8 files changed

+97
-63
lines changed

8 files changed

+97
-63
lines changed

src/host/base.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,13 @@ function Base.repeat(a::AbstractGPUVector, m::Int)
5353
end
5454
return b
5555
end
56+
57+
## PermutedDimsArrays
58+
59+
using Base: PermutedDimsArrays
60+
61+
# PermutedDimsArrays' custom copyto! doesn't know how to deal with GPU arrays
62+
function PermutedDimsArrays._copy!(dest::PermutedDimsArray{T,N,<:Any,<:Any,<:AbstractGPUArray}, src) where {T,N}
63+
dest .= src
64+
dest
65+
end

src/host/indexing.jl

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,39 +5,72 @@ export allowscalar, @allowscalar, assertscalar
55

66
# mechanism to disallow scalar operations
77

8-
const scalar_allowed = Ref(true)
8+
@enum ScalarIndexing ScalarAllowed ScalarWarned ScalarDisallowed
9+
10+
const scalar_allowed = Ref(ScalarWarned)
911
const scalar_warned = Ref(false)
1012

11-
function allowscalar(flag = true)
12-
scalar_allowed[] = flag
13+
"""
14+
allowscalar(allow=true, warn=true)
15+
16+
Configure whether scalar indexing is allowed depending on the value of `allow`.
17+
18+
If allowed, `warn` can be set to throw a single warning instead. Calling this function will
19+
reset the state of the warning, and throw a new warning on subsequent scalar iteration.
20+
"""
21+
function allowscalar(allow::Bool=true, warn::Bool=true)
1322
scalar_warned[] = false
23+
scalar_allowed[] = if allow && !warn
24+
ScalarAllowed
25+
elseif allow
26+
ScalarWarned
27+
else
28+
ScalarDisallowed
29+
end
1430
return
1531
end
1632

33+
"""
34+
assertscalar(op::String)
35+
36+
Assert that a certain operation `op` performs scalar indexing. If this is not allowed, an
37+
error will be thrown ([`allowscalar`](@ref)).
38+
"""
1739
function assertscalar(op = "operation")
18-
if !scalar_allowed[]
40+
if scalar_allowed[] == ScalarDisallowed
1941
error("$op is disallowed")
20-
elseif !scalar_warned[]
42+
elseif scalar_allowed[] == ScalarWarned && !scalar_warned[]
2143
@warn "Performing scalar operations on GPU arrays: This is very slow, consider disallowing these operations with `allowscalar(false)`"
2244
scalar_warned[] = true
2345
end
2446
return
2547
end
2648

49+
"""
50+
@allowscalar ex...
51+
@disallowscalar ex...
52+
53+
Temporarily allow or disallow scalar iteration.
54+
55+
Note that this functionality is intended for functionality that is known and allowed to use
56+
scalar iteration (or not), i.e., there is no option to throw a warning. Only use this on
57+
fine-grained expressions.
58+
"""
2759
macro allowscalar(ex)
2860
quote
2961
local prev = scalar_allowed[]
30-
scalar_allowed[] = true
62+
scalar_allowed[] = ScalarAllowed
3163
local ret = $(esc(ex))
3264
scalar_allowed[] = prev
3365
ret
3466
end
3567
end
3668

69+
@doc (@doc @allowscalar) ->
3770
macro disallowscalar(ex)
3871
quote
3972
local prev = scalar_allowed[]
40-
scalar_allowed[] = false
73+
scalar_allowed[] = ScalarDisallowed
4174
local ret = $(esc(ex))
4275
scalar_allowed[] = prev
4376
ret

src/host/linalg.jl

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,25 @@
11
# integration with LinearAlgebra stdlib
22

3-
function LinearAlgebra.transpose!(At::AbstractGPUArray{T, 2}, A::AbstractGPUArray{T, 2}) where T
3+
## transpose and adjoint
4+
5+
function transpose_f!(f, At::AbstractGPUArray{T, 2}, A::AbstractGPUArray{T, 2}) where T
46
gpu_call(At, A) do ctx, At, A
57
idx = @cartesianidx A ctx
6-
@inbounds At[idx[2], idx[1]] = A[idx[1], idx[2]]
8+
@inbounds At[idx[2], idx[1]] = f(A[idx[1], idx[2]])
79
return
810
end
911
At
1012
end
1113

12-
function genperm(I::CartesianIndex{N}, perm::NTuple{N}) where N
13-
CartesianIndex(ntuple(d-> (@inbounds return I[perm[d]]), Val(N)))
14+
LinearAlgebra.transpose!(At::AbstractGPUArray, A::AbstractGPUArray) = transpose_f!(transpose, At, A)
15+
LinearAlgebra.adjoint!(At::AbstractGPUArray, A::AbstractGPUArray) = transpose_f!(adjoint, At, A)
16+
17+
function Base.copyto!(A::AbstractGPUArray, B::Adjoint{T, <: AbstractGPUArray}) where T
18+
adjoint!(A, B.parent)
1419
end
1520

16-
function LinearAlgebra.permutedims!(dest::AbstractGPUArray, src::AbstractGPUArray, perm) where N
17-
perm isa Tuple || (perm = Tuple(perm))
18-
gpu_call(dest, src, perm) do ctx, dest, src, perm
19-
I = @cartesianidx src ctx
20-
@inbounds dest[genperm(I, perm)] = src[I]
21-
return
22-
end
23-
return dest
21+
function Base.copyto!(A::AbstractGPUArray, B::Transpose{T, <: AbstractGPUArray}) where T
22+
transpose!(A, B.parent)
2423
end
2524

2625
function Base.copyto!(A::AbstractArray, B::Adjoint{<:Any, <:AbstractGPUArray})
@@ -29,17 +28,17 @@ end
2928
function Base.copyto!(A::AbstractArray, B::Transpose{<:Any, <:AbstractGPUArray})
3029
copyto!(A, Transpose(Array(parent(B))))
3130
end
31+
32+
33+
## triangular
34+
3235
function Base.copyto!(A::AbstractArray, B::UpperTriangular{<:Any, <:AbstractGPUArray})
3336
copyto!(A, UpperTriangular(Array(parent(B))))
3437
end
3538
function Base.copyto!(A::AbstractArray, B::LowerTriangular{<:Any, <:AbstractGPUArray})
3639
copyto!(A, LowerTriangular(Array(parent(B))))
3740
end
3841

39-
function Base.copyto!(A::AbstractGPUArray, B::Adjoint{T, <: AbstractGPUArray}) where T
40-
transpose!(A, B.parent)
41-
end
42-
4342
function LinearAlgebra.tril!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
4443
gpu_call(A, d) do ctx, _A, _d
4544
I = @cartesianidx _A
@@ -64,17 +63,8 @@ function LinearAlgebra.triu!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
6463
return A
6564
end
6665

67-
function LinearAlgebra.copy_transpose!(dst::AbstractGPUArray, src::AbstractGPUArray)
68-
gpu_call(st, src) do ctx, dst, src
69-
I = @cartesianidx dst
70-
dst[I...] = src[reverse(I)...]
71-
return
72-
end
73-
return dst
74-
end
7566

76-
77-
# matrix multiplication
67+
## matrix multiplication
7868

7969
function generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S}) where {T,S,R}
8070
if size(A,2) != size(B,1)
@@ -137,3 +127,20 @@ function generic_lmul!(s::Number, X::AbstractGPUArray)
137127
end
138128

139129
LinearAlgebra.lmul!(a::Number, B::AbstractGPUArray) = generic_lmul!(a, B)
130+
131+
132+
## permutedims
133+
134+
function genperm(I::CartesianIndex{N}, perm::NTuple{N}) where N
135+
CartesianIndex(ntuple(d-> (@inbounds return I[perm[d]]), Val(N)))
136+
end
137+
138+
function LinearAlgebra.permutedims!(dest::AbstractGPUArray, src::AbstractGPUArray, perm) where N
139+
perm isa Tuple || (perm = Tuple(perm))
140+
gpu_call(dest, src, perm) do ctx, dest, src, perm
141+
I = @cartesianidx src ctx
142+
@inbounds dest[genperm(I, perm)] = src[I]
143+
return
144+
end
145+
return dest
146+
end

src/host/mapreduce.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -178,14 +178,3 @@ function acc_mapreduce(f, op, v0::OT, A::GPUSrcArray, rest...) where {OT}
178178
target=out, threads=threads, blocks=blocks)
179179
reduce(op, Array(out))
180180
end
181-
182-
"""
183-
Same as Base.isapprox, but without keyword args and without nans
184-
"""
185-
function fast_isapprox(x::Number, y::Number, rtol::Real = Base.rtoldefault(x, y), atol::Real=0)
186-
x == y || (isfinite(x) && isfinite(y) && abs(x-y) <= max(atol, rtol*max(abs(x), abs(y))))
187-
end
188-
189-
Base.isapprox(A::AbstractGPUArray{T1}, B::AbstractGPUArray{T2}, rtol::Real = Base.rtoldefault(T1, T2, 0), atol::Real=0) where {T1, T2} = all(fast_isapprox.(A, B, T1(rtol)|>real, T1(atol)|>real))
190-
Base.isapprox(A::AbstractArray{T1}, B::AbstractGPUArray{T2}, rtol::Real = Base.rtoldefault(T1, T2, 0), atol::Real=0) where {T1, T2} = all(fast_isapprox.(A, Array(B), T1(rtol)|>real, T1(atol)|>real))
191-
Base.isapprox(A::AbstractGPUArray{T1}, B::AbstractArray{T2}, rtol::Real = Base.rtoldefault(T1, T2, 0), atol::Real=0) where {T1, T2} = all(fast_isapprox.(Array(A), B, T1(rtol)|>real, T1(atol)|>real))

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@ include("testsuite.jl")
44

55
@testset "JLArray" begin
66
using GPUArrays.JLArrays
7+
JLArrays.allowscalar(false)
78
TestSuite.test(JLArray)
89
end

test/testsuite/base.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,5 +148,12 @@ function test_base(AT)
148148
@test blocks == 1
149149
@test threads == 1
150150
end
151+
152+
@testset "permutedims" begin
153+
@test compare(x->permutedims(x, [1, 2]), AT, rand(4, 4))
154+
155+
inds = rand(1:100, 150, 150)
156+
@test compare(x->permutedims(view(x, inds, :), (3, 2, 1)), AT, rand(100, 100))
157+
end
151158
end
152159
end

test/testsuite/linalg.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
function test_linalg(AT)
22
@testset "linear algebra" begin
3-
@testset "transpose" begin
3+
@testset "adjoint and transpose" begin
44
@test compare(adjoint, AT, rand(Float32, 32, 32))
5+
@test compare(adjoint!, AT, rand(Float32, 32, 32), rand(Float32, 32, 32))
56
@test compare(transpose, AT, rand(Float32, 32, 32))
7+
@test compare(transpose!, AT, rand(Float32, 32, 32), rand(Float32, 32, 32))
8+
@test compare((x,y)->copyto!(x, adjoint(y)), AT, rand(Float32, 32, 32), rand(Float32, 32, 32))
9+
@test compare((x,y)->copyto!(x, transpose(y)), AT, rand(Float32, 32, 32), rand(Float32, 32, 32))
610
@test compare(transpose!, AT, Array{Float32}(undef, 32, 32), rand(Float32, 32, 32))
711
@test compare(transpose!, AT, Array{Float32}(undef, 128, 32), rand(Float32, 32, 128))
812
end
@@ -19,6 +23,7 @@ function test_linalg(AT)
1923
copyto!(ga, LowerTriangular(gb))
2024
@test ga == Array(collect(LowerTriangular(gb)))
2125
end
26+
2227
@testset "permutedims" begin
2328
@test compare(x -> permutedims(x, (2, 1)), AT, rand(Float32, 2, 3))
2429
@test compare(x -> permutedims(x, (2, 1, 3)), AT, rand(Float32, 4, 5, 6))

test/testsuite/mapreduce.jl

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -58,23 +58,5 @@ function test_mapreduce(AT)
5858
@test A !== deepcopy(A)
5959
end
6060
end
61-
62-
@testset "isapprox" begin
63-
for ET in supported_eltypes()
64-
ET <: Complex && continue
65-
A = fill(AT{ET}, ET(0), (100,))
66-
B = ones(AT{ET}, 100)
67-
@test !(A B)
68-
@test !(A Array(B))
69-
@test !(Array(A) B)
70-
71-
72-
ca = AT(randn(ComplexF64,3,3))
73-
cb = copy(ca)
74-
cb[1:1, 1:1] .+= 1e-7im
75-
@test isapprox(ca, cb, atol=1e-5)
76-
@test !isapprox(ca, cb, atol=1e-9)
77-
end
78-
end
7961
end
8062
end

0 commit comments

Comments
 (0)