Skip to content

Commit 8fa0ecb

Browse files
committed
opting out of rules
1 parent 23ec91d commit 8fa0ecb

File tree

7 files changed

+242
-5
lines changed

7 files changed

+242
-5
lines changed

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ makedocs(
4848
"Introduction" => "index.md",
4949
"FAQ" => "FAQ.md",
5050
"Rule configurations and calling back into AD" => "config.md",
51+
"Opting out of rules" => "opting_out_of_rules.md",
5152
"Writing Good Rules" => "writing_good_rules.md",
5253
"Complex Numbers" => "complex.md",
5354
"Deriving Array Rules" => "arrays.md",

docs/src/api.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,6 @@ ProjectTo
5050
```@docs
5151
ChainRulesCore.AbstractTangent
5252
ChainRulesCore.debug_mode
53+
ChainRulesCore.no_rrule
54+
ChainRulesCore.no_frule
5355
```

docs/src/opting_out_of_rules.md

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Opting out of rules
2+
3+
It is common to define rules fairly generically.
4+
Often matching (or exceeding) how generic the matching original primal method is.
5+
Sometimes this is not the correct behavour.
6+
Sometimes the AD can do better than this human defined rule.
7+
If this is generally the case, then we should not have the rule defined at all.
8+
But if it is only the case for a particular set of types, then we want to opt-out just that one.
9+
This is done with the [`@opt_out`](@ref) macro.
10+
11+
Consider one might have a rrule for `sum` (the following simplified from the one in [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/Base/mapreduce.jl) itself)
12+
```julia
13+
function rrule(::typeof(sum), x::AbstractArray{<:Number}; dims=:)
14+
y = sum(x; dims=dims)
15+
project = ProjectTo(x)
16+
function sum_pullback(ȳ)
17+
# broadcasting the two works out the size no-matter `dims`
18+
# project makes sure we stay in the same vector subspace as `x`
19+
# no putting in off-diagonal entries in Diagonal etc
20+
= project(broadcast(lasttuple, x, ȳ)))
21+
return (NoTangent(), x̄)
22+
end
23+
return y, sum_pullback
24+
end
25+
```
26+
27+
That is a fairly reasonable `rrule` for the vast majority of cases.
28+
29+
You might have a custom array type for which you could write a faster rule.
30+
For example, the pullback for summing a`SkewSymmetric` matrix can be optimizes to basically be `Diagonal(fill(ȳ, size(x,1)))`.
31+
To do that, you can indeed write another more specific [`rrule`](@ref).
32+
But another case is where the AD system itself would generate a more optimized case.
33+
34+
For example, the a [`NamedDimArray`](https://github.com/invenia/NamedDims.jl) is a thin wrapper around some other array type.
35+
It's sum method is basically just to call `sum` on it's parent.
36+
It is entirely conceivable[^1] that the AD system can do better than our `rrule` here.
37+
For example by avoiding the overhead of [`project`ing](@ref ProjectTo).
38+
39+
To opt-out of using the `rrule` and to allow the AD system to do its own thing we use the
40+
[`@opt_out`](@ref) macro, to say to not use it for sum.
41+
42+
```julia
43+
@opt_out rrule(::typeof(sum), ::NamedDimsArray)
44+
```
45+
46+
We could even opt-out for all 1 arg functions.
47+
```@julia
48+
@opt_out rrule(::Any, ::NamedDimsArray)
49+
```
50+
Though this is likely to cause some method-ambiguities.
51+
52+
Similar can be done `@opt_out frule`.
53+
It can also be done passing in a [`RuleConfig`](@ref config).
54+
55+
56+
### How to support this (for AD implementers)
57+
58+
We provide two ways to know that a rule has been opted out of.
59+
60+
## `rrule` / `frule` returns `nothing`
61+
62+
`@opt_out` defines a `frule` or `rrule` matching the signature that returns `nothing`.
63+
64+
If you are in a position to generate code, in response to values returned by function calls then you can do something like:
65+
```@julia
66+
res = rrule(f, xs)
67+
if res === nothing
68+
y, pullback = perform_ad_via_decomposition(r, xs) # do AD without hitting the rrule
69+
else
70+
y, pullback = res
71+
end
72+
```
73+
The Julia compiler, will specialize based on inferring the restun type of `rrule`, and so can remove that branch.
74+
75+
## `no_rrule` / `no_frule` has a method
76+
77+
`@opt_out` also defines a method for [`ChainRulesCore.no_frule`](@ref) or [`ChainRulesCore.no_rrule`](@ref).
78+
The use of this method doesn't matter, what matters is it's method-table.
79+
A simple thing you can do with this is not support opting out.
80+
To do this, filter all methods from the `rrule`/`frule` method table that also occur in the `no_frule`/`no_rrule` table.
81+
This will thus avoid ever hitting an `rrule`/`frule` that returns `nothing` and thus makes your library error.
82+
This is easily done, though it does mean ignoring the user's stated desire to opt out of the rule.
83+
84+
More complex you can use this to generate code that triggers your AD.
85+
If for a given signature there is a more specific method in the `no_rrule`/`no_frule` method-table, than the one that would be hit from the `rrule`/`frule` table
86+
(Excluding the one that exactly matches which will return `nothing`) then you know that the rule should not be used.
87+
You can, likely by looking at the primal method table, workout which method you would have it if the rule had not been defined,
88+
and then `invoke` it.
89+
90+
91+
92+
[^1]: It is also possible, that this is not the case. Benchmark your real uses cases.

src/ChainRulesCore.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ export frule, rrule # core function
1010
export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMode
1111
export frule_via_ad, rrule_via_ad
1212
# definition helper macros
13-
export @non_differentiable, @scalar_rule, @thunk, @not_implemented
13+
export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented
1414
export ProjectTo, canonicalize, unthunk # differential operations
1515
export add!! # gradient accumulation operations
1616
# differentials

src/rule_definition_tools.jl

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# These are some macros (and supporting functions) to make it easier to define rules.
22

3+
# Note: must be declared before it is used, which is later in this file.
4+
macro strip_linenos(expr)
5+
return esc(Base.remove_linenums!(expr))
6+
end
7+
38
############################################################################################
49
### @scalar_rule
510

@@ -323,7 +328,7 @@ macro non_differentiable(sig_expr)
323328
:($(primal_name)($(unconstrained_args...)))
324329
else
325330
normal_args = unconstrained_args[1:end-1]
326-
var_arg = s[end]
331+
var_arg = unconstrained_args[end]
327332
:($(primal_name)($(normal_args...), $(var_arg)...))
328333
end
329334

@@ -392,13 +397,73 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke)
392397
end
393398
end
394399

400+
395401
############################################################################################
396-
# Helpers
402+
# @opt_out
397403

398-
macro strip_linenos(expr)
399-
return esc(Base.remove_linenums!(expr))
404+
"""
405+
@opt_out frule([config], _, f, args...)
406+
@opt_out rrule([config], f, args...)
407+
408+
This allows you to opt-out of a `frule` or `rrule` by providing a more specific method,
409+
that says to use the AD system, to solver it.
410+
411+
For example, consider some function `foo(x::AbtractArray)`.
412+
In general, you know a efficicent and generic way to implement it's `rrule`.
413+
You do so, (likely making use of [`ProjectTo`](@ref)).
414+
But it actually turns out that for some `FancyArray` type it is better to let the AD do it's
415+
thing.
416+
417+
Then you would write something like:
418+
```julia
419+
function rrule(::typeof(foo), x::AbstractArray)
420+
foo_pullback(ȳ) = ...
421+
return foo(x), foo_pullback
400422
end
401423
424+
@opt_out rrule(::typeof(foo), ::FancyArray)
425+
```
426+
427+
This will generate a [`rrule`](@ref) that returns `nothing`,
428+
and will also add a similar entry to [`ChainRulesCore.no_rrule`](@ref).
429+
430+
Similar applies for [`frule`](@ref) and [`ChainRulesCore.no_frule`](@ref)
431+
"""
432+
macro opt_out(expr)
433+
no_rule_target = _no_rule_target_rewrite!(deepcopy(expr))
434+
435+
return @strip_linenos quote
436+
$(esc(no_rule_target)) = nothing
437+
$(esc(expr)) = nothing
438+
end
439+
end
440+
441+
function _no_rule_target_rewrite!(call_target::Symbol)
442+
return if call_target == :rrule
443+
:(ChainRulesCore.no_rrule)
444+
elseif call_target == :frule
445+
:(ChainRulesCore.no_frule)
446+
else
447+
error("Unexpected opt-out target. Exprected frule or rrule, got: $call_target")
448+
end
449+
end
450+
_no_rule_target_rewrite!(qt::QuoteNode) = _no_rule_target_rewrite!(qt.value)
451+
function _no_rule_target_rewrite!(expr::Expr)
452+
length(expr.args)===0 && error("Malformed method expression. $expr")
453+
if expr.head === :call || expr.head === :where
454+
expr.args[1] = _no_rule_target_rewrite!(expr.args[1])
455+
elseif expr.head == :(.) && expr.args[1] == :ChainRulesCore
456+
expr = _no_rule_target_rewrite!(expr.args[end])
457+
else
458+
error("Malformed method expression. $(expr)")
459+
end
460+
return expr
461+
end
462+
463+
464+
############################################################################################
465+
# Helpers
466+
402467
"""
403468
_isvararg(expr)
404469

src/rules.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,52 @@ const rrule_kwfunc = Core.kwftype(typeof(rrule)).instance
139139
function (::typeof(rrule_kwfunc))(kws::Any, ::typeof(rrule), ::RuleConfig, args...)
140140
return rrule_kwfunc(kws, rrule, args...)
141141
end
142+
143+
##############################################################
144+
### Opt out functionality
145+
146+
const NO_RRULE_DOC = """
147+
no_rrule
148+
149+
This is an implementation detail for opting out of [`rrule`](@ref).
150+
It follows the signature for `rrule` exactly.
151+
We use it as a way to store a collection of type-tuples in its method-table.
152+
If something has this defined, it means that it must having a must also have a `rrule`,
153+
that returns `nothing`.
154+
155+
### Machanics
156+
note: when this says methods `==` or `<:` it actually means:
157+
`parameters(m.sig)[2:end]` rather than the method object `m` itself.
158+
159+
To decide if should opt-out using this mechanism.
160+
- find the most specific method of `rrule`
161+
- find the most specific method of `no_rrule`
162+
- if the method of `no_rrule` `<:` the method of `rrule`, then should opt-out
163+
164+
To just ignore the fact that rules can be opted-out from, and that some rules thus return
165+
`nothing`, then filter the list of methods of `rrule` to remove those that are `==` to ones
166+
that occur in the method table of `no_rrule`.
167+
168+
Note also when doing this you must still also handle falling back from rule with config, to
169+
rule without config.
170+
171+
On the other-hand if your AD can work with `rrule`s that return `nothing`, then it is
172+
simpler to just use that mechanism for opting out; and you don't need to worry about this
173+
at all.
174+
"""
175+
176+
"""
177+
$NO_RRULE_DOC
178+
179+
See also [`ChainRulesCore.no_frule`](@ref).
180+
"""
181+
function no_rrule end
182+
no_rrule(::Any, ::Vararg{Any}) = nothing
183+
184+
"""
185+
$(replace(NO_RRULE_DOC, "rrule"=>"frule"))
186+
187+
See also [`ChainRulesCore.no_rrule`](@ref).
188+
"""
189+
function no_frule end
190+
no_frule(ȧrgs, f, ::Vararg{Any}) = nothing

test/rules.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,4 +148,32 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
148148
@test_skip ∂xr isa Float64 # to be made true with projection
149149
@test_skip ∂xr real(∂x)
150150
end
151+
152+
153+
@testset "@opt_out" begin
154+
first_oa(x, y) = x
155+
@scalar_rule(first_oa(x, y), (1, 0))
156+
@opt_out ChainRulesCore.rrule(::typeof(first_oa), x::T, y::T) where T<:Float32
157+
@opt_out(
158+
ChainRulesCore.frule(::Any, ::typeof(first_oa), x::T, y::T) where T<:Float32
159+
)
160+
161+
@testset "rrule" begin
162+
@test rrule(first_oa, 3.0, 4.0)[2](1) == (NoTangent(), 1, 0)
163+
@test rrule(first_oa, 3f0, 4f0) === nothing
164+
165+
@test !isempty(Iterators.filter(methods(ChainRulesCore.no_rrule)) do m
166+
m.sig <:Tuple{Any, typeof(first_oa), T, T} where T<:Float32
167+
end)
168+
end
169+
170+
@testset "frule" begin
171+
@test frule((NoTangent(), 1,0), first_oa, 3.0, 4.0) == (3.0, 1)
172+
@test frule((NoTangent(), 1,0), first_oa, 3f0, 4f0) === nothing
173+
174+
@test !isempty(Iterators.filter(methods(ChainRulesCore.no_frule)) do m
175+
m.sig <:Tuple{Any, Any, typeof(first_oa), T, T} where T<:Float32
176+
end)
177+
end
178+
end
151179
end

0 commit comments

Comments
 (0)