-
Notifications
You must be signed in to change notification settings - Fork 40
Fix gradient issues with kernelmatrix_diag and use ChainRulesCore #208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 21 commits
e525614
e56492a
35a6306
25e5efd
2f85ebc
cae225f
3f16f07
8c0d0a2
13a10fd
78a2078
5ca94e7
2c60abd
9214211
87edbc8
f65556b
48e2dcb
6cc803d
0e30941
61869b1
06bd4f0
8e1e516
aaa16de
52b1ae5
4067a42
641ebee
13d1e39
4675c2f
0b97c1a
ad9838e
650dc08
1703db1
e2cd167
01ffac0
9bfb6eb
f3fa4bc
a0c2a64
ff5a66b
8157b4c
e6bfdb1
db5e7b8
a44a762
72889dd
25549c1
bbe5c7c
e08dbf4
48bd681
3298d34
0b99771
c26edf3
647862a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
## Reverse Rules Delta | ||
|
||
function rrule(::typeof(Distances.evaluate), s::Delta, x::AbstractVector, y::AbstractVector) | ||
d = evaluate(s, x, y) | ||
function evaluate_pullback(::Any) | ||
return NO_FIELDS, Zero(), Zero() | ||
end | ||
return d, evaluate_pullback | ||
end | ||
|
||
function rrule( | ||
::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2 | ||
) | ||
P = Distances.pairwise(d, X, Y; dims=dims) | ||
function pairwise_pullback(::Any) | ||
theogf marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
return NO_FIELDS, Zero(), Zero() | ||
end | ||
return P, pairwise_pullback | ||
end | ||
|
||
function rrule(::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix; dims=2) | ||
P = Distances.pairwise(d, X; dims=dims) | ||
function pairwise_pullback(::Any) | ||
theogf marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
return NO_FIELDS, Zero() | ||
end | ||
return P, pairwise_pullback | ||
end | ||
|
||
function rrule(::typeof(Distances.colwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix) | ||
C = Distances.colwise(d, X, Y) | ||
function colwise_pullback(::AbstractVector) | ||
return NO_FIELDS, Zero(), Zero() | ||
end | ||
return C, colwise_pullback | ||
end | ||
|
||
## Reverse Rules DotProduct | ||
function rrule( | ||
::typeof(Distances.evaluate), s::DotProduct, x::AbstractVector, y::AbstractVector | ||
) | ||
d = dot(x, y) | ||
function evaluate_pullback(Δ) | ||
return NO_FIELDS, Δ .* y, Δ .* x | ||
end | ||
return d, evaluate_pullback | ||
end | ||
|
||
function rrule( | ||
::typeof(Distances.pairwise), | ||
d::DotProduct, | ||
X::AbstractMatrix, | ||
Y::AbstractMatrix; | ||
dims=2, | ||
) | ||
P = Distances.pairwise(d, X, Y; dims=dims) | ||
if dims == 1 | ||
|
||
function pairwise_pullback_cols(Δ) | ||
return NO_FIELDS, Δ * Y, Δ' * X | ||
end | ||
return P, pairwise_pullback_cols | ||
else | ||
function pairwise_pullback_rows(Δ) | ||
return NO_FIELDS, Y * Δ', X * Δ | ||
end | ||
return P, pairwise_pullback_rows | ||
end | ||
end | ||
|
||
function rrule(::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix; dims=2) | ||
P = Distances.pairwise(d, X; dims=dims) | ||
if dims == 1 | ||
function pairwise_pullback_cols(Δ) | ||
return NO_FIELDS, 2 * Δ * X | ||
end | ||
return P, pairwise_pullback_cols | ||
else | ||
function pairwise_pullback_rows(Δ) | ||
return NO_FIELDS, 2 * X * Δ | ||
end | ||
return P, pairwise_pullback_rows | ||
end | ||
end | ||
|
||
function rrule( | ||
::typeof(Distances.colwise), d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix | ||
) | ||
C = Distances.colwise(d, X, Y) | ||
function colwise_pullback(Δ::AbstractVector) | ||
return (nothing, Δ' .* Y, Δ' .* X) | ||
theogf marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
end | ||
return C, colwise_pullback | ||
end | ||
|
||
## Reverse Rules Sinus | ||
function rrule(::typeof(Distances.evaluate), s::Sinus, x::AbstractVector, y::AbstractVector) | ||
d = (x - y) | ||
theogf marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
sind = sinpi.(d) | ||
val = sum(abs2, sind ./ s.r) | ||
gradx = 2π .* cospi.(d) .* sind ./ (s.r .^ 2) | ||
theogf marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
function evaluate_pullback(Δ) | ||
return (r=-2Δ .* abs2.(sind) ./ s.r,), Δ * gradx, -Δ * gradx | ||
theogf marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
end | ||
return val, evaluate_pullback | ||
end | ||
|
||
## Reverse Rules for matrix wrappers | ||
|
||
function rrule(::ColVecs, X::AbstractMatrix) | ||
ColVecs_pullback(Δ::NamedTuple) = (Δ.X,) | ||
ColVecs_pullback(Δ::AbstractMatrix) = (Δ,) | ||
theogf marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
function ColVecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}}) | ||
return throw(error("In slow method")) | ||
theogf marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
end | ||
return ColVecs(X), ColVecs_pullback | ||
end | ||
|
||
function rrule(::RowVecs, X::AbstractMatrix) | ||
RowVecs_pullback(Δ::NamedTuple) = (Δ.X,) | ||
RowVecs_pullback(Δ::AbstractMatrix) = (Δ,) | ||
theogf marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
function RowVecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}}) | ||
return throw(error("In slow method")) | ||
end | ||
return RowVecs(X), RowVecs_pullback | ||
end | ||
|
||
theogf marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
# function rrule(::typeof(Base.map), t::Transform, X::ColVecs) | ||
# return pullback(_map, t, X) | ||
# end | ||
|
||
# function rrule(::typeof(Base.map), t::Transform, X::RowVecs) | ||
# return pullback(_map, t, X) | ||
# end | ||
|
||
# @adjoint function (dist::Distances.SqMahalanobis)(a, b) | ||
# function SqMahalanobis_pullback(Δ::Real) | ||
# B_Bᵀ = dist.qmat + transpose(dist.qmat) | ||
# a_b = a - b | ||
# δa = (B_Bᵀ * a_b) * Δ | ||
# return (qmat=(a_b * a_b') * Δ,), δa, -δa | ||
# end | ||
# return evaluate(dist, a, b), SqMahalanobis_pullback | ||
# end |
This file was deleted.
Uh oh!
There was an error while loading. Please reload this page.