Skip to content

Commit 6028c0e

Browse files
committed
Add DualPlan
1 parent 1695c7e commit 6028c0e

File tree

2 files changed

+45
-22
lines changed

2 files changed

+45
-22
lines changed

ext/AbstractFFTsForwardDiffExt.jl

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,11 @@
11
module AbstractFFTsForwardDiffExt
22

33
using AbstractFFTs
4+
using AbstractFFTs.LinearAlgebra
45
import ForwardDiff
56
import ForwardDiff: Dual
6-
import AbstractFFTs: Plan
7+
import AbstractFFTs: Plan, mul!
78

8-
for P in (:Plan, :ScaledPlan) # need ScaledPlan to avoid ambiguities
9-
@eval begin
10-
AbstractFFTs.plan_mul(p::AbstractFFTs.$P, x::AbstractArray{DT}) where DT<:Dual = array2dual(DT, p * dual2array(x))
11-
AbstractFFTs.plan_mul(p::AbstractFFTs.$P, x::AbstractArray{Complex{DT}}) where DT<:Dual = array2dual(DT, p * dual2array(x))
12-
end
13-
end
14-
15-
mul!(y::AbstractArray{<:Union{Dual,Complex{<:Dual}}}, p::Plan, x::AbstractArray{<:Union{Dual,Complex{<:Dual}}}) = copyto!(y, p*x)
169

1710
AbstractFFTs.complexfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.complexfloat.(x)
1811
AbstractFFTs.complexfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V),N}, d) + 0im
@@ -26,10 +19,37 @@ array2dual(DT::Type{<:Dual}, x::Array{T}) where T = reinterpret(reshape, DT, rea
2619
array2dual(DT::Type{<:Dual}, x::Array{<:Complex{T}}) where T = complex.(array2dual(DT, real(x)), array2dual(DT, imag(x)))
2720

2821

22+
########
23+
# DualPlan
24+
# represents a plan acting on dual numbers. We wrap a plan acting on a higher dimensional tensor
25+
# as an array of duals can be reinterpreted as a higher dimensional array.
26+
# This allows standard FFTW plans to act on arrays of duals.
27+
#####
28+
struct DualPlan{T,P} <: Plan{T}
29+
p::P
30+
DualPlan{T,P}(p) where {T,P} = new(p)
31+
end
32+
33+
DualPlan(::Type{Dual{Tag,V,N}}, p::Plan{T}) where {Tag,T<:Real,V,N} = DualPlan{Dual{Tag,T,N},typeof(p)}(p)
34+
DualPlan(::Type{Dual{Tag,V,N}}, p::Plan{Complex{T}}) where {Tag,T<:Real,V,N} = DualPlan{Complex{Dual{Tag,T,N}},typeof(p)}(p)
35+
Base.size(p::DualPlan) = Base.tail(size(p.p))
36+
Base.:*(p::DualPlan{DT}, x::AbstractArray{DT}) where DT<:Dual = array2dual(DT, p.p * dual2array(x))
37+
Base.:*(p::DualPlan{Complex{DT}}, x::AbstractArray{Complex{DT}}) where DT<:Dual = array2dual(DT, p.p * dual2array(x))
38+
39+
function LinearAlgebra.mul!(y::AbstractArray{<:Dual}, p::DualPlan, x::AbstractArray{<:Dual})
40+
LinearAlgebra.mul!(dual2array(y), p.p, dual2array(x)) # even though `Dual` are immutable, when in an `Array` they can be modified.
41+
y
42+
end
43+
44+
function LinearAlgebra.mul!(y::AbstractArray{<:Complex{<:Dual}}, p::DualPlan, x::AbstractArray{<:Union{Dual,Complex{<:Dual}}})
45+
copyto!(y, p*x) # Complex duals cannot be reinterpret in-place
46+
end
47+
48+
2949
for plan in (:plan_fft, :plan_ifft, :plan_bfft, :plan_rfft)
3050
@eval begin
31-
AbstractFFTs.$plan(x::AbstractArray{<:Dual}, dims=1:ndims(x)) = AbstractFFTs.$plan(dual2array(x), 1 .+ dims)
32-
AbstractFFTs.$plan(x::AbstractArray{<:Complex{<:Dual}}, dims=1:ndims(x)) = AbstractFFTs.$plan(dual2array(x), 1 .+ dims)
51+
AbstractFFTs.$plan(x::AbstractArray{D}, dims=1:ndims(x)) where D<:Dual = DualPlan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims))
52+
AbstractFFTs.$plan(x::AbstractArray{<:Complex{D}}, dims=1:ndims(x)) where D<:Dual = DualPlan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims))
3353
end
3454
end
3555

