Skip to content

Commit 2b60a83

Browse files
committed
Add ForwardDiff extension
1 parent 70524d2 commit 2b60a83

File tree

4 files changed

+42
-1
lines changed

4 files changed

+42
-1
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AbstractFFTs"
22
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
3-
version = "1.5.0"
3+
version = "1.6.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -9,10 +9,12 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
99

1010
[weakdeps]
1111
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
12+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1213
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1314

1415
[extensions]
1516
AbstractFFTsChainRulesCoreExt = "ChainRulesCore"
17+
AbstractFFTsForwardDiffExt = "ForwardDiff"
1618
AbstractFFTsTestExt = "Test"
1719

1820
[compat]

ext/AbstractFFTsForwardDiffExt.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
module AbstractFFTsForwardDiffExt
2+
3+
using AbstractFFTs
4+
import ForwardDiff
5+
import ForwardDiff: Dual
6+
import AbstractFFTs: Plan
7+
8+
for P in (:Plan, :ScaledPlan) # need ScaledPlan to avoid ambiguities
9+
@eval begin
10+
Base.:*(p::AbstractFFTs.$P, x::AbstractArray{DT}) where DT<:Dual = array2dual(DT, p * dual2array(x))
11+
Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Complex{DT}}) where DT<:Dual = array2dual(DT, p * dual2array(x))
12+
end
13+
end
14+
15+
mul!(y::AbstractArray{<:Union{Dual,Complex{<:Dual}}}, p::Plan, x::AbstractArray{<:Union{Dual,Complex{<:Dual}}}) = copyto!(y, p*x)
16+
17+
AbstractFFTs.complexfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.complexfloat.(x)
18+
AbstractFFTs.complexfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V),N}, d) + 0im
19+
20+
AbstractFFTs.realfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.realfloat.(x)
21+
AbstractFFTs.realfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V),N}, d)
22+
23+
dual2array(x::Array{<:Dual{Tag,T}}) where {Tag,T} = reinterpret(reshape, T, x)
24+
dual2array(x::Array{<:Complex{<:Dual{Tag, T}}}) where {Tag,T} = complex.(dual2array(real(x)), dual2array(imag(x)))
25+
array2dual(DT::Type{<:Dual}, x::Array{T}) where T = reinterpret(reshape, DT, real(x))
26+
array2dual(DT::Type{<:Dual}, x::Array{<:Complex{T}}) where T = complex.(array2dual(DT, real(x)), array2dual(DT, imag(x)))
27+
28+
29+
for plan in (:plan_fft, :plan_ifft, :plan_bfft, :plan_rfft)
30+
@eval begin
31+
AbstractFFTs.$plan(x::AbstractArray{<:Dual}, dims=1:ndims(x)) = AbstractFFTs.$plan(dual2array(x), 1 .+ dims)
32+
AbstractFFTs.$plan(x::AbstractArray{<:Complex{<:Dual}}, dims=1:ndims(x)) = AbstractFFTs.$plan(dual2array(x), 1 .+ dims)
33+
end
34+
end
35+
36+
37+
38+
end # module

test/abstractfftsforwarddiff.jl

Whitespace-only changes.

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,3 +274,4 @@ end
274274
end
275275
end
276276

277+
include("abstractfftsforwarddiff.jl")

0 commit comments

Comments
 (0)