diff --git a/Project.toml b/Project.toml index 8e10b5e..edde3ec 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesOverloadGeneration" uuid = "f51149dc-2911-5acf-81fc-2076a2a81d4f" -version = "0.1.2" +version = "0.1.3" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/ruleset_loading.jl b/src/ruleset_loading.jl index d59c306..b57cab3 100644 --- a/src/ruleset_loading.jl +++ b/src/ruleset_loading.jl @@ -66,7 +66,8 @@ _is_fallback(::typeof(rrule), m::Method) = m.sig === Tuple{typeof(rrule),Any,Var _is_fallback(::typeof(frule), m::Method) = m.sig === Tuple{typeof(frule),Any,Any,Vararg{Any}} "check if this rule requires a particular configuation (`RuleConfig`)" -_requires_config(m::Method) = m.sig.parameters[2] <: RuleConfig +_requires_config(m::Method) = m.sig <: Tuple{Any, RuleConfig, Vararg} + const LAST_REFRESH_RRULE = Ref(0) const LAST_REFRESH_FRULE = Ref(0) diff --git a/test/ruleset_loading.jl b/test/ruleset_loading.jl index 662a4e8..4e261a3 100644 --- a/test/ruleset_loading.jl +++ b/test/ruleset_loading.jl @@ -90,14 +90,27 @@ end @testset "should not have rrules that need RuleConfig" begin - old_rrule_list = collect(_rule_list(rrule)) - function ChainRulesCore.rrule( - ::RuleConfig{>:Union{HasForwardsMode,HasReverseMode}}, sum, f, xs - ) - return 1.0, x->(x,x,x) # this will not be call so return doesn't matter + @testset "normal type sigs" begin + old_rrule_list = collect(_rule_list(rrule)) + function ChainRulesCore.rrule( + ::RuleConfig{>:Union{HasForwardsMode,HasReverseMode}}, sum, f, xs + ) + return 1.0, x->(x,x,x) # this will not be call so return doesn't matter + end + # New rule should not have appeared + @test collect(_rule_list(rrule)) == old_rrule_list + end + @testset "UnionAll type sigs" begin + old_rrule_list = collect(_rule_list(rrule)) + function ChainRulesCore.rrule( + ::RuleConfig{>:Union{HasForwardsMode,HasReverseMode}}, sum, f::F, xs + ) where F <: Function + return 1.0, x->(x,x,x) # this will not be call so return doesn't matter + end + # New rule should not have appeared + @test collect(_rule_list(rrule)) == old_rrule_list + # Above would error if we were not handling UnionAll's right end - # New rule should not have appeared - @test collect(_rule_list(rrule)) == old_rrule_list end end end