Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions src/basekernels/maha.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,4 @@ kappa(κ::MahalanobisKernel, d::T) where {T<:Real} = exp(-d)

metric(κ::MahalanobisKernel) = SqMahalanobis(κ.P)

function dot_perslice(A::AbstractMatrix, B::AbstractMatrix; dims=2)
return reshape(sum(A .* B, dims=3-dims), :)
end

Base.show(io::IO, κ::MahalanobisKernel) = print(io, "Mahalanobis Kernel (size(P) = ", size(κ.P), ")")
18 changes: 16 additions & 2 deletions test/basekernels/maha.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,27 @@
v1 = rand(rng, 3)
v2 = rand(rng, 3)

P = rand(rng, 3, 3)
U = UpperTriangular(rand(rng, 3,3))
P = Matrix(Cholesky(U, 'U', 0))
@assert isposdef(P)
k = MahalanobisKernel(P)

@test kappa(k, x) == exp(-x)
@test k(v1, v2) exp(-sqmahalanobis(v1, v2, P))
@test kappa(ExponentialKernel(), x) == kappa(k, x)
@test repr(k) == "Mahalanobis Kernel (size(P) = $(size(P)))"
test_ADs(P -> MahalanobisKernel(P), P, ADs=[:Zygote])

M1, M2 = rand(rng,3,2), rand(rng,3,2)
fdm = FiniteDifferences.Central(5, 1);


FiniteDifferences.to_vec(dist::SqMahalanobis{Float64}) = vec(dist.qmat), x -> SqMahalanobis(reshape(x, size(dist.qmat)...))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed? If possible, we should avoid this type piracy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes j′vp only works when there is a to_vec function defined for each argument.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering since according to the docs to_vec is only needed for the inputs xs... but not the evaluated function f in j'vp(fdm, f, xs...).

Copy link
Contributor Author

@sharanry sharanry Aug 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I understand, it is also needed for objects like SqMahalanobis if they have parameters like qmat.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct, but actually for some reason we've not made FiniteDifferences handle functions-with-data properly yet, so you'll have to build the SqMaha object inside of the function that you're differentiating.


@test_broken j′vp(fdm, x -> MahalanobisKernel(Array(x[1]'*x[1]))(x[2], x[3]), 1, [U, v1, v2])
Zygote.pullback(x -> MahalanobisKernel(Array(x[1]'*x[1]))(x[2], x[3]), [U, v1, v2])[2](1)
@test all(j′vp(fdm, x -> SqMahalanobis(Array(x[1]'*x[1]))(x[2], x[3]), 1, [U, v1, v2])[1][1] .≈
Zygote.pullback(x -> SqMahalanobis(Array(x[1]'*x[1]))(x[2], x[3]), [U, v1, v2])[2](1)[1][1])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@devmotion I tried doing what you suggested. The tests still fail. This error probably propagates and causes even the first test to fail.

julia> j′vp(fdm, x -> SqMahalanobis(Array(x[1]'*x[1]))(x[2], x[3]), 1, [U, v1, v2])[1][1]
3×3 UpperTriangular{Float64,Array{Float64,2}}:
 0.228808   0.00318764   -0.107503
          -0.000391803   0.0132135
                        0.0438772

julia> Zygote.pullback(x -> SqMahalanobis(Array(x[1]'*x[1]))(x[2], x[3]), [U, v1, v2])[2](1)[1][1]
3×3 Array{Float64,2}:
  0.228808    0.00318764   -0.107503
 -0.0281234  -0.000391803   0.0132135
 -0.0933875  -0.00130103    0.0438772

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me your output indicates that it basically works apart from the fact that Zygote incorrectly returns a dense matrix instead of an upper triangular matrix. Since U was upper triangular, only the values above and on the diagonal should be returned.

Copy link
Contributor Author

@sharanry sharanry Aug 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FiniteDifferences if pretty good in matching the types. Zygote isn't. Do you suggest we manually check if the upper triangular part matches for now?

Edit: I don't we are addressing the major issue here. Our goal is to make the overall adjoint correct for kernelmatrix. So maybe defining a custom zygote adjoint for UpperTriangular which outputs a UpperTriangular might solve the problem.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Were the call to UpperTriangular inside the function, then the adjoint that you would get from Zygote would also be UpperTriangular. Maybe just do that?


# test_ADs(U -> MahalanobisKernel(Array(U' * U)), U, ADs=[:Zygote])
@test_broken "Nothing passes (problem with Mahalanobis distance in Distances)"
end