Skip to content

Commit 3f574ba

Browse files
committed
Further reduce imports.
1 parent a50ae63 commit 3f574ba

File tree

6 files changed

+22
-41
lines changed

6 files changed

+22
-41
lines changed

src/GPUArrays.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ using LinearAlgebra.BLAS
1313
using Base.Cartesian
1414

1515
using FFTW
16-
import FFTW: *, plan_ifft!, plan_fft!, plan_fft, plan_ifft, size, plan_bfft, plan_bfft!
1716

1817
include("abstractarray.jl")
1918
include("abstract_gpu_interface.jl")

src/abstractarray.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import Base: similar, convert, _reshape, map!, copyto!, map, copy, deepcopy
2-
31
# Dense GPU Array
42
abstract type GPUArray{T, N} <: DenseArray{T, N} end
53

@@ -172,8 +170,8 @@ function Base.copyto!(
172170
dest
173171
end
174172

175-
copy(x::GPUArray) = identity.(x)
176-
deepcopy(x::GPUArray) = copy(x)
173+
Base.copy(x::GPUArray) = identity.(x)
174+
Base.deepcopy(x::GPUArray) = copy(x)
177175

178176
#=
179177
reinterpret taken from julia base/array.jl
@@ -222,13 +220,13 @@ function reinterpret(::Type{T}, a::GPUArray{S}, dims::NTuple{N, Integer}) where
222220
unsafe_reinterpret(T, a, dims)
223221
end
224222

225-
function _reshape(A::GPUArray{T}, dims::Dims) where T
223+
function Base._reshape(A::GPUArray{T}, dims::Dims) where T
226224
n = length(A)
227225
prod(dims) == n || throw(DimensionMismatch("parent has $n elements, which is incompatible with size $dims"))
228226
return unsafe_reinterpret(T, A, dims)
229227
end
230228
#ambig
231-
function _reshape(A::GPUArray{T, 1}, dims::Tuple{Integer}) where T
229+
function Base._reshape(A::GPUArray{T, 1}, dims::Tuple{Integer}) where T
232230
n = Base._length(A)
233231
prod(dims) == n || throw(DimensionMismatch("parent has $n elements, which is incompatible with size $dims"))
234232
return unsafe_reinterpret(T, A, dims)

src/array.jl

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ function AbstractDeviceArray(ptr::Array, shape::Vararg{Integer, N}) where N
120120
reshape(ptr, shape)
121121
end
122122

123-
124123
function _gpu_call(f, A::JLArray, args::Tuple, blocks_threads::Tuple{T, T}) where T <: NTuple{N, Integer} where N
125124
blocks, threads = blocks_threads
126125
idx = ntuple(i-> 1, length(blocks))
@@ -170,32 +169,21 @@ end
170169
blas_module(::JLArray) = LinearAlgebra.BLAS
171170
blasbuffer(A::JLArray) = A.data
172171

173-
# defining our own plan type is the easiest way to pass around the plans in Base interface
172+
# defining our own plan type is the easiest way to pass around the plans in FFTW interface
174173
# without ambiguities
175174

176175
struct FFTPlan{T}
177176
p::T
178177
end
179-
function plan_fft(A::JLArray; kw_args...)
180-
FFTPlan(plan_fft(A.data; kw_args...))
181-
end
182-
function plan_fft!(A::JLArray; kw_args...)
183-
FFTPlan(plan_fft!(A.data; kw_args...))
184-
end
185-
function plan_bfft!(A::JLArray; kw_args...)
186-
FFTPlan(plan_bfft!(A.data; kw_args...))
187-
end
188-
function plan_bfft(A::JLArray; kw_args...)
189-
FFTPlan(plan_bfft(A.data; kw_args...))
190-
end
191-
function plan_ifft!(A::JLArray; kw_args...)
192-
FFTPlan(plan_ifft!(A.data; kw_args...))
193-
end
194-
function plan_ifft(A::JLArray; kw_args...)
195-
FFTPlan(plan_ifft(A.data; kw_args...))
196-
end
197178

198-
function *(plan::FFTPlan, A::JLArray)
179+
FFTW.plan_fft(A::JLArray; kw_args...) = FFTPlan(plan_fft(A.data; kw_args...))
180+
FFTW.plan_fft!(A::JLArray; kw_args...) = FFTPlan(plan_fft!(A.data; kw_args...))
181+
FFTW.plan_bfft!(A::JLArray; kw_args...) = FFTPlan(plan_bfft!(A.data; kw_args...))
182+
FFTW.plan_bfft(A::JLArray; kw_args...) = FFTPlan(plan_bfft(A.data; kw_args...))
183+
FFTW.plan_ifft!(A::JLArray; kw_args...) = FFTPlan(plan_ifft!(A.data; kw_args...))
184+
FFTW.plan_ifft(A::JLArray; kw_args...) = FFTPlan(plan_ifft(A.data; kw_args...))
185+
186+
function Base.:(*)(plan::FFTPlan, A::JLArray)
199187
x = plan.p * A.data
200188
JLArray(x)
201189
end

src/base.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
allequal(x) = true
22
allequal(x, y, z...) = x == y && allequal(y, z...)
3-
function map!(f, y::GPUArray, xs::GPUArray...)
3+
function Base.map!(f, y::GPUArray, xs::GPUArray...)
44
@assert allequal(size.((y, xs...))...)
55
return y .= f.(xs...)
66
end
7-
function map(f, y::GPUArray, xs::GPUArray...)
7+
function Base.map(f, y::GPUArray, xs::GPUArray...)
88
@assert allequal(size.((y, xs...))...)
99
return f.(y, xs...)
1010
end

src/construction.jl

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
1-
import Base: fill!, zeros, ones, fill
2-
3-
4-
5-
function fill(X::Type{<: GPUArray}, val::T, dims::NTuple{N, Integer}) where {T, N}
1+
function Base.fill(X::Type{<: GPUArray}, val::T, dims::NTuple{N, Integer}) where {T, N}
62
res = similar(X, T, dims)
73
fill!(res, val)
84
end
9-
function fill(X::Type{<: GPUArray{T}}, val, dims::NTuple{N, Integer}) where {T, N}
5+
function Base.fill(X::Type{<: GPUArray{T}}, val, dims::NTuple{N, Integer}) where {T, N}
106
res = similar(X, T, dims)
117
fill!(res, convert(T, val))
128
end
13-
function fill!(A::GPUArray{T}, x) where T
9+
function Base.fill!(A::GPUArray{T}, x) where T
1410
gpu_call(A, (A, convert(T, x))) do state, a, val
1511
idx = @linearidx(a, state)
1612
@inbounds a[idx] = val
@@ -19,8 +15,8 @@ function fill!(A::GPUArray{T}, x) where T
1915
A
2016
end
2117

22-
zeros(T::Type{<: GPUArray}, dims::NTuple{N, Integer}) where N = fill(T, zero(eltype(T)), dims)
23-
ones(T::Type{<: GPUArray}, dims::NTuple{N, Integer}) where N = fill(T, one(eltype(T)), dims)
18+
Base.zeros(T::Type{<: GPUArray}, dims::NTuple{N, Integer}) where N = fill(T, zero(eltype(T)), dims)
19+
Base.ones(T::Type{<: GPUArray}, dims::NTuple{N, Integer}) where N = fill(T, one(eltype(T)), dims)
2420

2521
function uniformscaling_kernel(state, res::AbstractArray{T}, stride, s::UniformScaling) where T
2622
i = linear_index(state)

src/linalg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,9 @@ function LinearAlgebra.permutedims!(dest::GPUArray, src::GPUArray, perm::NTuple{
9090
end
9191

9292

93-
function copyto!(A::AbstractArray, B::Adjoint{T, <: GPUArray}) where T
93+
function Base.copyto!(A::AbstractArray, B::Adjoint{T, <: GPUArray}) where T
9494
copyto!(A, Adjoint(Array(B.parent)))
9595
end
96-
function copyto!(A::GPUArray, B::Adjoint{T, <: GPUArray}) where T
96+
function Base.copyto!(A::GPUArray, B::Adjoint{T, <: GPUArray}) where T
9797
transpose!(A, B.parent)
9898
end

0 commit comments

Comments
 (0)