Skip to content

Commit 6efb2d2

Browse files
authored
Merge pull request #398 from JuliaDiff/ox/optout
Add opting out of rules
2 parents 460a559 + db35df7 commit 6efb2d2

File tree

7 files changed

+263
-5
lines changed

7 files changed

+263
-5
lines changed

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ makedocs(
4848
"Introduction" => "index.md",
4949
"FAQ" => "FAQ.md",
5050
"Rule configurations and calling back into AD" => "config.md",
51+
"Opting out of rules" => "opting_out_of_rules.md",
5152
"Writing Good Rules" => "writing_good_rules.md",
5253
"Complex Numbers" => "complex.md",
5354
"Deriving Array Rules" => "arrays.md",

docs/src/api.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,6 @@ ProjectTo
5050
```@docs
5151
ChainRulesCore.AbstractTangent
5252
ChainRulesCore.debug_mode
53+
ChainRulesCore.no_rrule
54+
ChainRulesCore.no_frule
5355
```

docs/src/opting_out_of_rules.md

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# [Opting out of rules](@id opt_out)
2+
3+
It is common to define rules fairly generically.
4+
Often matching (or exceeding) how generic the matching original primal method is.
5+
Sometimes this is not the correct behaviour.
6+
Sometimes the AD can do better than this human defined rule.
7+
If this is generally the case, then we should not have the rule defined at all.
8+
But if it is only the case for a particular set of types, then we want to opt-out just that one.
9+
This is done with the [`@opt_out`](@ref) macro.
10+
11+
Consider one a `rrule` for `sum` (the following simplified from the one in [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/Base/mapreduce.jl) itself)
12+
```julia
13+
function rrule(::typeof(sum), x::AbstractArray{<:Number}; dims=:)
14+
y = sum(x; dims=dims)
15+
project = ProjectTo(x)
16+
function sum_pullback(ȳ)
17+
# broadcasting the two works out the size no-matter `dims`
18+
# project makes sure we stay in the same vector subspace as `x`
19+
# no putting in off-diagonal entries in Diagonal etc
20+
= project(broadcast(lasttuple, x, ȳ)))
21+
return (NoTangent(), x̄)
22+
end
23+
return y, sum_pullback
24+
end
25+
```
26+
27+
That is a fairly reasonable `rrule` for the vast majority of cases.
28+
29+
You might have a custom array type for which you could write a faster rule.
30+
For example, the pullback for summing a [`SkewSymmetric` (anti-symmetric)](https://en.wikipedia.org/wiki/Skew-symmetric_matrix) matrix can be optimized to basically be `Diagonal(fill(ȳ, size(x,1)))`.
31+
To do that, you can indeed write another more specific [`rrule`](@ref).
32+
But another case is where the AD system itself would generate a more optimized case.
33+
34+
For example, the [`NamedDimsArray`](https://github.com/invenia/NamedDims.jl) is a thin wrapper around some other array type.
35+
Its sum method is basically just to call `sum` on its parent.
36+
It is entirely conceivable[^1] that the AD system can do better than our `rrule` here.
37+
For example by avoiding the overhead of [`project`ing](@ref ProjectTo).
38+
39+
To opt-out of using the generic `rrule` and to allow the AD system to do its own thing we use the
40+
[`@opt_out`](@ref) macro, to say to not use it for sum of `NamedDimsArrays`.
41+
42+
```julia
43+
@opt_out rrule(::typeof(sum), ::NamedDimsArray)
44+
```
45+
46+
We could even opt-out for all 1 arg functions.
47+
```@julia
48+
@opt_out rrule(::Any, ::NamedDimsArray)
49+
```
50+
Though this is likely to cause some method-ambiguities.
51+
52+
Similar can be done `@opt_out frule`.
53+
It can also be done passing in a [`RuleConfig`](@ref config).
54+
55+
56+
!!! warning "If the general rule uses a config, the opt-out must also"
57+
Following the same principles as for [rules with config](@ref config), a rule with a `RuleConfig` argument will take precedence over one without, including if that one is a opt-out rule.
58+
But if the general rule does not use a config, then the opt-out rule *can* use a config.
59+
This allows, for example, you to use opt-out to avoid a particular AD system using a opt-out rule that takes that particular AD's config.
60+
61+
62+
## How to support this (for AD implementers)
63+
64+
We provide two ways to know that a rule has been opted out of.
65+
66+
### `rrule` / `frule` returns `nothing`
67+
68+
`@opt_out` defines a `frule` or `rrule` matching the signature that returns `nothing`.
69+
70+
If you are in a position to generate code, in response to values returned by function calls then you can do something like:
71+
```@julia
72+
res = rrule(f, xs)
73+
if res === nothing
74+
y, pullback = perform_ad_via_decomposition(r, xs) # do AD without hitting the rrule
75+
else
76+
y, pullback = res
77+
end
78+
```
79+
The Julia compiler will specialize based on inferring the return type of `rrule`, and so can remove that branch.
80+
81+
### `no_rrule` / `no_frule` has a method
82+
83+
`@opt_out` also defines a method for [`ChainRulesCore.no_frule`](@ref) or [`ChainRulesCore.no_rrule`](@ref).
84+
The body of this method doesn't matter, what matters is that it is a method-table.
85+
A simple thing you can do with this is not support opting out.
86+
To do this, filter all methods from the `rrule`/`frule` method table that also occur in the `no_frule`/`no_rrule` table.
87+
This will thus avoid ever hitting an `rrule`/`frule` that returns `nothing` (and thus prevents your library from erroring).
88+
This is easily done, though it does mean ignoring the user's stated desire to opt out of the rule.
89+
90+
More complex you can use this to generate code that triggers your AD.
91+
If for a given signature there is a more specific method in the `no_rrule`/`no_frule` method-table, than the one that would be hit from the `rrule`/`frule` table
92+
(Excluding the one that exactly matches which will return `nothing`) then you know that the rule should not be used.
93+
You can, likely by looking at the primal method table, workout which method you would have it if the rule had not been defined,
94+
and then `invoke` it.
95+
96+
97+
98+
[^1]: It is also possible, that this is not the case. Benchmark your real uses cases.

src/ChainRulesCore.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module ChainRulesCore
22
using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize!
3+
using Base.Meta
34
using LinearAlgebra
45
using SparseArrays: SparseVector, SparseMatrixCSC
56
using Compat: hasfield
@@ -9,7 +10,7 @@ export frule, rrule # core function
910
export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMode
1011
export frule_via_ad, rrule_via_ad
1112
# definition helper macros
12-
export @non_differentiable, @scalar_rule, @thunk, @not_implemented
13+
export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented
1314
export ProjectTo, canonicalize, unthunk # differential operations
1415
export add!! # gradient accumulation operations
1516
# differentials

src/rule_definition_tools.jl

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
# These are some macros (and supporting functions) to make it easier to define rules.
2-
using Base.Meta
32

3+
# Note: must be declared before it is used, which is later in this file.
44
macro strip_linenos(expr)
55
return esc(Base.remove_linenums!(expr))
66
end
77

8+
############################################################################################
9+
### @scalar_rule
10+
811
"""
912
@scalar_rule(f(x₁, x₂, ...),
1013
@setup(statement₁, statement₂, ...),
@@ -88,7 +91,6 @@ macro scalar_rule(call, maybe_setup, partials...)
8891
frule_expr = scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials)
8992
rrule_expr = scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials)
9093

91-
############################################################################
9294
# Final return: building the expression to insert in the place of this macro
9395
code = quote
9496
if !($f isa Type) && fieldcount(typeof($f)) > 0
@@ -114,7 +116,6 @@ returns (in order) the correctly escaped:
114116
- `partials`: which are all `Expr{:tuple,...}`
115117
"""
116118
function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
117-
############################################################################
118119
# Setup: normalizing input form etc
119120

120121
if Meta.isexpr(maybe_setup, :macrocall) && maybe_setup.args[1] == Symbol("@setup")
@@ -275,6 +276,9 @@ propagator_name(f::Expr, propname::Symbol) = propagator_name(f.args[end], propna
275276
propagator_name(fname::Symbol, propname::Symbol) = Symbol(fname, :_, propname)
276277
propagator_name(fname::QuoteNode, propname::Symbol) = propagator_name(fname.value, propname)
277278

279+
############################################################################################
280+
### @non_differentiable
281+
278282
"""
279283
@non_differentiable(signature_expression)
280284
@@ -394,7 +398,74 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke)
394398
end
395399

396400

397-
###########
401+
############################################################################################
402+
# @opt_out
403+
404+
"""
405+
@opt_out frule([config], _, f, args...)
406+
@opt_out rrule([config], f, args...)
407+
408+
This allows you to opt-out of an `frule` or an `rrule` by providing a more specific method,
409+
that says to use the AD system to differentiate it.
410+
411+
For example, consider some function `foo(x::AbtractArray)`.
412+
In general, you know an efficient and generic way to implement its `rrule`.
413+
You do so, (likely making use of [`ProjectTo`](@ref)).
414+
But it actually turns out that for some `FancyArray` type it is better to let the AD do its
415+
thing.
416+
417+
Then you would write something like:
418+
```julia
419+
function rrule(::typeof(foo), x::AbstractArray)
420+
foo_pullback(ȳ) = ...
421+
return foo(x), foo_pullback
422+
end
423+
424+
@opt_out rrule(::typeof(foo), ::FancyArray)
425+
```
426+
427+
This will generate an [`rrule`](@ref) that returns `nothing`,
428+
and will also add a similar entry to [`ChainRulesCore.no_rrule`](@ref).
429+
430+
Similar applies for [`frule`](@ref) and [`ChainRulesCore.no_frule`](@ref)
431+
432+
For more information see the [documentation on opting out of rules](@ref opt_out).
433+
"""
434+
macro opt_out(expr)
435+
no_rule_target = _no_rule_target_rewrite!(deepcopy(expr))
436+
437+
return @strip_linenos quote
438+
$(esc(no_rule_target)) = nothing
439+
$(esc(expr)) = nothing
440+
end
441+
end
442+
443+
"Rewrite method sig Expr for `rrule` to be for `no_rrule`, and `frule` to be `no_frule`."
444+
function _no_rule_target_rewrite!(expr::Expr)
445+
length(expr.args)===0 && error("Malformed method expression. $expr")
446+
if expr.head === :call || expr.head === :where
447+
expr.args[1] = _no_rule_target_rewrite!(expr.args[1])
448+
elseif expr.head == :(.) && expr.args[1] == :ChainRulesCore
449+
expr = _no_rule_target_rewrite!(expr.args[end])
450+
else
451+
error("Malformed method expression. $(expr)")
452+
end
453+
return expr
454+
end
455+
_no_rule_target_rewrite!(qt::QuoteNode) = _no_rule_target_rewrite!(qt.value)
456+
function _no_rule_target_rewrite!(call_target::Symbol)
457+
return if call_target == :rrule
458+
:(ChainRulesCore.no_rrule)
459+
elseif call_target == :frule
460+
:(ChainRulesCore.no_frule)
461+
else
462+
error("Unexpected opt-out target. Exprected frule or rrule, got: $call_target")
463+
end
464+
end
465+
466+
467+
468+
############################################################################################
398469
# Helpers
399470

400471
"""

src/rules.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,60 @@ const rrule_kwfunc = Core.kwftype(typeof(rrule)).instance
139139
function (::typeof(rrule_kwfunc))(kws::Any, ::typeof(rrule), ::RuleConfig, args...)
140140
return rrule_kwfunc(kws, rrule, args...)
141141
end
142+
143+
##############################################################
144+
### Opt out functionality
145+
146+
const NO_RRULE_DOC = """
147+
no_rrule
148+
149+
This is an piece of infastructure supporting opting out of [`rrule`](@ref).
150+
It follows the signature for `rrule` exactly.
151+
A collection of type-tuples is stored in its method-table.
152+
If something has this defined, it means that it must having a must also have a `rrule`,
153+
defined that returns `nothing`.
154+
155+
!!! warning "do not overload no_rrule directly
156+
It is fine and intended to query the method table of `no_rrule`.
157+
It is not safe to add to that directly, as corresponding changes also need to be made to
158+
`rrule`.
159+
The [`@opt_out`](@ref) macro does both these things, and so should almost always be used
160+
rather than defining a method of `no_rrule` directly.
161+
162+
### Mechanics
163+
note: when the text below says methods `==` it actually means:
164+
`parameters(m.sig)[2:end]` (i.e. the signature type tuple) rather than the method object `m` itself.
165+
166+
To decide if should opt-out using this mechanism.
167+
- find the most specific method of `rrule` and `no_rule` e.g with `Base.which`
168+
- if the method of `no_rrule` `==` the method of `rrule`, then should opt-out
169+
170+
To just ignore the fact that rules can be opted-out from, and that some rules thus return
171+
`nothing`, then filter the list of methods of `rrule` to remove those that are `==` to ones
172+
that occur in the method table of `no_rrule`.
173+
174+
Note also when doing this you must still also handle falling back from rule with config, to
175+
rule without config.
176+
177+
On the other-hand if your AD can work with `rrule`s that return `nothing`, then it is
178+
simpler to just use that mechanism for opting out; and you don't need to worry about this
179+
at all.
180+
181+
For more information see the [documentation on opting out of rules](@ref opt_out)
182+
"""
183+
184+
"""
185+
$NO_RRULE_DOC
186+
187+
See also [`ChainRulesCore.no_frule`](@ref).
188+
"""
189+
function no_rrule end
190+
no_rrule(::Any, ::Vararg{Any}) = nothing
191+
192+
"""
193+
$(replace(NO_RRULE_DOC, "rrule"=>"frule"))
194+
195+
See also [`ChainRulesCore.no_rrule`](@ref).
196+
"""
197+
function no_frule end
198+
no_frule(ȧrgs, f, ::Vararg{Any}) = nothing

test/rules.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,4 +148,32 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
148148
@test_skip ∂xr isa Float64 # to be made true with projection
149149
@test_skip ∂xr real(∂x)
150150
end
151+
152+
153+
@testset "@opt_out" begin
154+
first_oa(x, y) = x
155+
@scalar_rule(first_oa(x, y), (1, 0))
156+
@opt_out ChainRulesCore.rrule(::typeof(first_oa), x::T, y::T) where T<:Float32
157+
@opt_out(
158+
ChainRulesCore.frule(::Any, ::typeof(first_oa), x::T, y::T) where T<:Float32
159+
)
160+
161+
@testset "rrule" begin
162+
@test rrule(first_oa, 3.0, 4.0)[2](1) == (NoTangent(), 1, 0)
163+
@test rrule(first_oa, 3f0, 4f0) === nothing
164+
165+
@test !isempty(Iterators.filter(methods(ChainRulesCore.no_rrule)) do m
166+
m.sig <:Tuple{Any, typeof(first_oa), T, T} where T<:Float32
167+
end)
168+
end
169+
170+
@testset "frule" begin
171+
@test frule((NoTangent(), 1,0), first_oa, 3.0, 4.0) == (3.0, 1)
172+
@test frule((NoTangent(), 1,0), first_oa, 3f0, 4f0) === nothing
173+
174+
@test !isempty(Iterators.filter(methods(ChainRulesCore.no_frule)) do m
175+
m.sig <:Tuple{Any, Any, typeof(first_oa), T, T} where T<:Float32
176+
end)
177+
end
178+
end
151179
end

0 commit comments

Comments
 (0)