Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
name = "AbstractFFTs"
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
version = "1.5.0"
version = "1.6.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[extensions]
AbstractFFTsChainRulesCoreExt = "ChainRulesCore"
AbstractFFTsForwardDiffExt = "ForwardDiff"
AbstractFFTsTestExt = "Test"

[compat]
Aqua = "0.8"
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
FiniteDifferences = "0.12"
ForwardDiff = "0.10"
LinearAlgebra = "<0.0.1, 1"
Random = "<0.0.1, 1"
Test = "<0.0.1, 1"
Expand All @@ -31,9 +35,10 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["Aqua", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Random", "Test", "Unitful"]
test = ["Aqua", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "ForwardDiff", "Random", "Test", "Unitful"]
66 changes: 66 additions & 0 deletions ext/AbstractFFTsForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
module AbstractFFTsForwardDiffExt

using AbstractFFTs
using AbstractFFTs.LinearAlgebra
import ForwardDiff
import ForwardDiff: Dual
import AbstractFFTs: Plan, mul!, dualplan, dual2array


AbstractFFTs.complexfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.complexfloat.(x)
AbstractFFTs.complexfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V),N}, d) + 0im

AbstractFFTs.realfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.realfloat.(x)
AbstractFFTs.realfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V),N}, d)

dual2array(x::Array{<:Dual{Tag,T}}) where {Tag,T} = reinterpret(reshape, T, x)
dual2array(x::Array{<:Complex{<:Dual{Tag, T}}}) where {Tag,T} = complex.(dual2array(real(x)), dual2array(imag(x)))
array2dual(DT::Type{<:Dual}, x::Array{T}) where T = reinterpret(reshape, DT, real(x))
array2dual(DT::Type{<:Dual}, x::Array{<:Complex{T}}) where T = complex.(array2dual(DT, real(x)), array2dual(DT, imag(x)))


########
# DualPlan
# represents a plan acting on dual numbers. We wrap a plan acting on a higher dimensional tensor
# as an array of duals can be reinterpreted as a higher dimensional array.
# This allows standard FFTW plans to act on arrays of duals.
#####
struct DualPlan{T,P} <: Plan{T}
p::P
DualPlan{T,P}(p) where {T,P} = new(p)
end

DualPlan(::Type{Dual{Tag,V,N}}, p::Plan{T}) where {Tag,T<:Real,V,N} = DualPlan{Dual{Tag,T,N},typeof(p)}(p)
DualPlan(::Type{Dual{Tag,V,N}}, p::Plan{Complex{T}}) where {Tag,T<:Real,V,N} = DualPlan{Complex{Dual{Tag,T,N}},typeof(p)}(p)
dualplan(D, p) = DualPlan(D, p)
Base.size(p::DualPlan) = Base.tail(size(p.p))

Check warning on line 36 in ext/AbstractFFTsForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractFFTsForwardDiffExt.jl#L36

Added line #L36 was not covered by tests
Base.:*(p::DualPlan{DT}, x::AbstractArray{DT}) where DT<:Dual = array2dual(DT, p.p * dual2array(x))
Base.:*(p::DualPlan{Complex{DT}}, x::AbstractArray{Complex{DT}}) where DT<:Dual = array2dual(DT, p.p * dual2array(x))

function LinearAlgebra.mul!(y::AbstractArray{<:Dual}, p::DualPlan, x::AbstractArray{<:Dual})
LinearAlgebra.mul!(dual2array(y), p.p, dual2array(x)) # even though `Dual` are immutable, when in an `Array` they can be modified.
y

Check warning on line 42 in ext/AbstractFFTsForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractFFTsForwardDiffExt.jl#L40-L42

Added lines #L40 - L42 were not covered by tests
end

function LinearAlgebra.mul!(y::AbstractArray{<:Complex{<:Dual}}, p::DualPlan, x::AbstractArray{<:Union{Dual,Complex{<:Dual}}})
copyto!(y, p*x) # Complex duals cannot be reinterpret in-place
end


for plan in (:plan_fft, :plan_ifft, :plan_bfft, :plan_rfft)
@eval begin
AbstractFFTs.$plan(x::AbstractArray{D}, dims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims))
AbstractFFTs.$plan(x::AbstractArray{<:Complex{D}}, dims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims))
end
end


