Skip to content

Commit 3acd962

Browse files
authored
use a functor for projection (#385)
1 parent b4f2cfa commit 3acd962

File tree

8 files changed

+428
-2
lines changed

8 files changed

+428
-2
lines changed

docs/Manifest.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
1313
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
1414
path = ".."
1515
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
16-
version = "0.10.7"
16+
version = "0.10.11"
1717

1818
[[Compat]]
1919
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]

docs/src/api.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ Pages = ["config.jl"]
4141
Private = false
4242
```
4343

44+
## ProjectTo
45+
```@docs
46+
ProjectTo
47+
```
48+
4449
## Internal
4550
```@docs
4651
ChainRulesCore.AbstractTangent

docs/src/writing_good_rules.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,40 @@ Examples being:
6363
- There is only one derivative being returned, so from the fact that the user called
6464
`frule`/`rrule` they clearly will want to use that one.
6565

66+
## Ensure you remain in the primal's subspace (i.e. use `ProjectTo` appropriately)
67+
68+
Rules with abstractly-typed arguments may return incorrect answers when called with certain concrete types.
69+
A classic example is the matrix-matrix multiplication rule, a naive definition of which follows:
70+
```julia
71+
function rrule(::typeof(*), A::AbstractMatrix, B::AbstractMatrix)
72+
function times_pullback(ȳ)
73+
dA =* B'
74+
dB = A' *
75+
return NoTangent(), dA, dB
76+
end
77+
return A * B, times_pullback
78+
end
79+
```
80+
When computing `*(A, B)`, where `A isa Diagonal` and `B isa Matrix`, the output will be a `Matrix`.
81+
As a result, `` in the pullback will be a `Matrix`, and consequently `dA` for a `A isa Diagonal` will be a `Matrix`, which is wrong.
82+
Not only is it the wrong type, but it can contain non-zeros off the diagonal, which is not possible, it is outside of the subspace.
83+
While a specialised rules can indeed be written for the `Diagonal` case, there are many other types and we don't want to be forced to write a rule for each of them.
84+
Instead, `project_A = ProjectTo(A)` can be used (outside the pullback) to extract an object that knows how to project onto the type of `A` (e.g. also knows the size of the array).
85+
This object can be called with a tangent `ȳ * B'`, by doing `project_A(ȳ * B')`, to project it on the tangent space of `A`.
86+
The correct rule then looks like
87+
```julia
88+
function rrule(::typeof(*), A::AbstractMatrix, B::AbstractMatrix)
89+
project_A = ProjectTo(A)
90+
project_B = ProjectTo(B)
91+
function times_pullback(ȳ)
92+
dA =* B'
93+
dB = A' *
94+
return NoTangent(), project_A(dA), project_B(dB)
95+
end
96+
return A * B, times_pullback
97+
end
98+
```
99+
66100
## Structs: constructors and functors
67101

68102
To define an `frule` or `rrule` for a _function_ `foo` we dispatch on the type of `foo`, which is `typeof(foo)`.

src/ChainRulesCore.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMod
1010
export frule_via_ad, rrule_via_ad
1111
# definition helper macros
1212
export @non_differentiable, @scalar_rule, @thunk, @not_implemented
13-
export canonicalize, extern, unthunk # differential operations
13+
export ProjectTo, canonicalize, extern, unthunk # differential operations
1414
export add!! # gradient accumulation operations
1515
# differentials
1616
export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk
@@ -26,6 +26,7 @@ include("differentials/notimplemented.jl")
2626

2727
include("differential_arithmetic.jl")
2828
include("accumulation.jl")
29+
include("projection.jl")
2930

3031
include("config.jl")
3132
include("rules.jl")

src/differentials/composite.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ backing(x::NamedTuple) = x
133133
backing(x::Dict) = x
134134
backing(x::Tangent) = getfield(x, :backing)
135135

136+
# For generic structs
136137
function backing(x::T)::NamedTuple where T
137138
# note: all computation outside the if @generated happens at runtime.
138139
# so the first 4 lines of the branchs look the same, but can not be moved out.

