Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 37 additions & 21 deletions src/fft/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,47 +136,63 @@ for (f, xtype, inplace, forward) in (
(:plan_fft, :rocfft_transform_type_complex_forward, :false, :true),
(:plan_bfft, :rocfft_transform_type_complex_inverse, :false, :false),
)
@eval function $f(X::ROCArray{T, N}, region) where {T <: rocfftComplexes, N}
_inplace = $(inplace)
_xtype = $(xtype)
R = length(region)
region = NTuple{R,Int}(region)
pp = get_plan(_xtype, size(X), T, _inplace, region)
return cROCFFTPlan{T,$forward,_inplace,N,R,Nothing}(pp..., X, size(X), _xtype, region, nothing, false, T)
@eval begin
# Try to constant-propagate the `region` argument so that its length `R` can be inferred.
Base.@constprop :aggressive function $f(X::ROCArray{T, N}, region) where {T <: rocfftComplexes, N}
R = length(region)
region = NTuple{R,Int}(region)
return $f(X, region)
end

function $f(X::ROCArray{T, N}, region::NTuple{R,Int}) where {T <: rocfftComplexes, N, R}
_inplace = $(inplace)
_xtype = $(xtype)
pp = get_plan(_xtype, size(X), T, _inplace, region)
return cROCFFTPlan{T,$forward,_inplace,N,R,Nothing}(pp..., X, size(X), _xtype, region, nothing, false, T)
end
end
end

function plan_rfft(X::ROCArray{T,N}, region) where {T<:rocfftReals,N}
inplace = false
xtype = rocfft_transform_type_real_forward
Base.@constprop :aggressive function plan_rfft(X::ROCArray{T,N}, region) where {T<:rocfftReals,N}
R = length(region)
region = NTuple{R,Int}(region)
return plan_rfft(X, region)
end

function plan_rfft(X::ROCArray{T,N}, region::NTuple{R,Int}) where {T<:rocfftReals,N,R}
inplace = false
xtype = rocfft_transform_type_real_forward
pp = get_plan(xtype, size(X), T, inplace, region)
ydims = collect(size(X))
ydims[region[1]] = div(ydims[region[1]],2) + 1

xdims = size(X)
ydims = Base.setindex(xdims, div(xdims[region[1]],2) + 1, region[1])

# The buffer is not needed for real-to-complex (`mul!`),
# but it’s required for complex-to-real (`ldiv!`).
buffer = ROCArray{complex(T)}(undef, ydims...)
buffer = ROCArray{complex(T)}(undef, ydims)
B = typeof(buffer)

return rROCFFTPlan{T,ROCFFT_FORWARD,inplace,N,R,B}(pp..., X, (ydims...,), xtype, region, buffer, true, T)
return rROCFFTPlan{T,ROCFFT_FORWARD,inplace,N,R,B}(pp..., X, ydims, xtype, region, buffer, true, T)
end

function plan_brfft(X::ROCArray{T,N}, d::Integer, region) where {T <: rocfftComplexes, N}
inplace = false
xtype = rocfft_transform_type_real_inverse
Base.@constprop :aggressive function plan_brfft(X::ROCArray{T,N}, d::Integer, region) where {T <: rocfftComplexes, N}
R = length(region)
region = NTuple{R,Int}(region)
ydims = collect(size(X))
ydims[region[1]] = d
pp = get_plan(xtype, (ydims...,), T, inplace, region)
return plan_brfft(X, d, region)
end

function plan_brfft(X::ROCArray{T,N}, d::Integer, region::NTuple{R,Int}) where {T <: rocfftComplexes, N, R}
inplace = false
xtype = rocfft_transform_type_real_inverse
xdims = size(X)
ydims = Base.setindex(xdims, d, region[1])
pp = get_plan(xtype, ydims, T, inplace, region)

# Buffer to not modify the input in a complex-to-real FFT.
buffer = ROCArray{T}(undef, size(X))
B = typeof(buffer)

return rROCFFTPlan{T,ROCFFT_INVERSE,inplace,N,R,B}(pp..., X, (ydims...,), xtype, region, buffer, false, T)
return rROCFFTPlan{T,ROCFFT_INVERSE,inplace,N,R,B}(pp..., X, ydims, xtype, region, buffer, false, T)
end

# FIXME: plan_inv methods allocate needlessly (to provide type parameters and normalization function)
Expand Down
10 changes: 5 additions & 5 deletions test/rocarray/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ function out_of_place(X::AbstractArray{T,N}) where {T <: Complex,N}
fftw_X = fft(X)

dX = ROCArray(X)
p = plan_fft(dX)
p = @inferred plan_fft(dX)
dY = p * dX
@test isapprox(collect(dY), fftw_X; rtol=MYRTOL, atol=MYATOL)
@test X ≈ collect(dX)
Expand All @@ -37,7 +37,7 @@ function in_place(X::AbstractArray{T,N}) where {T <: Complex,N}
fftw_X = fft(X)

dX = ROCArray(X)
p = plan_fft!(dX)
p = @inferred plan_fft!(dX)
p * dX
@test isapprox(collect(dX), fftw_X; rtol=MYRTOL, atol=MYATOL)

Expand All @@ -50,7 +50,7 @@ function batched(X::AbstractArray{T,N}, region) where {T <: Complex,N}
fftw_X = fft(X, region)

dX = ROCArray(X)
p = plan_fft!(dX, region)
p = @inferred plan_fft!(dX, region)
p * dX
@test isapprox(collect(dX), fftw_X; rtol=MYRTOL, atol=MYATOL)

Expand Down Expand Up @@ -173,7 +173,7 @@ function out_of_place(X::AbstractArray{T,N}) where {T <: Real,N}
fftw_X = rfft(X)
dX = ROCArray(X)

p = plan_rfft(dX)
p = @inferred plan_rfft(dX)
dY = p * dX
Y = collect(dY)
@test isapprox(Y, fftw_X; rtol=MYRTOL, atol=MYATOL)
Expand All @@ -197,7 +197,7 @@ function batched(X::AbstractArray{T,N},region) where {T <: Real,N}
fftw_X = rfft(X,region)
dX = ROCArray(X)

p = plan_rfft(dX, region)
p = @inferred plan_rfft(dX, region)
dY = p * dX
@test isapprox(collect(dY), fftw_X; rtol=MYRTOL, atol=MYATOL)
@test X ≈ collect(dX)
Expand Down