-
-
Notifications
You must be signed in to change notification settings - Fork 216
Closed
Description
Gradients given by Zygote for a planned rfft differ significantly from ChainRulesCore gradients. Appears to be releated to #1437 #899 #1377
Adjoint in question:
Lines 645 to 659 in 54f1e80
# the adjoint jacobian of an FFT with respect to its input is the reverse FFT of the | |
# gradient of its inputs, but with different normalization factor | |
@adjoint function *(P::AbstractFFTs.Plan, xs) | |
return P * xs, function(Δ) | |
N = prod(size(xs)[[P.region...]]) | |
return (nothing, N * (P \ Δ)) | |
end | |
end | |
@adjoint function \(P::AbstractFFTs.Plan, xs) | |
return P \ xs, function(Δ) | |
N = prod(size(Δ)[[P.region...]]) | |
return (nothing, (P * Δ)/N) | |
end | |
end |
Minimal Workable Example
using Pkg
Pkg.activate(; temp=true)
Pkg.add(["Zygote", "FFTW", "ChainRulesCore"])
using Zygote
using FFTW
using ChainRulesCore
x = rand(3,3)
# No Dims
p = plan_rfft(x)
y, back = Zygote.pullback(*, p, x)
@info "Zygote - No Dims" back(one.(y))[2]
y, back = ChainRulesCore.rrule(*, p, x)
@info "ChainRules - No Dims" back(one.(y))[3]
# dims = 1
p = plan_rfft(x, 1)
y, back = Zygote.pullback(*, p, x)
@info "Zygote - dims=1" back(one.(y))[2]
y, back = ChainRulesCore.rrule(*, p, x)
@info "ChainRules - dims=1" back(one.(y))[3]
Output:
❯ julia zygote.jl
Activating new project at `/tmp/jl_9uzHHd`
Resolving package versions...
Updating `/tmp/jl_9uzHHd/Project.toml`
[d360d2e6] + ChainRulesCore v1.19.1
[7a1cc6ca] + FFTW v1.7.2
[e88e6eb3] + Zygote v0.6.68
Updating `/tmp/jl_9uzHHd/Manifest.toml`
[621f4979] + AbstractFFTs v1.5.0
[...]
┌ Info: Zygote - No Dims
│ (back(one.(y)))[2] =
│ 3×3 Matrix{Float64}:
│ 9.0 0.0 0.0
│ 0.0 0.0 0.0
└ 0.0 0.0 0.0
┌ Info: ChainRules - No Dims
│ (back(one.(y)))[3] =
│ 3×3 Matrix{Float64}:
│ 6.0 0.0 0.0
│ 1.5 0.0 0.0
└ 1.5 0.0 0.0
┌ Info: Zygote - dims=1
│ (back(one.(y)))[2] =
│ 3×3 Matrix{Float64}:
│ 3.0 3.0 3.0
│ 0.0 0.0 0.0
└ 0.0 0.0 0.0
┌ Info: ChainRules - dims=1
│ (back(one.(y)))[3] =
│ 3×3 Matrix{Float64}:
│ 2.0 2.0 2.0
│ 0.5 0.5 0.5
└ 0.5 0.5 0.5
Metadata
Metadata
Assignees
Labels
No labels