diff --git a/src/fft/fft.jl b/src/fft/fft.jl index b0a187fe2..1dfe213bd 100644 --- a/src/fft/fft.jl +++ b/src/fft/fft.jl @@ -136,47 +136,63 @@ for (f, xtype, inplace, forward) in ( (:plan_fft, :rocfft_transform_type_complex_forward, :false, :true), (:plan_bfft, :rocfft_transform_type_complex_inverse, :false, :false), ) - @eval function $f(X::ROCArray{T, N}, region) where {T <: rocfftComplexes, N} - _inplace = $(inplace) - _xtype = $(xtype) - R = length(region) - region = NTuple{R,Int}(region) - pp = get_plan(_xtype, size(X), T, _inplace, region) - return cROCFFTPlan{T,$forward,_inplace,N,R,Nothing}(pp..., X, size(X), _xtype, region, nothing, false, T) + @eval begin + # Try to constant-propagate the `region` argument so that its length `R` can be inferred. + Base.@constprop :aggressive function $f(X::ROCArray{T, N}, region) where {T <: rocfftComplexes, N} + R = length(region) + region = NTuple{R,Int}(region) + return $f(X, region) + end + + function $f(X::ROCArray{T, N}, region::NTuple{R,Int}) where {T <: rocfftComplexes, N, R} + _inplace = $(inplace) + _xtype = $(xtype) + pp = get_plan(_xtype, size(X), T, _inplace, region) + return cROCFFTPlan{T,$forward,_inplace,N,R,Nothing}(pp..., X, size(X), _xtype, region, nothing, false, T) + end end end -function plan_rfft(X::ROCArray{T,N}, region) where {T<:rocfftReals,N} - inplace = false - xtype = rocfft_transform_type_real_forward +Base.@constprop :aggressive function plan_rfft(X::ROCArray{T,N}, region) where {T<:rocfftReals,N} R = length(region) region = NTuple{R,Int}(region) + return plan_rfft(X, region) +end + +function plan_rfft(X::ROCArray{T,N}, region::NTuple{R,Int}) where {T<:rocfftReals,N,R} + inplace = false + xtype = rocfft_transform_type_real_forward pp = get_plan(xtype, size(X), T, inplace, region) - ydims = collect(size(X)) - ydims[region[1]] = div(ydims[region[1]],2) + 1 + + xdims = size(X) + ydims = Base.setindex(xdims, div(xdims[region[1]],2) + 1, region[1]) # The buffer is not needed for real-to-complex (`mul!`), # but it’s required for complex-to-real (`ldiv!`). - buffer = ROCArray{complex(T)}(undef, ydims...) + buffer = ROCArray{complex(T)}(undef, ydims) B = typeof(buffer) - return rROCFFTPlan{T,ROCFFT_FORWARD,inplace,N,R,B}(pp..., X, (ydims...,), xtype, region, buffer, true, T) + return rROCFFTPlan{T,ROCFFT_FORWARD,inplace,N,R,B}(pp..., X, ydims, xtype, region, buffer, true, T) end -function plan_brfft(X::ROCArray{T,N}, d::Integer, region) where {T <: rocfftComplexes, N} - inplace = false - xtype = rocfft_transform_type_real_inverse +Base.@constprop :aggressive function plan_brfft(X::ROCArray{T,N}, d::Integer, region) where {T <: rocfftComplexes, N} R = length(region) region = NTuple{R,Int}(region) - ydims = collect(size(X)) - ydims[region[1]] = d - pp = get_plan(xtype, (ydims...,), T, inplace, region) + return plan_brfft(X, d, region) +end + +function plan_brfft(X::ROCArray{T,N}, d::Integer, region::NTuple{R,Int}) where {T <: rocfftComplexes, N, R} + inplace = false + xtype = rocfft_transform_type_real_inverse + xdims = size(X) + ydims = Base.setindex(xdims, d, region[1]) + pp = get_plan(xtype, ydims, T, inplace, region) # Buffer to not modify the input in a complex-to-real FFT. buffer = ROCArray{T}(undef, size(X)) B = typeof(buffer) - return rROCFFTPlan{T,ROCFFT_INVERSE,inplace,N,R,B}(pp..., X, (ydims...,), xtype, region, buffer, false, T) + return rROCFFTPlan{T,ROCFFT_INVERSE,inplace,N,R,B}(pp..., X, ydims, xtype, region, buffer, false, T) end # FIXME: plan_inv methods allocate needlessly (to provide type parameters and normalization function) diff --git a/test/rocarray/fft.jl b/test/rocarray/fft.jl index 323480220..113b9df7b 100644 --- a/test/rocarray/fft.jl +++ b/test/rocarray/fft.jl @@ -17,7 +17,7 @@ function out_of_place(X::AbstractArray{T,N}) where {T <: Complex,N} fftw_X = fft(X) dX = ROCArray(X) - p = plan_fft(dX) + p = @inferred plan_fft(dX) dY = p * dX @test isapprox(collect(dY), fftw_X; rtol=MYRTOL, atol=MYATOL) @test X ≈ collect(dX) @@ -37,7 +37,7 @@ function in_place(X::AbstractArray{T,N}) where {T <: Complex,N} fftw_X = fft(X) dX = ROCArray(X) - p = plan_fft!(dX) + p = @inferred plan_fft!(dX) p * dX @test isapprox(collect(dX), fftw_X; rtol=MYRTOL, atol=MYATOL) @@ -50,7 +50,7 @@ function batched(X::AbstractArray{T,N}, region) where {T <: Complex,N} fftw_X = fft(X, region) dX = ROCArray(X) - p = plan_fft!(dX, region) + p = @inferred plan_fft!(dX, region) p * dX @test isapprox(collect(dX), fftw_X; rtol=MYRTOL, atol=MYATOL) @@ -173,7 +173,7 @@ function out_of_place(X::AbstractArray{T,N}) where {T <: Real,N} fftw_X = rfft(X) dX = ROCArray(X) - p = plan_rfft(dX) + p = @inferred plan_rfft(dX) dY = p * dX Y = collect(dY) @test isapprox(Y, fftw_X; rtol=MYRTOL, atol=MYATOL) @@ -197,7 +197,7 @@ function batched(X::AbstractArray{T,N},region) where {T <: Real,N} fftw_X = rfft(X,region) dX = ROCArray(X) - p = plan_rfft(dX, region) + p = @inferred plan_rfft(dX, region) dY = p * dX @test isapprox(collect(dY), fftw_X; rtol=MYRTOL, atol=MYATOL) @test X ≈ collect(dX)