From 9805dd1d6433ba6109ad8a727cc843b0703a2824 Mon Sep 17 00:00:00 2001 From: Wolfhart Feldmeier Date: Mon, 6 Mar 2023 11:30:21 +0100 Subject: [PATCH 1/4] drop adjoints for [i,r,b]fft() Partially addresses https://github.com/FluxML/Zygote.jl/issues/1377 ChainRules for these have been added in https://github.com/JuliaMath/AbstractFFTs.jl/pull/58 --- src/lib/array.jl | 123 ---------------------------------------------- test/gradcheck.jl | 101 +++++++++++++++++-------------------- 2 files changed, 46 insertions(+), 178 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 5086d9c14..dcec8ee21 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -671,12 +671,6 @@ AbstractFFTs.brfft(x::Fill, d, dims...) = AbstractFFTs.brfft(collect(x), d, dims # the adjoint jacobian of an FFT with respect to its input is the reverse FFT of the # gradient of its inputs, but with different normalization factor -@adjoint function fft(xs) - return AbstractFFTs.fft(xs), function(Δ) - return (AbstractFFTs.bfft(Δ),) - end -end - @adjoint function *(P::AbstractFFTs.Plan, xs) return P * xs, function(Δ) N = prod(size(xs)[[P.region...]]) @@ -691,123 +685,6 @@ end end end -# all of the plans normalize their inverse, while we need the unnormalized one. -@adjoint function ifft(xs) - return AbstractFFTs.ifft(xs), function(Δ) - N = length(xs) - return (AbstractFFTs.fft(Δ)/N,) - end -end - -@adjoint function bfft(xs) - return AbstractFFTs.bfft(xs), function(Δ) - return (AbstractFFTs.fft(Δ),) - end -end - -@adjoint function fftshift(x) - return fftshift(x), function(Δ) - return (ifftshift(Δ),) - end -end - -@adjoint function ifftshift(x) - return ifftshift(x), function(Δ) - return (fftshift(Δ),) - end -end - - -# to actually use rfft, one needs to insure that everything -# that happens in the Fourier domain could've been done in -# the space domain with real numbers. This means enforcing -# conjugate symmetry along all transformed dimensions besides -# the first. Otherwise this is going to result in *very* weird -# behavior. -@adjoint function rfft(xs::AbstractArray{<:Real}) - return AbstractFFTs.rfft(xs), function(Δ) - N = length(Δ) - originalSize = size(xs,1) - return (AbstractFFTs.brfft(Δ, originalSize),) - end -end - -@adjoint function irfft(xs, d) - return AbstractFFTs.irfft(xs, d), function(Δ) - total = length(Δ) - fullTransform = AbstractFFTs.rfft(real.(Δ))/total - return (fullTransform, nothing) - end -end - -@adjoint function brfft(xs, d) - return AbstractFFTs.brfft(xs, d), function(Δ) - fullTransform = AbstractFFTs.rfft(real.(Δ)) - return (fullTransform, nothing) - end -end - - -# if we're specifying the dimensions -@adjoint function fft(xs, dims) - return AbstractFFTs.fft(xs, dims), function(Δ) - # dims can be int, array or tuple, - # convert to collection for use as index - dims = collect(dims) - return (AbstractFFTs.bfft(Δ, dims), nothing) - end -end - -@adjoint function bfft(xs, dims) - return AbstractFFTs.ifft(xs, dims), function(Δ) - dims = collect(dims) - return (AbstractFFTs.fft(Δ, dims),nothing) - end -end - -@adjoint function ifft(xs, dims) - return AbstractFFTs.ifft(xs, dims), function(Δ) - dims = collect(dims) - N = prod(collect(size(xs))[dims]) - return (AbstractFFTs.fft(Δ, dims)/N,nothing) - end -end - -@adjoint function rfft(xs, dims) - return AbstractFFTs.rfft(xs, dims), function(Δ) - dims = collect(dims) - N = prod(collect(size(xs))[dims]) - return (N * AbstractFFTs.irfft(Δ, size(xs,dims[1]), dims), nothing) - end -end - -@adjoint function irfft(xs, d, dims) - return AbstractFFTs.irfft(xs, d, dims), function(Δ) - dims = collect(dims) - N = prod(collect(size(xs))[dims]) - return (AbstractFFTs.rfft(real.(Δ), dims)/N, nothing, nothing) - end -end -@adjoint function brfft(xs, d, dims) - return AbstractFFTs.brfft(xs, d, dims), function(Δ) - dims = collect(dims) - return (AbstractFFTs.rfft(real.(Δ), dims), nothing, nothing) - end -end - - -@adjoint function fftshift(x, dims) - return fftshift(x), function(Δ) - return (ifftshift(Δ, dims), nothing) - end -end - -@adjoint function ifftshift(x, dims) - return ifftshift(x), function(Δ) - return (fftshift(Δ, dims), nothing) - end -end - # FillArray functionality # ======================= diff --git a/test/gradcheck.jl b/test/gradcheck.jl index b170aa045..14903bb6d 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1621,15 +1621,13 @@ end @testset "AbstractFFTs" begin - # Many of these tests check a complex gradient to a function with real input. This is now - # clamped to real by ProjectTo, but to run the old tests, use here the old gradient function: - function oldgradient(f, args...) - y, back = Zygote.pullback(f, args...) - back(Zygote.sensitivity(y)) - end - # Eventually these rules and tests will be moved to ChainRules.jl, at which point the tests - # can be updated to use real / complex consistently. + # Eventually these rules and tests will be moved to AbstractFFTs.jl + # Rules for direct invocation of [i,r,b]fft have already been defined in # https://github.com/JuliaMath/AbstractFFTs.jl/pull/58 + # however these apply only if the full method signature (including `region`) is used. + # Otherwise, rules for multiplication and left division by a AbstractFFTs.Plan + # are required, which will be implemented in + # https://github.com/JuliaMath/AbstractFFTs.jl/pull/67 findicateMat(i,j,n1,n2) = [(k==i) && (l==j) ? 1.0 : 0.0 for k=1:n1, l=1:n2] @@ -1643,45 +1641,41 @@ end indicateMat = [(k==i) && (l==j) ? 1.0 : 0.0 for k=1:size(X, 1), l=1:size(X,2)] # gradient of ifft(fft) must be (approximately) 1 (for various cases) - @test oldgradient((X)->real.(ifft(fft(X))[i, j]), X)[1] ≈ indicateMat + @test gradient((X)->real.(ifft(fft(X, (1, 2)), (1, 2))[i, j]), X)[1] ≈ indicateMat # same for the inverse - @test oldgradient((X̂)->real.(fft(ifft(X̂))[i, j]), X̂)[1] ≈ indicateMat + @test gradient((X̂)->real.(fft(ifft(X̂, (1, 2)), (1, 2))[i, j]), X̂)[1] ≈ indicateMat # same for rfft(irfft) - @test oldgradient((X)->real.(irfft(rfft(X), size(X,1)))[i, j], X)[1] ≈ real.(indicateMat) - # rfft isn't actually surjective, so rffft(irfft) can't really be tested this way. + @test gradient((X)->real.(irfft(rfft(X, (1, 2)), size(X,1), (1, 2)))[i, j], X)[1] ≈ real.(indicateMat) + # rfft isn't actually surjective, so rfft(irfft) can't really be tested this way. # the gradients are actually just evaluating the inverse transform on the # indicator matrix mirrorI = mirrorIndex(i,sizeX[1]) FreqIndMat = findicateMat(mirrorI, j, size(X̂r,1), sizeX[2]) - listOfSols = [(fft, bfft(indicateMat), bfft(indicateMat*im), - plan_fft(X), i, X), - (ifft, 1/N*fft(indicateMat), 1/N*fft(indicateMat*im), - plan_fft(X), i, X), - (bfft, fft(indicateMat), fft(indicateMat*im), nothing, i, - X), - (rfft, real.(brfft(FreqIndMat, sizeX[1])), - real.(brfft(FreqIndMat*im, sizeX[1])), plan_rfft(X), - mirrorI, X), - ((K)->(irfft(K,sizeX[1])), 1/N * rfft(indicateMat), - zeros(size(X̂r)), plan_rfft(X), i, X̂r)] - for (trans, solRe, solIm, P, mI, evalX) in listOfSols - @test oldgradient((X)->real.(trans(X))[mI, j], evalX)[1] ≈ + listOfSols = [(X -> fft(X, (1, 2)), real(bfft(indicateMat)), real(bfft(indicateMat*im)), + plan_fft(X), i, X, true), + (K -> ifft(K, (1, 2)), 1/N*real(fft(indicateMat)), 1/N*real(fft(indicateMat*im)), + plan_fft(X), i, X, false), + (X -> bfft(X, (1, 2)), real(fft(indicateMat)), real(fft(indicateMat*im)), nothing, i, + X, false), + ] + for (trans, solRe, solIm, P, mI, evalX, fft_or_rfft) in listOfSols + @test gradient((X)->real.(trans(X))[mI, j], evalX)[1] ≈ solRe - @test oldgradient((X)->imag.(trans(X))[mI, j], evalX)[1] ≈ + @test gradient((X)->imag.(trans(X))[mI, j], evalX)[1] ≈ solIm - if typeof(P) <:AbstractFFTs.Plan && maximum(trans .== [fft,rfft]) - @test oldgradient((X)->real.(P * X)[mI, j], evalX)[1] ≈ + if typeof(P) <:AbstractFFTs.Plan && fft_or_rfft + @test gradient((X)->real.(P * X)[mI, j], evalX)[1] ≈ solRe - @test oldgradient((X)->imag.(P * X)[mI, j], evalX)[1] ≈ + @test gradient((X)->imag.(P * X)[mI, j], evalX)[1] ≈ solIm elseif typeof(P) <: AbstractFFTs.Plan - @test oldgradient((X)->real.(P \ X)[mI, j], evalX)[1] ≈ + @test gradient((X)->real.(P \ X)[mI, j], evalX)[1] ≈ solRe # for whatever reason the rfft_plan doesn't handle this case well, # even though irfft does if eltype(evalX) <: Real - @test oldgradient((X)->imag.(P \ X)[mI, j], evalX)[1] ≈ + @test gradient((X)->imag.(P \ X)[mI, j], evalX)[1] ≈ solIm end end @@ -1692,47 +1686,44 @@ end x = [-0.353213 -0.789656 -0.270151; -0.95719 -1.27933 0.223982] # check ffts for individual dimensions for trans in (fft, ifft, bfft) - @test oldgradient((x)->sum(abs.(trans(x))), x)[1] ≈ - oldgradient( (x) -> sum(abs.(trans(trans(x,1),2))), x)[1] + @test gradient((x)->sum(abs.(trans(x, (1, 2)))), x)[1] ≈ + gradient( (x) -> sum(abs.(trans(trans(x,1),2))), x)[1] # switch sum abs order - @test oldgradient((x)->abs(sum((trans(x)))),x)[1] ≈ - oldgradient( (x) -> abs(sum(trans(trans(x,1),2))), x)[1] + @test gradient((x)->abs(sum((trans(x)))),x)[1] ≈ + gradient( (x) -> abs(sum(trans(trans(x,1),2))), x)[1] # dims parameter for the function - @test oldgradient((x, dims)->sum(abs.(trans(x,dims))), x, (1,2))[1] ≈ - oldgradient( (x) -> sum(abs.(trans(x))), x)[1] - # (1,2) should be the same as no index - @test oldgradient( (x) -> sum(abs.(trans(x,(1,2)))), x)[1] ≈ - oldgradient( (x) -> sum(abs.(trans(trans(x,1),2))), x)[1] - @test gradcheck(x->sum(abs.(trans(x))), x) + @test gradient((x, dims)->sum(abs.(trans(x,dims))), x, (1,2))[1] ≈ + gradient( (x) -> sum(abs.(trans(x, (1, 2)))), x)[1] + @test gradcheck(x->sum(abs.(trans(x, (1, 2)))), x) @test gradcheck(x->sum(abs.(trans(x, 2))), x) end - @test oldgradient((x)->sum(abs.(rfft(x))), x)[1] ≈ - oldgradient( (x) -> sum(abs.(fft(rfft(x,1),2))), x)[1] - @test oldgradient((x, dims)->sum(abs.(rfft(x,dims))), x, (1,2))[1] ≈ - oldgradient( (x) -> sum(abs.(rfft(x))), x)[1] + @test gradient((x)->sum(abs.(rfft(x, (1, 2)))), x)[1] ≈ + gradient( (x) -> sum(abs.(fft(rfft(x,1),2))), x)[1] + @test gradient((x, dims)->sum(abs.(rfft(x,dims))), x, (1,2))[1] ≈ + gradient( (x) -> sum(abs.(rfft(x, (1, 2)))), x)[1] # Test type stability of fft x = randn(Float64,16) P = plan_fft(x) - @test typeof(oldgradient(x->sum(abs2,ifft(fft(x))),x)[1]) == Array{Complex{Float64},1} - @test typeof(oldgradient(x->sum(abs2,P\(P*x)),x)[1]) == Array{Complex{Float64},1} - @test typeof(oldgradient(x->sum(abs2,irfft(rfft(x),16)),x)[1]) == Array{Float64,1} + @test typeof(gradient(x->sum(abs2,ifft(fft(x, 1), 1)),x)[1]) == Array{Float64,1} + @test typeof(gradient(x->sum(abs2,P\(P*x)),x)[1]) == Array{Float64,1} + @test typeof(gradient(x->sum(abs2,irfft(rfft(x, 1),16, 1)),x)[1]) == Array{Float64,1} x = randn(Float64,16,16) - @test typeof(oldgradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float64},2} - @test typeof(oldgradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2} + @test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Float64,2} + @test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2} x = randn(Float32,16) P = plan_fft(x) - @test typeof(oldgradient(x->sum(abs2,ifft(fft(x))),x)[1]) == Array{Complex{Float32},1} - @test typeof(oldgradient(x->sum(abs2,P\(P*x)),x)[1]) == Array{Complex{Float32},1} - @test typeof(oldgradient(x->sum(abs2,irfft(rfft(x),16)),x)[1]) == Array{Float32,1} + @test typeof(gradient(x->sum(abs2,ifft(fft(x, 1), 1)),x)[1]) == Array{Float32,1} + @test typeof(gradient(x->sum(abs2,P\(P*x)),x)[1]) == Array{Float32,1} + @test typeof(gradient(x->sum(abs2,irfft(rfft(x, 1),16, 1)),x)[1]) == Array{Float32,1} x = randn(Float32,16,16) - @test typeof(oldgradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float32},2} - @test typeof(oldgradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float32,2} + @test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Float32,2} + @test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float32,2} end @testset "FillArrays" begin From a5342294e282ae1b296d2da1b0c3d5efe16d6388 Mon Sep 17 00:00:00 2001 From: Wolfhart Feldmeier Date: Wed, 8 Mar 2023 15:33:13 +0100 Subject: [PATCH 2/4] add back gradient test for *fft without dims argument --- test/gradcheck.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 14903bb6d..6a3a1f6c3 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1624,10 +1624,11 @@ end # Eventually these rules and tests will be moved to AbstractFFTs.jl # Rules for direct invocation of [i,r,b]fft have already been defined in # https://github.com/JuliaMath/AbstractFFTs.jl/pull/58 - # however these apply only if the full method signature (including `region`) is used. - # Otherwise, rules for multiplication and left division by a AbstractFFTs.Plan - # are required, which will be implemented in + + # ChainRules involing AbstractFFTs.Plan are not yet part of AbstractFFTs, + # but there is a WIP PR: # https://github.com/JuliaMath/AbstractFFTs.jl/pull/67 + # After the above is merged, this testset can probably be removed entirely. findicateMat(i,j,n1,n2) = [(k==i) && (l==j) ? 1.0 : 0.0 for k=1:n1, l=1:n2] @@ -1641,11 +1642,11 @@ end indicateMat = [(k==i) && (l==j) ? 1.0 : 0.0 for k=1:size(X, 1), l=1:size(X,2)] # gradient of ifft(fft) must be (approximately) 1 (for various cases) - @test gradient((X)->real.(ifft(fft(X, (1, 2)), (1, 2))[i, j]), X)[1] ≈ indicateMat + @test gradient((X)->real.(ifft(fft(X))[i, j]), X)[1] ≈ indicateMat # same for the inverse - @test gradient((X̂)->real.(fft(ifft(X̂, (1, 2)), (1, 2))[i, j]), X̂)[1] ≈ indicateMat + @test gradient((X̂)->real.(fft(ifft(X̂))[i, j]), X̂)[1] ≈ indicateMat # same for rfft(irfft) - @test gradient((X)->real.(irfft(rfft(X, (1, 2)), size(X,1), (1, 2)))[i, j], X)[1] ≈ real.(indicateMat) + @test gradient((X)->real.(irfft(rfft(X), size(X,1)))[i, j], X)[1] ≈ real.(indicateMat) # rfft isn't actually surjective, so rfft(irfft) can't really be tested this way. # the gradients are actually just evaluating the inverse transform on the From 8c531d8d4ce204200404a59c42011a62cbabb87d Mon Sep 17 00:00:00 2001 From: Wolfhart Feldmeier Date: Wed, 8 Mar 2023 15:34:11 +0100 Subject: [PATCH 3/4] increase compat constraint for AbstractFFTs to 1.3.1 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 93b548336..f78a49654 100644 --- a/Project.toml +++ b/Project.toml @@ -27,7 +27,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] -AbstractFFTs = "0.5, 1.0" +AbstractFFTs = "1.3.1" ChainRules = "1.44.1" ChainRulesCore = "1.9" ChainRulesTestUtils = "1" From cc5b12a17e7fbded8abd212b7af4976af2be28a4 Mon Sep 17 00:00:00 2001 From: trahflow Date: Wed, 8 Mar 2023 22:13:11 +0100 Subject: [PATCH 4/4] fix typo Co-authored-by: Brian Chen --- test/gradcheck.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 6a3a1f6c3..6e9b954fc 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1625,7 +1625,7 @@ end # Rules for direct invocation of [i,r,b]fft have already been defined in # https://github.com/JuliaMath/AbstractFFTs.jl/pull/58 - # ChainRules involing AbstractFFTs.Plan are not yet part of AbstractFFTs, + # ChainRules involving AbstractFFTs.Plan are not yet part of AbstractFFTs, # but there is a WIP PR: # https://github.com/JuliaMath/AbstractFFTs.jl/pull/67 # After the above is merged, this testset can probably be removed entirely.