-
-
Notifications
You must be signed in to change notification settings - Fork 216
Open
Labels
needs adjointmissing rulemissing rule
Description
Not shure if this belongs here or to ChainRules, but taking the gradient of functions that use isposdef fails with ERROR: MethodError: no method matching iterate(::Nothing). See this MWE:
using Zygote, LinearAlgebra
# a is actually positive definite
a = rand(3,3); a = a*a'
# works fine
function f(x)
return x
end
gradient(f, 1.0)
# errors
function f(x)
isposdef(a)
return x
end
gradient(f, 1.0)
ERROR: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
iterate(::Union{LinRange, StepRangeLen}) at /usr/share/julia/base/range.jl:826
iterate(::Union{LinRange, StepRangeLen}, ::Integer) at /usr/share/julia/base/range.jl:826
iterate(::T) where T<:Union{Base.KeySet{<:Any, <:Dict}, Base.ValueIterator{<:Dict}} at /usr/share/julia/base/dict.jl:695
...
Stacktrace:
[1] indexed_iterate(I::Nothing, i::Int64)
@ Base ./tuple.jl:92
[2] chain_rrule_kw
@ ~/.julia/packages/Zygote/DkIUK/src/compiler/chainrules.jl:229 [inlined]
[3] macro expansion
@ ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0 [inlined]
[4] _pullback(::Zygote.Context, ::LinearAlgebra.var"#cholesky##kw", ::NamedTuple{(:check,), Tuple{Bool}}, ::typeof(cholesky), ::Hermitian{Float64, Matrix{Float64}})
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:9
[5] _pullback
@ /usr/share/julia/stdlib/v1.7/LinearAlgebra/src/dense.jl:92 [inlined]
[6] _pullback(ctx::Zygote.Context, f::typeof(isposdef), args::Matrix{Float64})
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
[7] _pullback
@ ~/Downloads/test_zygote.jl:151 [inlined]
[8] _pullback(ctx::Zygote.Context, f::typeof(f), args::Float64)
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
[9] _pullback(f::Function, args::Float64)
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:34
[10] pullback(f::Function, args::Float64)
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:40
[11] gradient(f::Function, args::Float64)
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:75
[12] top-level scope
@ ~/Downloads/test_zygote.jl:155
The problem seems to be calling isposdef and Hermitian, as this fails too:
function f(x)
isposdef(cholesky(Hermitian(a); check = false))
return x
end
gradient(f, 1.0)But interestingly enough, this works just fine:
function f(x)
isposdef(cholesky(Symmetric(a); check = false))
return x
end
gradient(f, 1.0)
function f(x)
Hermitian(a)
return x
end
gradient(f, 1.0)Version info:
Julia Version 1.7.3
Commit 742b9abb4d (2022-05-06 12:58 UTC)
Platform Info:
OS: Linux (x86_64-pc-linux-gnu)
CPU: Intel(R) Core(TM) i5-8265U CPU @ 1.60GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-12.0.1 (ORCJIT, skylake)
Environment:
JULIA_EDITOR = code
JULIA_NUM_THREADS = 1
Package versions:
[082447d4] ChainRules v1.35.1
[d360d2e6] ChainRulesCore v1.15.0
[a93c6f00] DataFrames v1.3.4
[31c24e10] Distributions v0.25.62
[6a86dc24] FiniteDiff v2.12.1
[f6369f11] ForwardDiff v0.10.30
[d3d80556] LineSearches v7.1.1
[d41bc354] NLSolversBase v7.8.2
[76087f3c] NLopt v0.6.5
[429524aa] Optim v1.7.0
[08abe8d2] PrettyTables v1.3.1
[2913bbd2] StatsBase v0.33.16
[78862bba] StenoGraphs v0.2.0
[0c5d862f] Symbolics v4.6.0
[e88e6eb3] Zygote v0.6.40
[8bb1440f] DelimitedFiles
[4af54fe1] LazyArtifacts
[37e2e46d] LinearAlgebra
[44cfe95a] Pkg
[9a3f8284] Random
[2f01184e] SparseArrays
[10745b16] Statistics
Metadata
Metadata
Assignees
Labels
needs adjointmissing rulemissing rule