You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/src/opting_out_of_rules.md
+19-21Lines changed: 19 additions & 21 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -10,44 +10,44 @@ This is done with the [`@opt_out`](@ref) macro.
10
10
11
11
Consider one a `rrule` for `sum` (the following simplified from the one in [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/Base/mapreduce.jl) itself)
# broadcasting the two works out the size no-matter `dims`
18
-
# project makes sure we stay in the same vector subspace as `x`
19
-
# no putting in off-diagonal entries in Diagonal etc
20
-
x̄ =project(broadcast(last∘tuple, x, ȳ)))
17
+
x̄ =project(fill(ȳ, size(x)))
21
18
return (NoTangent(), x̄)
22
19
end
23
20
return y, sum_pullback
24
21
end
25
22
```
26
23
27
24
That is a fairly reasonable `rrule` for the vast majority of cases.
28
-
29
25
You might have a custom array type for which you could write a faster rule.
30
-
For example, the pullback for summing a [`SkewSymmetric` (anti-symmetric)](https://en.wikipedia.org/wiki/Skew-symmetric_matrix) matrix can be optimized to basically be `Diagonal(fill(ȳ, size(x,1)))`.
31
-
To do that, you can indeed write another more specific [`rrule`](@ref).
32
-
But another case is where the AD system itself would generate a more optimized case.
26
+
In which case you would do that, by writing a faster, more specific, `rrule`.
27
+
But sometimes, it is the case that ADing the (faster, more specific) primal for your custom array type would yeild the faster pullback without you having to write a `rrule` by hand.
33
28
34
-
For example, the [`NamedDimsArray`](https://github.com/invenia/NamedDims.jl) is a thin wrapper around some other array type.
35
-
Its sum method is basically just to call `sum` on its parent.
36
-
It is entirely conceivable[^1] that the AD system can do better than our `rrule` here.
37
-
For example by avoiding the overhead of [`project`ing](@ref ProjectTo).
29
+
Consider a summing [`SkewSymmetric` (anti-symmetric)](https://en.wikipedia.org/wiki/Skew-symmetric_matrix) matrix.
30
+
The skew symmetric matrix has structural zeros on the diagonal, and off-diagonals are paired with their negation.
31
+
Thus the sum is always going to be zero.
32
+
As such the author of that matrix type would probably have overloaded `sum(x::SkewSymmetric{T}) where T = zero(T)`.
33
+
ADing this would result in the tangent computed for `x` as `ZeroTangent()` and it would be very fast since AD can see that `x` is never used in the right-hand side.
34
+
In contrast the generic method for `AbstractArray` defined above would have to allocate the fill array, and then compute the skew projection.
35
+
Only to findout the output would be projected to `SkewSymmetric(zeros(T))` anyway (slower, and a less useful type).
38
36
39
37
To opt-out of using the generic `rrule` and to allow the AD system to do its own thing we use the
40
-
[`@opt_out`](@ref) macro, to say to not use it for sum of `NamedDimsArrays`.
38
+
[`@opt_out`](@ref) macro, to say to not use it for sum of `SkewSymmetric`.
41
39
42
40
```julia
43
-
@opt_outrrule(::typeof(sum), ::NamedDimsArray)
41
+
@opt_outrrule(::typeof(sum), ::SkewSymmetric)
44
42
```
45
43
46
-
We could even opt-out for all 1 arg functions.
44
+
Perhaps we might not want to ever use rules for SkewSymmetric, because we have determined that it is always better to leave it to the AD, unless a verys specific rule has been written[^1].
45
+
We could then opt-out for all 1 arg functions.
47
46
```@julia
48
-
@opt_out rrule(::Any, ::NamedDimsArray)
47
+
@opt_out rrule(::Any, ::SkewSymmetric)
49
48
```
50
-
Though this is likely to cause some method-ambiguities.
49
+
Though this is likely to cause some method-ambiguities, if we do it for more, but we can resolve those.
50
+
51
51
52
52
Similar can be done `@opt_out frule`.
53
53
It can also be done passing in a [`RuleConfig`](@ref config).
@@ -93,6 +93,4 @@ If for a given signature there is a more specific method in the `no_rrule`/`no_f
93
93
You can, likely by looking at the primal method table, workout which method you would have it if the rule had not been defined,
94
94
and then `invoke` it.
95
95
96
-
97
-
98
-
[^1]: It is also possible, that this is not the case. Benchmark your real uses cases.
96
+
[^1]: seems unlikely, but it is possible, there is a lot of structure that can be taken advantage of for some matrix types.
0 commit comments