test/abstractfftsforwarddiff.jl

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ using ForwardDiff
33
using Test
44
using ForwardDiff: Dual, partials, value
55

6+
# Needed until https://github.com/JuliaDiff/ForwardDiff.jl/pull/732 is merged
7+
complexpartials(x, k) = partials(real(x), k) + im*partials(imag(x), k)
8+
69
@testset "ForwardDiff extension tests" begin
710
x1 = Dual.(1:4.0, 2:5, 3:6)
811

@@ -13,42 +16,42 @@ using ForwardDiff: Dual, partials, value
1316

1417
@testset "$f" for f in (fft, ifft, rfft, bfft)
1518
@test value.(f(x1)) == f(value.(x1))
16-
@test partials.(real(f(x1)), 1) + im*partials.(imag(f(x1)), 1) == f(partials.(x1, 1))
17-
@test partials.(real(f(x1)), 2) + im*partials.(imag(f(x1)), 2) == f(partials.(x1, 2))
19+
@test complexpartials.(f(x1), 1) == f(partials.(x1, 1))
20+
@test complexpartials.(f(x1), 2) == f(partials.(x1, 2))
1821
end
1922

2023
@test ifft(fft(x1)) x1
2124
@test irfft(rfft(x1), length(x1)) x1
2225
@test brfft(rfft(x1), length(x1)) 4x1
2326

2427
f = x -> real(fft([x; 0; 0])[1])
25-
@test derivative(f,0.1) 1
28+
@test ForwardDiff.derivative(f,0.1) 1
2629

2730
r = x -> real(rfft([x; 0; 0])[1])
28-
@test derivative(r,0.1) 1
31+
@test ForwardDiff.derivative(r,0.1) 1
2932

3033

3134
n = 100
3235
θ = range(0,2π; length=n+1)[1:end-1]
3336
# emperical from Mathematical
34-
@test derivative-> fft(exp.(ω .* cos.(θ)))[1]/n, 1) 0.565159103992485
37+
@test ForwardDiff.derivative-> fft(exp.(ω .* cos.(θ)))[1]/n, 1) 0.565159103992485
3538

3639
# c = x -> dct([x; 0; 0])[1]
3740
# @test derivative(c,0.1) ≈ 1
3841

3942
@testset "matrix" begin
4043
A = x1 * (1:10)'
4144
@test value.(fft(A)) == fft(value.(A))
42-
@test partials.(fft(A), 1) == fft(partials.(A, 1))
43-
@test partials.(fft(A), 2) == fft(partials.(A, 2))
45+
@test complexpartials.(fft(A), 1) == fft(partials.(A, 1))
46+
@test complexpartials.(fft(A), 2) == fft(partials.(A, 2))
4447

4548
@test value.(fft(A, 1)) == fft(value.(A), 1)
46-
@test partials.(fft(A, 1), 1) == fft(partials.(A, 1), 1)
47-
@test partials.(fft(A, 1), 2) == fft(partials.(A, 2), 1)
49+
@test complexpartials.(fft(A, 1), 1) == fft(partials.(A, 1), 1)
50+
@test complexpartials.(fft(A, 1), 2) == fft(partials.(A, 2), 1)
4851

4952
@test value.(fft(A, 2)) == fft(value.(A), 2)
50-
@test partials.(fft(A, 2), 1) == fft(partials.(A, 1), 2)
51-
@test partials.(fft(A, 2), 2) == fft(partials.(A, 2), 2)
53+
@test complexpartials.(fft(A, 2), 1) == fft(partials.(A, 1), 2)
54+
@test complexpartials.(fft(A, 2), 2) == fft(partials.(A, 2), 2)
5255
end
5356

5457
c1 = complex.(x1)

0 commit comments

Comments
 (0)