src/projection.jl

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""
2+
ProjectTo(x::T)
3+
4+
Returns a `ProjectTo{T,...}` functor able to project a differential `dx` onto the same tangent space as `x`.
5+
This functor encloses over what ever is needed to be able to be able to do that projection.
6+
For example, when projecting `dx=ZeroTangent()` on an array `P=Array{T, N}`, the size of `x`
7+
is not available from `P`, so it is stored in the functor.
8+
9+
(::ProjectTo{T})(dx)
10+
11+
Projects the differential `dx` on the onto the tangent space used to create the `ProjectTo`.
12+
"""
13+
struct ProjectTo{P,D<:NamedTuple}
14+
info::D
15+
end
16+
ProjectTo{P}(info::D) where {P,D<:NamedTuple} = ProjectTo{P,D}(info)
17+
ProjectTo{P}(; kwargs...) where {P} = ProjectTo{P}(NamedTuple(kwargs))
18+
19+
backing(project::ProjectTo) = getfield(project, :info)
20+
Base.getproperty(p::ProjectTo, name::Symbol) = getproperty(backing(p), name)
21+
Base.propertynames(p::ProjectTo) = propertynames(backing(p))
22+
23+
function Base.show(io::IO, project::ProjectTo{T}) where {T}
24+
print(io, "ProjectTo{")
25+
show(io, T)
26+
print(io, "}")
27+
if isempty(backing(project))
28+
print(io, "()")
29+
else
30+
show(io, backing(project))
31+
end
32+
end
33+
34+
# fallback (structs)
35+
function ProjectTo(x::T) where {T}
36+
# Generic fallback for structs, recursively make `ProjectTo`s all their fields
37+
fields_nt::NamedTuple = backing(x)
38+
return ProjectTo{T}(map(ProjectTo, fields_nt))
39+
end
40+
function (project::ProjectTo{T})(dx::Tangent) where {T}
41+
sub_projects = backing(project)
42+
sub_dxs = backing(canonicalize(dx))
43+
_call(f, x) = f(x)
44+
return construct(T, map(_call, sub_projects, sub_dxs))
45+
end
46+
47+
# should not work for Tuples and NamedTuples, as not valid tangent types
48+
function ProjectTo(x::T) where {T<:Union{<:Tuple,NamedTuple}}
49+
return throw(
50+
ArgumentError("The `x` in `ProjectTo(x)` must be a valid differential, not $x")
51+
)
52+
end
53+
54+
# Generic
55+
(project::ProjectTo)(dx::AbstractThunk) = project(unthunk(dx))
56+
(::ProjectTo{T})(dx::T) where {T} = dx # not always true, but we can special case for when it isn't
57+
(::ProjectTo{T})(dx::AbstractZero) where {T} = zero(T)
58+
59+
# Number
60+
ProjectTo(::T) where {T<:Number} = ProjectTo{T}()
61+
(::ProjectTo{T})(dx::Number) where {T<:Number} = convert(T, dx)
62+
(::ProjectTo{T})(dx::Number) where {T<:Real} = convert(T, real(dx))
63+
64+
# Arrays
65+
ProjectTo(xs::T) where {T<:Array} = ProjectTo{T}(; elements=map(ProjectTo, xs))
66+
function (project::ProjectTo{T})(dx::Array) where {T<:Array}
67+
_call(f, x) = f(x)
68+
return T(map(_call, project.elements, dx))
69+
end
70+
function (project::ProjectTo{T})(dx::AbstractZero) where {T<:Array}
71+
return T(map(proj -> proj(dx), project.elements))
72+
end
73+
(project::ProjectTo{<:Array})(dx::AbstractArray) = project(collect(dx))
74+
75+
# Arrays{<:Number}: optimized case so we don't need a projector per element
76+
function ProjectTo(x::T) where {E<:Number,T<:Array{E}}
77+
return ProjectTo{T}(; element=ProjectTo(zero(E)), size=size(x))
78+
end
79+
(project::ProjectTo{<:Array{T}})(dx::Array) where {T<:Number} = project.element.(dx)
80+
function (project::ProjectTo{<:Array{T}})(dx::AbstractZero) where {T<:Number}
81+
return zeros(T, project.size)
82+
end
83+
function (project::ProjectTo{<:Array{T}})(dx::Tangent{<:SubArray}) where {T<:Number}
84+
return project(dx.parent)
85+
end
86+
87+
# Diagonal
88+
ProjectTo(x::T) where {T<:Diagonal} = ProjectTo{T}(; diag=ProjectTo(diag(x)))
89+
(project::ProjectTo{T})(dx::AbstractMatrix) where {T<:Diagonal} = T(project.diag(diag(dx)))
90+
(project::ProjectTo{T})(dx::AbstractZero) where {T<:Diagonal} = T(project.diag(dx))
91+
92+
# :data, :uplo fields
93+
for SymHerm in (:Symmetric, :Hermitian)
94+
@eval begin
95+
function ProjectTo(x::T) where {T<:$SymHerm}
96+
return ProjectTo{T}(; uplo=Symbol(x.uplo), parent=ProjectTo(parent(x)))
97+
end
98+
function (project::ProjectTo{<:$SymHerm})(dx::AbstractMatrix)
99+
return $SymHerm(project.parent(dx), project.uplo)
100+
end
101+
function (project::ProjectTo{<:$SymHerm})(dx::AbstractZero)
102+
return $SymHerm(project.parent(dx), project.uplo)
103+
end
104+
function (project::ProjectTo{<:$SymHerm})(dx::Tangent)
105+
return $SymHerm(project.parent(dx.data), project.uplo)
106+
end
107+
end
108+
end
109+
110+
# :data field
111+
for UL in (:UpperTriangular, :LowerTriangular)
112+
@eval begin
113+
ProjectTo(x::T) where {T<:$UL} = ProjectTo{T}(; parent=ProjectTo(parent(x)))
114+
(project::ProjectTo{<:$UL})(dx::AbstractMatrix) = $UL(project.parent(dx))
115+
(project::ProjectTo{<:$UL})(dx::AbstractZero) = $UL(project.parent(dx))
116+
(project::ProjectTo{<:$UL})(dx::Tangent) = $UL(project.parent(dx.data))
117+
end
118+
end
119+
120+
# Transpose
121+
ProjectTo(x::T) where {T<:Transpose} = ProjectTo{T}(; parent=ProjectTo(parent(x)))
122+
function (project::ProjectTo{<:Transpose})(dx::AbstractMatrix)
123+
return transpose(project.parent(transpose(dx)))
124+
end
125+
(project::ProjectTo{<:Transpose})(dx::AbstractZero) = transpose(project.parent(dx))
126+
127+
# Adjoint
128+
ProjectTo(x::T) where {T<:Adjoint} = ProjectTo{T}(; parent=ProjectTo(parent(x)))
129+
(project::ProjectTo{<:Adjoint})(dx::AbstractMatrix) = adjoint(project.parent(adjoint(dx)))
130+
(project::ProjectTo{<:Adjoint})(dx::AbstractZero) = adjoint(project.parent(dx))
131+
132+
# PermutedDimsArray
133+
ProjectTo(x::P) where {P<:PermutedDimsArray} = ProjectTo{P}(; parent=ProjectTo(parent(x)))
134+
function (project::ProjectTo{<:PermutedDimsArray{T,N,perm,iperm,AA}})(
135+
dx::AbstractArray
136+
) where {T,N,perm,iperm,AA}
137+
return PermutedDimsArray{T,N,perm,iperm,AA}(permutedims(project.parent(dx), perm))
138+
end
139+
function (project::ProjectTo{P})(dx::AbstractZero) where {P<:PermutedDimsArray}
140+
return P(project.parent(dx))
141+
end
142+
143+
# SubArray
144+
ProjectTo(x::T) where {T<:SubArray} = ProjectTo(copy(x)) # don't project on to a view, but onto matching copy

0 commit comments

Comments
 (0)