-
Notifications
You must be signed in to change notification settings - Fork 40
[WIP] Fix AD issues with various kernels #154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 18 commits
a6211d0
8704f18
8f44c51
14db1f4
90c1dff
dcf1f6b
16e8af6
ede5879
e8b76ec
e236aaf
d50c73f
090cc8a
45c14d6
b920c19
2630adc
31730a8
e81cb01
4c2f233
0023292
acdec1a
f467162
651ae02
6b114d2
8655911
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,19 +62,19 @@ end | |
@adjoint function ColVecs(X::AbstractMatrix) | ||
back(Δ::NamedTuple) = (Δ.X,) | ||
back(Δ::AbstractMatrix) = (Δ,) | ||
function back(Δ::AbstractVector{<:AbstractVector{<:Real}}) | ||
function ColVecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}}) | ||
throw(error("In slow method")) | ||
end | ||
return ColVecs(X), back | ||
return ColVecs(X), ColVecs_pullback | ||
end | ||
|
||
@adjoint function RowVecs(X::AbstractMatrix) | ||
back(Δ::NamedTuple) = (Δ.X,) | ||
back(Δ::AbstractMatrix) = (Δ,) | ||
function back(Δ::AbstractVector{<:AbstractVector{<:Real}}) | ||
function RowVecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}}) | ||
throw(error("In slow method")) | ||
end | ||
return RowVecs(X), back | ||
return RowVecs(X), RowVecs_pullback | ||
end | ||
|
||
@adjoint function Base.map(t::Transform, X::ColVecs) | ||
|
@@ -84,3 +84,58 @@ end | |
@adjoint function Base.map(t::Transform, X::RowVecs) | ||
pullback(_map, t, X) | ||
end | ||
|
||
@adjoint function (dist::Distances.SqMahalanobis)(a, b) | ||
function SqMahalanobis_pullback(Δ::Real) | ||
B_Bᵀ = dist.qmat + transpose(dist.qmat) | ||
a_b = a - b | ||
δa = (B_Bᵀ * a_b) * Δ | ||
return (qmat = (a_b * a_b') * Δ,), δa, -δa | ||
end | ||
return evaluate(dist, a, b), SqMahalanobis_pullback | ||
sharanry marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
end | ||
|
||
|
||
@adjoint function Distances.pairwise( | ||
dist::SqMahalanobis, | ||
a::AbstractMatrix, | ||
b::AbstractMatrix; | ||
dims::Union{Nothing,Integer}=nothing | ||
) | ||
function pairwise_pullback(Δ::AbstractMatrix) | ||
B_Bᵀ = dist.qmat + transpose(dist.qmat) | ||
a_b = map( | ||
x -> (first(last(x)) - last(last(x)))*first(x), | ||
zip( | ||
Δ, | ||
Iterators.product(eachslice(a, dims=dims), eachslice(b, dims=dims)) | ||
) | ||
) | ||
δa = reduce(hcat, sum(map(x -> B_Bᵀ*x, a_b), dims=2)) | ||
δB = sum(map(x -> x*transpose(x), a_b)) | ||
return (qmat=δB,), δa, -δa | ||
|
||
end | ||
return Distances.pairwise(dist, a, b, dims=dims), pairwise_pullback | ||
end | ||
|
||
@adjoint function Distances.pairwise( | ||
dist::SqMahalanobis, | ||
a::AbstractMatrix; | ||
dims::Union{Nothing,Integer}=nothing | ||
) | ||
function pairwise_pullback(Δ::AbstractMatrix) | ||
B_Bᵀ = dist.qmat + transpose(dist.qmat) | ||
a_a = map( | ||
x -> (first(last(x)) - last(last(x)))*first(x), | ||
zip( | ||
Δ, | ||
Iterators.product(eachslice(a, dims=dims), eachslice(a, dims=dims)) | ||
) | ||
) | ||
δa = reduce(hcat, sum(map(x -> B_Bᵀ*x, a_a), dims=2)) | ||
δB = sum(map(x -> x*transpose(x), a_a)) | ||
return (qmat=δB,), δa | ||
end | ||
return Distances.pairwise(dist, a, b, dims=dims), pairwise_pullback | ||
end | ||
|
Uh oh!
There was an error while loading. Please reload this page.