Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractFFTs = "0.5, 1.0"
AbstractFFTs = "1.3.1"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be done?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bumping the compat version of a dependency to add a new min bound? That's perfectly fine as long as we're not suddenly breaking anything. We do it all the time for e.g. Flux -> NNlib.

ChainRules = "1.44.1"
ChainRulesCore = "1.9"
ChainRulesTestUtils = "1"
Expand Down
123 changes: 0 additions & 123 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -671,12 +671,6 @@ AbstractFFTs.brfft(x::Fill, d, dims...) = AbstractFFTs.brfft(collect(x), d, dims

# the adjoint jacobian of an FFT with respect to its input is the reverse FFT of the
# gradient of its inputs, but with different normalization factor
@adjoint function fft(xs)
return AbstractFFTs.fft(xs), function(Δ)
return (AbstractFFTs.bfft(Δ),)
end
end

@adjoint function *(P::AbstractFFTs.Plan, xs)
return P * xs, function(Δ)
N = prod(size(xs)[[P.region...]])
Expand All @@ -691,123 +685,6 @@ end
end
end

# all of the plans normalize their inverse, while we need the unnormalized one.
@adjoint function ifft(xs)
return AbstractFFTs.ifft(xs), function(Δ)
N = length(xs)
return (AbstractFFTs.fft(Δ)/N,)
end
end

@adjoint function bfft(xs)
return AbstractFFTs.bfft(xs), function(Δ)
return (AbstractFFTs.fft(Δ),)
end
end

@adjoint function fftshift(x)
return fftshift(x), function(Δ)
return (ifftshift(Δ),)
end
end

@adjoint function ifftshift(x)
return ifftshift(x), function(Δ)
return (fftshift(Δ),)
end
end


# to actually use rfft, one needs to insure that everything
# that happens in the Fourier domain could've been done in
# the space domain with real numbers. This means enforcing
# conjugate symmetry along all transformed dimensions besides
# the first. Otherwise this is going to result in *very* weird
# behavior.
@adjoint function rfft(xs::AbstractArray{<:Real})
return AbstractFFTs.rfft(xs), function(Δ)
N = length(Δ)
originalSize = size(xs,1)
return (AbstractFFTs.brfft(Δ, originalSize),)
end
end

@adjoint function irfft(xs, d)
return AbstractFFTs.irfft(xs, d), function(Δ)
total = length(Δ)
fullTransform = AbstractFFTs.rfft(real.(Δ))/total
return (fullTransform, nothing)
end
end

@adjoint function brfft(xs, d)
return AbstractFFTs.brfft(xs, d), function(Δ)
fullTransform = AbstractFFTs.rfft(real.(Δ))
return (fullTransform, nothing)
end
end


# if we're specifying the dimensions
@adjoint function fft(xs, dims)
return AbstractFFTs.fft(xs, dims), function(Δ)
# dims can be int, array or tuple,
# convert to collection for use as index
dims = collect(dims)
return (AbstractFFTs.bfft(Δ, dims), nothing)
end
end

@adjoint function bfft(xs, dims)
return AbstractFFTs.ifft(xs, dims), function(Δ)
dims = collect(dims)
return (AbstractFFTs.fft(Δ, dims),nothing)
end
end

@adjoint function ifft(xs, dims)
return AbstractFFTs.ifft(xs, dims), function(Δ)
dims = collect(dims)
N = prod(collect(size(xs))[dims])
return (AbstractFFTs.fft(Δ, dims)/N,nothing)
end
end

@adjoint function rfft(xs, dims)
return AbstractFFTs.rfft(xs, dims), function(Δ)
dims = collect(dims)
N = prod(collect(size(xs))[dims])
return (N * AbstractFFTs.irfft(Δ, size(xs,dims[1]), dims), nothing)
end
end

@adjoint function irfft(xs, d, dims)
return AbstractFFTs.irfft(xs, d, dims), function(Δ)
dims = collect(dims)
N = prod(collect(size(xs))[dims])
return (AbstractFFTs.rfft(real.(Δ), dims)/N, nothing, nothing)
end
end
@adjoint function brfft(xs, d, dims)
return AbstractFFTs.brfft(xs, d, dims), function(Δ)
dims = collect(dims)
return (AbstractFFTs.rfft(real.(Δ), dims), nothing, nothing)
end
end


@adjoint function fftshift(x, dims)
return fftshift(x), function(Δ)
return (ifftshift(Δ, dims), nothing)
end
end

@adjoint function ifftshift(x, dims)
return ifftshift(x), function(Δ)
return (fftshift(Δ, dims), nothing)
end
end

# FillArray functionality
# =======================

Expand Down
102 changes: 47 additions & 55 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1621,16 +1621,15 @@ end

@testset "AbstractFFTs" begin

# Many of these tests check a complex gradient to a function with real input. This is now
# clamped to real by ProjectTo, but to run the old tests, use here the old gradient function:
function oldgradient(f, args...)
y, back = Zygote.pullback(f, args...)
back(Zygote.sensitivity(y))
end
# Eventually these rules and tests will be moved to ChainRules.jl, at which point the tests
# can be updated to use real / complex consistently.
# Eventually these rules and tests will be moved to AbstractFFTs.jl
# Rules for direct invocation of [i,r,b]fft have already been defined in
# https://github.com/JuliaMath/AbstractFFTs.jl/pull/58

# ChainRules involing AbstractFFTs.Plan are not yet part of AbstractFFTs,
# but there is a WIP PR:
# https://github.com/JuliaMath/AbstractFFTs.jl/pull/67
# After the above is merged, this testset can probably be removed entirely.

findicateMat(i,j,n1,n2) = [(k==i) && (l==j) ? 1.0 : 0.0 for k=1:n1,
l=1:n2]
mirrorIndex(i,N) = i - 2*max(0,i - (N>>1+1))
Expand All @@ -1643,45 +1642,41 @@ end
indicateMat = [(k==i) && (l==j) ? 1.0 : 0.0 for k=1:size(X, 1),
l=1:size(X,2)]
# gradient of ifft(fft) must be (approximately) 1 (for various cases)
@test oldgradient((X)->real.(ifft(fft(X))[i, j]), X)[1] ≈ indicateMat
@test gradient((X)->real.(ifft(fft(X))[i, j]), X)[1] ≈ indicateMat
# same for the inverse
@test oldgradient((X̂)->real.(fft(ifft(X̂))[i, j]), X̂)[1] ≈ indicateMat
@test gradient((X̂)->real.(fft(ifft(X̂))[i, j]), X̂)[1] ≈ indicateMat
# same for rfft(irfft)
@test oldgradient((X)->real.(irfft(rfft(X), size(X,1)))[i, j], X)[1] ≈ real.(indicateMat)
# rfft isn't actually surjective, so rffft(irfft) can't really be tested this way.
@test gradient((X)->real.(irfft(rfft(X), size(X,1)))[i, j], X)[1] ≈ real.(indicateMat)
# rfft isn't actually surjective, so rfft(irfft) can't really be tested this way.

# the gradients are actually just evaluating the inverse transform on the
# indicator matrix
mirrorI = mirrorIndex(i,sizeX[1])
FreqIndMat = findicateMat(mirrorI, j, size(X̂r,1), sizeX[2])
listOfSols = [(fft, bfft(indicateMat), bfft(indicateMat*im),
plan_fft(X), i, X),
(ifft, 1/N*fft(indicateMat), 1/N*fft(indicateMat*im),
plan_fft(X), i, X),
(bfft, fft(indicateMat), fft(indicateMat*im), nothing, i,
X),
(rfft, real.(brfft(FreqIndMat, sizeX[1])),
real.(brfft(FreqIndMat*im, sizeX[1])), plan_rfft(X),
mirrorI, X),
((K)->(irfft(K,sizeX[1])), 1/N * rfft(indicateMat),
zeros(size(X̂r)), plan_rfft(X), i, X̂r)]
for (trans, solRe, solIm, P, mI, evalX) in listOfSols
@test oldgradient((X)->real.(trans(X))[mI, j], evalX)[1] ≈
listOfSols = [(X -> fft(X, (1, 2)), real(bfft(indicateMat)), real(bfft(indicateMat*im)),
plan_fft(X), i, X, true),
(K -> ifft(K, (1, 2)), 1/N*real(fft(indicateMat)), 1/N*real(fft(indicateMat*im)),
plan_fft(X), i, X, false),
(X -> bfft(X, (1, 2)), real(fft(indicateMat)), real(fft(indicateMat*im)), nothing, i,
X, false),
]
for (trans, solRe, solIm, P, mI, evalX, fft_or_rfft) in listOfSols
@test gradient((X)->real.(trans(X))[mI, j], evalX)[1] ≈
solRe
@test oldgradient((X)->imag.(trans(X))[mI, j], evalX)[1] ≈
@test gradient((X)->imag.(trans(X))[mI, j], evalX)[1] ≈
solIm
if typeof(P) <:AbstractFFTs.Plan && maximum(trans .== [fft,rfft])
@test oldgradient((X)->real.(P * X)[mI, j], evalX)[1] ≈
if typeof(P) <:AbstractFFTs.Plan && fft_or_rfft
@test gradient((X)->real.(P * X)[mI, j], evalX)[1] ≈
solRe
@test oldgradient((X)->imag.(P * X)[mI, j], evalX)[1] ≈
@test gradient((X)->imag.(P * X)[mI, j], evalX)[1] ≈
solIm
elseif typeof(P) <: AbstractFFTs.Plan
@test oldgradient((X)->real.(P \ X)[mI, j], evalX)[1] ≈
@test gradient((X)->real.(P \ X)[mI, j], evalX)[1] ≈
solRe
# for whatever reason the rfft_plan doesn't handle this case well,
# even though irfft does
if eltype(evalX) <: Real
@test oldgradient((X)->imag.(P \ X)[mI, j], evalX)[1] ≈
@test gradient((X)->imag.(P \ X)[mI, j], evalX)[1] ≈
solIm
end
end
Expand All @@ -1692,47 +1687,44 @@ end
x = [-0.353213 -0.789656 -0.270151; -0.95719 -1.27933 0.223982]
# check ffts for individual dimensions
for trans in (fft, ifft, bfft)
@test oldgradient((x)->sum(abs.(trans(x))), x)[1] ≈
oldgradient( (x) -> sum(abs.(trans(trans(x,1),2))), x)[1]
@test gradient((x)->sum(abs.(trans(x, (1, 2)))), x)[1] ≈
gradient( (x) -> sum(abs.(trans(trans(x,1),2))), x)[1]
# switch sum abs order
@test oldgradient((x)->abs(sum((trans(x)))),x)[1] ≈
oldgradient( (x) -> abs(sum(trans(trans(x,1),2))), x)[1]
@test gradient((x)->abs(sum((trans(x)))),x)[1] ≈
gradient( (x) -> abs(sum(trans(trans(x,1),2))), x)[1]
# dims parameter for the function
@test oldgradient((x, dims)->sum(abs.(trans(x,dims))), x, (1,2))[1] ≈
oldgradient( (x) -> sum(abs.(trans(x))), x)[1]
# (1,2) should be the same as no index
@test oldgradient( (x) -> sum(abs.(trans(x,(1,2)))), x)[1] ≈
oldgradient( (x) -> sum(abs.(trans(trans(x,1),2))), x)[1]
@test gradcheck(x->sum(abs.(trans(x))), x)
@test gradient((x, dims)->sum(abs.(trans(x,dims))), x, (1,2))[1] ≈
gradient( (x) -> sum(abs.(trans(x, (1, 2)))), x)[1]
@test gradcheck(x->sum(abs.(trans(x, (1, 2)))), x)
@test gradcheck(x->sum(abs.(trans(x, 2))), x)
end

@test oldgradient((x)->sum(abs.(rfft(x))), x)[1] ≈
oldgradient( (x) -> sum(abs.(fft(rfft(x,1),2))), x)[1]
@test oldgradient((x, dims)->sum(abs.(rfft(x,dims))), x, (1,2))[1] ≈
oldgradient( (x) -> sum(abs.(rfft(x))), x)[1]
@test gradient((x)->sum(abs.(rfft(x, (1, 2)))), x)[1] ≈
gradient( (x) -> sum(abs.(fft(rfft(x,1),2))), x)[1]
@test gradient((x, dims)->sum(abs.(rfft(x,dims))), x, (1,2))[1] ≈
gradient( (x) -> sum(abs.(rfft(x, (1, 2)))), x)[1]

# Test type stability of fft

x = randn(Float64,16)
P = plan_fft(x)
@test typeof(oldgradient(x->sum(abs2,ifft(fft(x))),x)[1]) == Array{Complex{Float64},1}
@test typeof(oldgradient(x->sum(abs2,P\(P*x)),x)[1]) == Array{Complex{Float64},1}
@test typeof(oldgradient(x->sum(abs2,irfft(rfft(x),16)),x)[1]) == Array{Float64,1}
@test typeof(gradient(x->sum(abs2,ifft(fft(x, 1), 1)),x)[1]) == Array{Float64,1}
@test typeof(gradient(x->sum(abs2,P\(P*x)),x)[1]) == Array{Float64,1}
@test typeof(gradient(x->sum(abs2,irfft(rfft(x, 1),16, 1)),x)[1]) == Array{Float64,1}

x = randn(Float64,16,16)
@test typeof(oldgradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float64},2}
@test typeof(oldgradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2}
@test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Float64,2}
@test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2}

x = randn(Float32,16)
P = plan_fft(x)
@test typeof(oldgradient(x->sum(abs2,ifft(fft(x))),x)[1]) == Array{Complex{Float32},1}
@test typeof(oldgradient(x->sum(abs2,P\(P*x)),x)[1]) == Array{Complex{Float32},1}
@test typeof(oldgradient(x->sum(abs2,irfft(rfft(x),16)),x)[1]) == Array{Float32,1}
@test typeof(gradient(x->sum(abs2,ifft(fft(x, 1), 1)),x)[1]) == Array{Float32,1}
@test typeof(gradient(x->sum(abs2,P\(P*x)),x)[1]) == Array{Float32,1}
@test typeof(gradient(x->sum(abs2,irfft(rfft(x, 1),16, 1)),x)[1]) == Array{Float32,1}

x = randn(Float32,16,16)
@test typeof(oldgradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float32},2}
@test typeof(oldgradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float32,2}
@test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Float32,2}
@test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float32,2}
end

@testset "FillArrays" begin
Expand Down