-
Notifications
You must be signed in to change notification settings - Fork 40
[WIP] Fix AD issues with various kernels #154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
a6211d0
8704f18
8f44c51
14db1f4
90c1dff
dcf1f6b
16e8af6
ede5879
e8b76ec
e236aaf
d50c73f
090cc8a
45c14d6
b920c19
2630adc
31730a8
e81cb01
4c2f233
0023292
acdec1a
f467162
651ae02
6b114d2
8655911
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,13 +4,27 @@ | |
v1 = rand(rng, 3) | ||
v2 = rand(rng, 3) | ||
|
||
P = rand(rng, 3, 3) | ||
U = UpperTriangular(rand(rng, 3,3)) | ||
P = Matrix(Cholesky(U, 'U', 0)) | ||
@assert isposdef(P) | ||
k = MahalanobisKernel(P) | ||
|
||
@test kappa(k, x) == exp(-x) | ||
@test k(v1, v2) ≈ exp(-sqmahalanobis(v1, v2, P)) | ||
@test kappa(ExponentialKernel(), x) == kappa(k, x) | ||
@test repr(k) == "Mahalanobis Kernel (size(P) = $(size(P)))" | ||
test_ADs(P -> MahalanobisKernel(P), P, ADs=[:Zygote]) | ||
|
||
M1, M2 = rand(rng,3,2), rand(rng,3,2) | ||
fdm = FiniteDifferences.Central(5, 1); | ||
|
||
|
||
FiniteDifferences.to_vec(dist::SqMahalanobis{Float64}) = vec(dist.qmat), x -> SqMahalanobis(reshape(x, size(dist.qmat)...)) | ||
|
||
@test_broken j′vp(fdm, x -> MahalanobisKernel(Array(x[1]'*x[1]))(x[2], x[3]), 1, [U, v1, v2]) ≈ | ||
sharanry marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
Zygote.pullback(x -> MahalanobisKernel(Array(x[1]'*x[1]))(x[2], x[3]), [U, v1, v2])[2](1) | ||
@test all(j′vp(fdm, x -> SqMahalanobis(Array(x[1]'*x[1]))(x[2], x[3]), 1, [U, v1, v2])[1][1] .≈ | ||
Zygote.pullback(x -> SqMahalanobis(Array(x[1]'*x[1]))(x[2], x[3]), [U, v1, v2])[2](1)[1][1]) | ||
|
||
|
||
# test_ADs(U -> MahalanobisKernel(Array(U' * U)), U, ADs=[:Zygote]) | ||
@test_broken "Nothing passes (problem with Mahalanobis distance in Distances)" | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this needed? If possible, we should avoid this type piracy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
j′vp
only works when there is ato_vec
function defined for each argument.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering since according to the docs
to_vec
is only needed for the inputsxs...
but not the evaluated functionf
inj'vp(fdm, f, xs...)
.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From what I understand, it is also needed for objects like
SqMahalanobis
if they have parameters likeqmat
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's correct, but actually for some reason we've not made
FiniteDifferences
handle functions-with-data properly yet, so you'll have to build theSqMaha
object inside of the function that you're differentiating.