Skip to content

Commit ca079ee

Browse files
authored
Merge pull request #329 from JuliaGPU/tb/compat
Simplify compatibility with GPUArrays
2 parents fa6e0b3 + 7374046 commit ca079ee

File tree

7 files changed

+4
-67
lines changed

7 files changed

+4
-67
lines changed

Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@ Adapt = "2.0"
1616
julia = "1.5"
1717

1818
[extras]
19-
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
2019
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
2120
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2221

2322
[targets]
24-
test = ["Test", "FFTW", "FillArrays"]
23+
test = ["Test", "FillArrays"]

docs/src/testsuite.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ TestSuite.test_base(T) # basic functionality like launching a kernel on the GPU
4444
TestSuite.test_blas(T) # tests the blas interface
4545
TestSuite.test_broadcasting(T) # tests the broadcasting implementation
4646
TestSuite.test_construction(T) # tests all kinds of different ways of constructing the array
47-
TestSuite.test_fft(T) # fft tests
4847
TestSuite.test_linalg(T) # linalg function tests
4948
TestSuite.test_mapreduce(T) # mapreduce sum, etc
5049
TestSuite.test_indexing(T) # indexing tests

src/host/linalg.jl

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -211,12 +211,3 @@ function LinearAlgebra.permutedims!(dest::AbstractGPUArray, src::AbstractGPUArra
211211
end
212212
return dest
213213
end
214-
215-
216-
## inv for Triangular
217-
for TR in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular)
218-
@eval function Base.inv(x::$TR{<:Any,<:AbstractGPUArray})
219-
out = typeof(parent(x))(I(size(x,1)))
220-
$TR(LinearAlgebra.ldiv!(x,out))
221-
end
222-
end

src/reference.jl

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -329,30 +329,6 @@ Base.copyto!(dest::DenseJLArray{T}, source::DenseJLArray{T}) where {T} =
329329
copyto!(dest, 1, source, 1, length(source))
330330

331331

332-
## fft
333-
334-
using AbstractFFTs
335-
336-
# defining our own plan type is the easiest way to pass around the plans in FFTW interface
337-
# without ambiguities
338-
339-
struct FFTPlan{T}
340-
p::T
341-
end
342-
343-
AbstractFFTs.plan_fft(A::JLArray; kw_args...) = FFTPlan(plan_fft(A.data; kw_args...))
344-
AbstractFFTs.plan_fft!(A::JLArray; kw_args...) = FFTPlan(plan_fft!(A.data; kw_args...))
345-
AbstractFFTs.plan_bfft!(A::JLArray; kw_args...) = FFTPlan(plan_bfft!(A.data; kw_args...))
346-
AbstractFFTs.plan_bfft(A::JLArray; kw_args...) = FFTPlan(plan_bfft(A.data; kw_args...))
347-
AbstractFFTs.plan_ifft!(A::JLArray; kw_args...) = FFTPlan(plan_ifft!(A.data; kw_args...))
348-
AbstractFFTs.plan_ifft(A::JLArray; kw_args...) = FFTPlan(plan_ifft(A.data; kw_args...))
349-
350-
function Base.:(*)(plan::FFTPlan, A::JLArray)
351-
x = plan.p * A.data
352-
JLArray(x)
353-
end
354-
355-
356332
## Random
357333

358334
using Random
@@ -389,14 +365,4 @@ function GPUArrays.default_rng(::Type{<:JLArray})
389365
GLOBAL_RNG[]
390366
end
391367

392-
393-
## LinearAlgebra
394-
395-
using LinearAlgebra
396-
397-
for TR in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular)
398-
@eval LinearAlgebra.ldiv!(x::$TR{T,<:JLArray{T,2}}, y::JLArray{T,2}) where T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64} = JLArray(LinearAlgebra.ldiv!($TR(parent(x).data),y.data))
399-
@eval LinearAlgebra.rdiv!(x::JLArray{T,2}, y::$TR{T,<:JLArray{T,2}}) where T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64} = JLArray(LinearAlgebra.rdiv!(x.data,$TR(parent(y).data)))
400-
end
401-
402368
end

test/testsuite.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ using LinearAlgebra
1212
using Random
1313
using Test
1414

15-
using FFTW
1615
using Adapt
1716
using FillArrays
1817

@@ -72,7 +71,6 @@ include("testsuite/mapreduce.jl")
7271
include("testsuite/broadcasting.jl")
7372
include("testsuite/linalg.jl")
7473
include("testsuite/math.jl")
75-
include("testsuite/fft.jl")
7674
include("testsuite/random.jl")
7775
include("testsuite/uniformscaling.jl")
7876

test/testsuite/fft.jl

Lines changed: 0 additions & 12 deletions
This file was deleted.

test/testsuite/linalg.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,17 @@
3030

3131
rand!(gpu_a)
3232
gpu_c = copyto!(gpu_b, TR(gpu_a))
33-
@test all(gpu_b .== TR(gpu_a))
33+
@test all(Array(gpu_b) .== TR(Array(gpu_a)))
3434
@test gpu_c isa AT
3535
end
3636

37-
@testset "inv for triangular" for TR in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular)
38-
@test compare(x -> inv(TR(x)), AT, rand(Float32, 32, 32))
39-
end
40-
4137
for TR in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular)
4238
gpu_a = AT{Float32}(undef, 128, 128) |> rand! |> TR
4339
gpu_b = AT{Float32}(undef, 128, 128) |> TR
4440

4541
gpu_c = copyto!(gpu_b, gpu_a)
46-
@test all(gpu_b .== gpu_a)
47-
@test all(gpu_c .== gpu_a)
42+
@test all(Array(gpu_b) .== Array(gpu_a))
43+
@test all(Array(gpu_c) .== Array(gpu_a))
4844
@test gpu_c isa TR
4945
end
5046
end

0 commit comments

Comments
 (0)