for plan in (:plan_irfft, :plan_brfft) # these take an extra argument, only when complex?
@eval begin
AbstractFFTs.$plan(x::AbstractArray{D}, dims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims))

Check warning on line 60 in ext/AbstractFFTsForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractFFTsForwardDiffExt.jl#L60

Added line #L60 was not covered by tests
AbstractFFTs.$plan(x::AbstractArray{<:Complex{D}}, d::Integer, dims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), d, 1 .+ dims))
end
end


end # module
6 changes: 6 additions & 0 deletions src/AbstractFFTs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,15 @@ export fft, ifft, bfft, fft!, ifft!, bfft!,
include("definitions.jl")
include("TestUtils.jl")

# Create function used by multiple extension as loading order is not guaranteed
function dualplan end
function dual2array end

if !isdefined(Base, :get_extension)
include("../ext/AbstractFFTsChainRulesCoreExt.jl")
include("../ext/AbstractFFTsTestExt.jl")
include("../ext/AbstractFFTsForwardDiffExt.jl")
end


end # module
60 changes: 60 additions & 0 deletions test/abstractfftsforwarddiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
using AbstractFFTs
using ForwardDiff
using Test
using ForwardDiff: Dual, partials, value

# Needed until https://github.com/JuliaDiff/ForwardDiff.jl/pull/732 is merged
complexpartials(x, k) = partials(real(x), k) + im*partials(imag(x), k)

@testset "ForwardDiff extension tests" begin
x1 = Dual.(1:4.0, 2:5, 3:6)

@test AbstractFFTs.complexfloat(x1)[1] === AbstractFFTs.complexfloat(x1[1]) === Dual(1.0, 2.0, 3.0) + 0im
@test AbstractFFTs.realfloat(x1)[1] === AbstractFFTs.realfloat(x1[1]) === Dual(1.0, 2.0, 3.0)

@test fft(x1, 1)[1] isa Complex{<:Dual}

@testset "$f" for f in (fft, ifft, rfft, bfft)
@test value.(f(x1)) == f(value.(x1))
@test complexpartials.(f(x1), 1) == f(partials.(x1, 1))
@test complexpartials.(f(x1), 2) == f(partials.(x1, 2))
end

@test ifft(fft(x1)) ≈ x1
@test irfft(rfft(x1), length(x1)) ≈ x1
@test brfft(rfft(x1), length(x1)) ≈ 4x1

f = x -> real(fft([x; 0; 0])[1])
@test ForwardDiff.derivative(f,0.1) ≈ 1

r = x -> real(rfft([x; 0; 0])[1])
@test ForwardDiff.derivative(r,0.1) ≈ 1


n = 100
θ = range(0,2π; length=n+1)[1:end-1]
# emperical from Mathematical
@test ForwardDiff.derivative(ω -> fft(exp.(ω .* cos.(θ)))[1]/n, 1) ≈ 0.565159103992485

# c = x -> dct([x; 0; 0])[1]
# @test derivative(c,0.1) ≈ 1

@testset "matrix" begin
A = x1 * (1:10)'
@test value.(fft(A)) == fft(value.(A))
@test complexpartials.(fft(A), 1) == fft(partials.(A, 1))
@test complexpartials.(fft(A), 2) == fft(partials.(A, 2))

@test value.(fft(A, 1)) == fft(value.(A), 1)
@test complexpartials.(fft(A, 1), 1) == fft(partials.(A, 1), 1)
@test complexpartials.(fft(A, 1), 2) == fft(partials.(A, 2), 1)

@test value.(fft(A, 2)) == fft(value.(A), 2)
@test complexpartials.(fft(A, 2), 1) == fft(partials.(A, 1), 2)
@test complexpartials.(fft(A, 2), 2) == fft(partials.(A, 2), 2)
end

c1 = complex.(x1)
@test mul!(similar(c1), plan_fft(x1), x1) == fft(x1)
@test mul!(similar(c1), plan_fft(c1), c1) == fft(c1)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,4 @@ end
end
end

include("abstractfftsforwarddiff.jl")
Loading