Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesOverloadGeneration"
uuid = "f51149dc-2911-5acf-81fc-2076a2a81d4f"
version = "0.1.2"
version = "0.1.3"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
3 changes: 2 additions & 1 deletion src/ruleset_loading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ _is_fallback(::typeof(rrule), m::Method) = m.sig === Tuple{typeof(rrule),Any,Var
_is_fallback(::typeof(frule), m::Method) = m.sig === Tuple{typeof(frule),Any,Any,Vararg{Any}}

"check if this rule requires a particular configuation (`RuleConfig`)"
_requires_config(m::Method) = m.sig.parameters[2] <: RuleConfig
_requires_config(m::Method) = m.sig <: Tuple{Any, RuleConfig, Vararg}


const LAST_REFRESH_RRULE = Ref(0)
const LAST_REFRESH_FRULE = Ref(0)
Expand Down
27 changes: 20 additions & 7 deletions test/ruleset_loading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,27 @@
end

@testset "should not have rrules that need RuleConfig" begin
old_rrule_list = collect(_rule_list(rrule))
function ChainRulesCore.rrule(
::RuleConfig{>:Union{HasForwardsMode,HasReverseMode}}, sum, f, xs
)
return 1.0, x->(x,x,x) # this will not be call so return doesn't matter
@testset "normal type sigs" begin
old_rrule_list = collect(_rule_list(rrule))
function ChainRulesCore.rrule(
::RuleConfig{>:Union{HasForwardsMode,HasReverseMode}}, sum, f, xs
)
return 1.0, x->(x,x,x) # this will not be call so return doesn't matter
end
# New rule should not have appeared
@test collect(_rule_list(rrule)) == old_rrule_list
end
@testset "UnionAll type sigs" begin
old_rrule_list = collect(_rule_list(rrule))
function ChainRulesCore.rrule(
::RuleConfig{>:Union{HasForwardsMode,HasReverseMode}}, sum, f::F, xs
) where F <: Function
return 1.0, x->(x,x,x) # this will not be call so return doesn't matter
end
# New rule should not have appeared
@test collect(_rule_list(rrule)) == old_rrule_list
# Above would error if we were not handling UnionAll's right
Copy link
Member

@mzgubic mzgubic Jun 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What part is the UnionAll?

I can't construct a method which has a signature of UnionAll actually

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a UnionAll over all values F can take

julia> function alt_rrule(
           ::RuleConfig{>:Union{HasForwardsMode,HasReverseMode}}, sum, f::F, xs
       ) where F <: Function
           return 1.0, x->(x,x,x) # this will not be call so return doesn't matter
       end
alt_rrule (generic function with 1 method)

julia> method = first(methods(alt_rrule))
alt_rrule(::RuleConfig{var"#s3"} where var"#s3">:Union{HasForwardsMode, HasReverseMode}, sum, f::F, xs) where F<:Function in Main at REPL[9]:1

julia> method.sig
Tuple{typeof(alt_rrule), RuleConfig{var"#s3"} where var"#s3">:Union{HasForwardsMode, HasReverseMode}, Any, F, Any} where F<:Function

julia> typeof(method.sig)
UnionAll

end
# New rule should not have appeared
@test collect(_rule_list(rrule)) == old_rrule_list
end
end
end