Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AbstractFFTs"
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
version = "1.1.0"
version = "1.2.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
14 changes: 14 additions & 0 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@ size(p::Plan, d) = size(p)[d]
ndims(p::Plan) = length(size(p))
length(p::Plan) = prod(size(p))::Int

"""
region(p::Plan)

Return an iterable of the dimensions that are transformed by the FFT plan `p`.

# Implementation

The default definition of `region` returns `p.region`.
Hence this method should be implemented only for types of `Plan`s that do not store the transformed region in a field of name `region`.
"""
region(p::Plan) = p.region

fftfloat(x) = _fftfloat(float(x))
_fftfloat(::Type{T}) where {T<:BlasReal} = T
_fftfloat(::Type{Float16}) = Float32
Expand Down Expand Up @@ -243,6 +255,8 @@ ScaledPlan(p::ScaledPlan, α::Number) = ScaledPlan(p.p, p.scale * α)

size(p::ScaledPlan) = size(p.p)

region(p::ScaledPlan) = region(p.p)

show(io::IO, p::ScaledPlan) = print(io, p.scale, " * ", p.p)
summary(p::ScaledPlan) = string(p.scale, " * ", summary(p.p))

Expand Down
8 changes: 7 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,21 @@ end
@test eltype(P) === ComplexF64
@test P * x ≈ fftw_fft
@test P \ (P * x) ≈ x
@test AbstractFFTs.region(P) == dims

fftw_bfft = complex.(size(x, dims) .* x)
@test AbstractFFTs.bfft(y, dims) ≈ fftw_bfft
P = plan_bfft(x, dims)
@test P * y ≈ fftw_bfft
@test P \ (P * y) ≈ y
@test AbstractFFTs.region(P) == dims

fftw_ifft = complex.(x)
@test AbstractFFTs.ifft(y, dims) ≈ fftw_ifft
P = plan_ifft(x, dims)
@test P * y ≈ fftw_ifft
@test P \ (P * y) ≈ y
@test AbstractFFTs.region(P) == dims

# real FFT
fftw_rfft = fftw_fft[
Expand All @@ -84,18 +87,21 @@ end
@test eltype(P) === Int
@test P * x ≈ fftw_rfft
@test P \ (P * x) ≈ x
@test AbstractFFTs.region(P) == dims

fftw_brfft = complex.(size(x, dims) .* x)
@test AbstractFFTs.brfft(ry, size(x, dims), dims) ≈ fftw_brfft
P = plan_brfft(ry, size(x, dims), dims)
@test P * ry ≈ fftw_brfft
@test P \ (P * ry) ≈ ry
@test AbstractFFTs.region(P) == dims

fftw_irfft = complex.(x)
@test AbstractFFTs.irfft(ry, size(x, dims), dims) ≈ fftw_irfft
P = plan_irfft(ry, size(x, dims), dims)
@test P * ry ≈ fftw_irfft
@test P \ (P * ry) ≈ ry
@test AbstractFFTs.region(P) == dims
end
end

Expand Down Expand Up @@ -170,7 +176,7 @@ end
# normalization should be inferable even if region is only inferred as ::Any,
# need to wrap in another function to test this (note that p.region::Any for
# p::TestPlan)
f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, p.region)
f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, region(p))
@test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10
end

Expand Down