Skip to content

Commit 99d5a38

Browse files
authored
Merge pull request #1278 from FluxML/bc/fix-rrule-lookup
Propagate ambiguities from rrule lookup instead of failing inexplicably
2 parents 5c80f55 + 328eb4d commit 99d5a38

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

src/compiler/chainrules.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ matching_cr_sig(t, s) = matching_cr_sig(t.method.sig, s.method.sig)
7373
matching_cr_sig(::DataType, ::UnionAll) = false
7474
matching_cr_sig(::UnionAll, ::DataType) = false
7575
matching_cr_sig(t::Type, s::Type) = type_tuple_tail(t) == type_tuple_tail(s)
76-
76+
matching_cr_sig(::Any, ::Nothing) = false # https://github.com/FluxML/Zygote.jl/issues/1234
77+
7778
type_tuple_tail(d::DataType) = Tuple{d.parameters[2:end]...}
7879
function type_tuple_tail(d::UnionAll)
7980
body = Base.unwrap_unionall(d)

test/chainrules.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,15 @@ using Zygote: ZygoteRuleConfig
275275
@test Zygote.gradient(x -> f_notimplemented(only(x)), [0.1]) === (nothing,)
276276
end
277277
end
278+
279+
# https://github.com/FluxML/Zygote.jl/issues/1234
280+
@testset "rrule lookup ambiguities" begin
281+
f_ambig(x, y) = x + y
282+
ChainRulesCore.rrule(::typeof(f_ambig), x::Int, y) = x + y, _ -> (0, 0)
283+
ChainRulesCore.rrule(::typeof(f_ambig), x, y::Int) = x + y, _ -> (0, 0)
284+
285+
@test_throws MethodError pullback(f_ambig, 1, 2)
286+
end
278287
end
279288

280289
@testset "ChainRulesCore.rrule_via_ad" begin

0 commit comments

Comments
 (0)