Skip to content

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Aug 22, 2022

This should fix the bug from https://discourse.julialang.org/t/forwarddiff-jl-error-loaderror-methoderror-convert-is-ambiguous/86132 , which must be cause by #655

julia> using Zygote

julia> function f3(x)
           A = ones(5,5)*x
           maximum(A)
       end
f3 (generic function with 1 method)

julia> gradient(f3, 0.5)
(1.0,)

julia> hessian(f3, 0.5)
ERROR: MethodError: convert(::Type{ForwardDiff.Dual{ForwardDiff.Tag{Zygote.var"#104#105"{typeof(f3)}, Float64}, Float64, 1}}, ::ChainRulesCore.ZeroTangent) is ambiguous.

Candidates:
  convert(::Type{T}, x::ChainRulesCore.AbstractZero) where T<:Number
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_types/abstract_zero.jl:31
  convert(::Type{ForwardDiff.Dual{T, V, N}}, x) where {T, V, N}
    @ ForwardDiff ~/.julia/packages/ForwardDiff/pDtsf/src/dual.jl:432

Possible fix, define
  convert(::Type{ForwardDiff.Dual{T, V, N}}, ::ChainRulesCore.AbstractZero) where {T, V, N}

Stacktrace:
  [1] fill!(dest::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{Zygote.var"#104#105"{typeof(f3)}, Float64}, Float64, 1}}, x::ChainRulesCore.ZeroTangent)
    @ Base ./array.jl:347
  [2] _setindex_zero(::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{Zygote.var"#104#105"{typeof(f3)}, Float64}, Float64, 1}}, ::ForwardDiff.Dual{ForwardDiff.Tag{Zygote.var"#104#105"{typeof(f3)}, Float64}, Float64, 1}, ::Int64, ::Vararg{Int64})
    @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/indexing.jl:104
  [3] ∇getindex(x::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{Zygote.var"#104#105"{typeof(f3)}, Float64}, Float64, 1}}, dy::ForwardDiff.Dual{ForwardDiff.Tag{Zygote.var"#104#105"{typeof(f3)}, Float64}, Float64, 1}, inds::CartesianIndex{2})
    @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/indexing.jl:89

Xref JuliaDiff/ChainRulesCore.jl#448


I think the reason Zygote is calling this rule at all is that its rrule_via_ad function takes a shortcut which doesn't check its own rules:

https://github.com/FluxML/Zygote.jl/blob/99d5a38b14dc842643acfa624b6f0f89061efbbf/src/compiler/chainrules.jl#L243-L246

Edit: maybe not, sorry. The rule for maximum calls ∇getindex directly.


Needs a test. Do we add ForwardDiff just for this?

@github-actions github-actions bot added the needs version bump Version needs to be incremented or set to -DEV in Project.toml label Aug 22, 2022
@oxinabox
Copy link
Member

Needs a test. Do we add ForwardDiff just for this?

I would rather not.
I am not sure how specific we want to be just now with what we promise for this function.
So I think it is fine to just leave this as is -- it passes current tests whicch i assume hit the changed line.

So I am kind of ok having this without additional tests

@mcabbott mcabbott merged commit 39c2d17 into main Aug 23, 2022
@mcabbott mcabbott deleted the mcabbott-patch-3 branch August 23, 2022 11:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs version bump Version needs to be incremented or set to -DEV in Project.toml

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants