Skip to content

Incorrect gradients for plan_rfft(x) * x #1496

@awadell1

Description

@awadell1

Gradients given by Zygote for a planned rfft differ significantly from ChainRulesCore gradients. Appears to be releated to #1437 #899 #1377

Adjoint in question:

Zygote.jl/src/lib/array.jl

Lines 645 to 659 in 54f1e80

# the adjoint jacobian of an FFT with respect to its input is the reverse FFT of the
# gradient of its inputs, but with different normalization factor
@adjoint function *(P::AbstractFFTs.Plan, xs)
return P * xs, function(Δ)
N = prod(size(xs)[[P.region...]])
return (nothing, N * (P \ Δ))
end
end
@adjoint function \(P::AbstractFFTs.Plan, xs)
return P \ xs, function(Δ)
N = prod(size(Δ)[[P.region...]])
return (nothing, (P * Δ)/N)
end
end

Minimal Workable Example

using Pkg
Pkg.activate(; temp=true)
Pkg.add(["Zygote", "FFTW", "ChainRulesCore"])
using Zygote
using FFTW
using ChainRulesCore

x = rand(3,3)

# No Dims
p = plan_rfft(x)
y, back = Zygote.pullback(*, p, x)
@info "Zygote - No Dims" back(one.(y))[2]
y, back = ChainRulesCore.rrule(*, p, x)
@info "ChainRules - No Dims" back(one.(y))[3]

# dims = 1
p = plan_rfft(x, 1)
y, back = Zygote.pullback(*, p, x)
@info "Zygote - dims=1" back(one.(y))[2]
y, back = ChainRulesCore.rrule(*, p, x)
@info "ChainRules - dims=1" back(one.(y))[3]

Output:

❯ julia zygote.jl 
  Activating new project at `/tmp/jl_9uzHHd`
   Resolving package versions...
    Updating `/tmp/jl_9uzHHd/Project.toml`
  [d360d2e6] + ChainRulesCore v1.19.1
  [7a1cc6ca] + FFTW v1.7.2
  [e88e6eb3] + Zygote v0.6.68
    Updating `/tmp/jl_9uzHHd/Manifest.toml`
  [621f4979] + AbstractFFTs v1.5.0
[...]
┌ Info: Zygote - No Dims
│   (back(one.(y)))[2] =3×3 Matrix{Float64}:9.0  0.0  0.00.0  0.0  0.00.0  0.0  0.0
┌ Info: ChainRules - No Dims
│   (back(one.(y)))[3] =3×3 Matrix{Float64}:6.0  0.0  0.01.5  0.0  0.01.5  0.0  0.0
┌ Info: Zygote - dims=1
│   (back(one.(y)))[2] =3×3 Matrix{Float64}:3.0  3.0  3.00.0  0.0  0.00.0  0.0  0.0
┌ Info: ChainRules - dims=1
│   (back(one.(y)))[3] =3×3 Matrix{Float64}:2.0  2.0  2.00.5  0.5  0.50.5  0.5  0.5

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions