Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/basekernels/maha.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

Mahalanobis distance-based kernel given by
```math
κ(x,y) = exp(-r^2), r^2 = maha(x,P,y) = (x-y)'*inv(P)*(x-y)
κ(x,y) = exp(-r^2), r^2 = maha(x,P,y) = (x-y)'* P *(x-y)
```
where the matrix P is the metric.

Expand Down
28 changes: 28 additions & 0 deletions src/basekernels/nn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,32 @@ function (κ::NeuralNetworkKernel)(x, y)
return asin(dot(x, y) / sqrt((1 + sum(abs2, x)) * (1 + sum(abs2, y))))
end

function kernelmatrix(::NeuralNetworkKernel, x::ColVecs, y::ColVecs)
validate_inputs(x, y)
X_2 = sum(x.X .* x.X, dims=1)
Y_2 = sum(y.X .* y.X, dims=1)
XY = x.X' * y.X
return asin.(XY ./ sqrt.((X_2 .+ 1)' * (Y_2 .+ 1)))
end

function kernelmatrix(::NeuralNetworkKernel, x::ColVecs)
X_2_1 = sum(x.X .* x.X, dims=1) .+ 1
XX = x.X' * x.X
return asin.(XX ./ sqrt.(X_2_1' * X_2_1))
end

function kernelmatrix(::NeuralNetworkKernel, x::RowVecs, y::RowVecs)
validate_inputs(x, y)
X_2 = sum(x.X .* x.X, dims=2)
Y_2 = sum(y.X .* y.X, dims=2)
XY = x.X * y.X'
return asin.(XY ./ sqrt.((X_2 .+ 1)' * (Y_2 .+ 1)))
end

function kernelmatrix(::NeuralNetworkKernel, x::RowVecs)
X_2_1 = sum(x.X .* x.X, dims=2) .+ 1
XX = x.X * x.X'
return asin.(XX ./ sqrt.(X_2_1' * X_2_1))
end

Base.show(io::IO, κ::NeuralNetworkKernel) = print(io, "Neural Network Kernel")
33 changes: 33 additions & 0 deletions src/zygote_adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,36 @@ end
@adjoint function Base.map(t::Transform, X::RowVecs)
pullback(_map, t, X)
end

@adjoint function (dist::Distances.SqMahalanobis)(a, b)
function back(Δ::Real)
B_Bᵀ = dist.qmat + transpose(dist.qmat)
a_b = a - b
δa = B_Bᵀ * a_b
return (qmat = a_b * a_b',), δa, -δa
end
return evaluate(dist, a, b), back
end


@adjoint function Distances.pairwise(
dist::SqMahalanobis,
a::AbstractMatrix,
b::AbstractMatrix;
dims::Union{Nothing,Integer}=nothing
)
function back(Δ::AbstractMatrix)
B_Bᵀ = dist.qmat + transpose(dist.qmat)
a_b = map(
x -> (first(last(x)) - last(last(x)))*first(x),
zip(
Δ,
Iterators.product(eachslice(a, dims=dims), eachslice(b, dims=dims))
)
)
δa = reduce(hcat, sum(map(x -> B_Bᵀ*x, a_b), dims=2))
δB = sum(map(x -> x*transpose(x), a_b))
return (qmat=δB,), δa, -δa
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is som discrepancy between the simple case above and this pullback - intuitively, from the simple case above I would assume that δB = sum_{i, j} (a_i - b_j) * (a_i - b_j)^T * Δ_{i,j}. However, here you compute δB = sum_{i, j} (a_i - b_j) * (a_i - b_j)^T * Δ_{i,j}^2. Probably one of them is incorrect (table 7 in https://notendur.hi.is/jonasson/greinar/blas-rmd.pdf indicates that the pairwise one is incorrect). Can we add the derivation of the adjoints according to https://www.juliadiff.org/ChainRulesCore.jl/dev/arrays.html as docstrings or comments, or maybe even have a separate PR for the Mahalanobis fixes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. I think a separate PR for mahalanobis fixes makes more sense.

end
return Distances.pairwise(dist, a, b, dims=dims), back
end
3 changes: 1 addition & 2 deletions test/basekernels/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@
@test metric(GammaExponentialKernel(γ=2.0)) == SqEuclidean()
@test repr(k) == "Gamma Exponential Kernel (γ = $(γ))"
@test KernelFunctions.iskroncompatible(k) == true
test_ADs(γ -> GammaExponentialKernel(gamma=first(γ)), [γ], ADs = [:ForwardDiff, :ReverseDiff])
@test_broken "Zygote gradient given γ"
test_ADs(γ -> GammaExponentialKernel(gamma=first(γ)), [γ])
#Coherence :
@test GammaExponentialKernel(γ=1.0)(v1,v2) ≈ SqExponentialKernel()(v1,v2)
@test GammaExponentialKernel(γ=0.5)(v1,v2) ≈ ExponentialKernel()(v1,v2)
Expand Down
4 changes: 2 additions & 2 deletions test/basekernels/fbm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@
@test kernelmatrix(k, x1*ones(1,1), x2*ones(1,1))[1] k(x1, x2) atol=1e-5

@test repr(k) == "Fractional Brownian Motion Kernel (h = $(h))"
test_ADs(FBMKernel, ADs = [:ReverseDiff])
@test_broken "Tests failing for kernelmatrix(k, x) for ForwardDiff and Zygote"
test_ADs(FBMKernel, ADs = [:ReverseDiff, :Zygote])
@test_broken "Tests failing for kernelmatrix(k, x) for ForwardDiff"
end
3 changes: 1 addition & 2 deletions test/basekernels/gabor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
@test k.ell 1.0 atol=1e-5
@test k.p 1.0 atol=1e-5
@test repr(k) == "Gabor Kernel (ell = 1.0, p = 1.0)"
#test_ADs(x -> GaborKernel(ell = x[1], p = x[2]), [ell, p])#, ADs = [:ForwardDiff, :ReverseDiff])
@test_broken "Tests failing for Zygote on differentiating through ell and p"
test_ADs(x -> GaborKernel(ell = x[1], p = x[2]), [ell, p], ADs = [:Zygote])
# Tests are also failing randomly for ForwardDiff and ReverseDiff but randomly
end
18 changes: 16 additions & 2 deletions test/basekernels/maha.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,27 @@
v1 = rand(rng, 3)
v2 = rand(rng, 3)

P = rand(rng, 3, 3)
U = UpperTriangular(rand(rng, 3,3))
P = Matrix(Cholesky(U, 'U', 0))
@assert isposdef(P)
k = MahalanobisKernel(P)

@test kappa(k, x) == exp(-x)
@test k(v1, v2) exp(-sqmahalanobis(v1, v2, P))
@test kappa(ExponentialKernel(), x) == kappa(k, x)
@test repr(k) == "Mahalanobis Kernel (size(P) = $(size(P)))"
# test_ADs(P -> MahalanobisKernel(P), P)

M1, M2 = rand(rng,3,2), rand(rng,3,2)
fdm = FiniteDifferences.Central(5, 1);


FiniteDifferences.to_vec(dist::SqMahalanobis{Float64}) = vec(dist.qmat), x -> SqMahalanobis(reshape(x, size(dist.qmat)...))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed? If possible, we should avoid this type piracy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes j′vp only works when there is a to_vec function defined for each argument.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering since according to the docs to_vec is only needed for the inputs xs... but not the evaluated function f in j'vp(fdm, f, xs...).

Copy link
Contributor Author

@sharanry sharanry Aug 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I understand, it is also needed for objects like SqMahalanobis if they have parameters like qmat.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct, but actually for some reason we've not made FiniteDifferences handle functions-with-data properly yet, so you'll have to build the SqMaha object inside of the function that you're differentiating.


@test_broken j′vp(fdm, x -> MahalanobisKernel(Array(x[1]'*x[1]))(x[2], x[3]), 1, [U, v1, v2])
Zygote.pullback(x -> MahalanobisKernel(Array(x[1]'*x[1]))(x[2], x[3]), [U, v1, v2])[2](1)
@test all(j′vp(fdm, x -> SqMahalanobis(Array(x[1]'*x[1]))(x[2], x[3]), 1, [U, v1, v2])[1][1] .≈
Zygote.pullback(x -> SqMahalanobis(Array(x[1]'*x[1]))(x[2], x[3]), [U, v1, v2])[2](1)[1][1])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@devmotion I tried doing what you suggested. The tests still fail. This error probably propagates and causes even the first test to fail.

julia> j′vp(fdm, x -> SqMahalanobis(Array(x[1]'*x[1]))(x[2], x[3]), 1, [U, v1, v2])[1][1]
3×3 UpperTriangular{Float64,Array{Float64,2}}:
 0.228808   0.00318764   -0.107503
          -0.000391803   0.0132135
                        0.0438772

julia> Zygote.pullback(x -> SqMahalanobis(Array(x[1]'*x[1]))(x[2], x[3]), [U, v1, v2])[2](1)[1][1]
3×3 Array{Float64,2}:
  0.228808    0.00318764   -0.107503
 -0.0281234  -0.000391803   0.0132135
 -0.0933875  -0.00130103    0.0438772

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me your output indicates that it basically works apart from the fact that Zygote incorrectly returns a dense matrix instead of an upper triangular matrix. Since U was upper triangular, only the values above and on the diagonal should be returned.

Copy link
Contributor Author

@sharanry sharanry Aug 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FiniteDifferences if pretty good in matching the types. Zygote isn't. Do you suggest we manually check if the upper triangular part matches for now?

Edit: I don't we are addressing the major issue here. Our goal is to make the overall adjoint correct for kernelmatrix. So maybe defining a custom zygote adjoint for UpperTriangular which outputs a UpperTriangular might solve the problem.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Were the call to UpperTriangular inside the function, then the adjoint that you would get from Zygote would also be UpperTriangular. Maybe just do that?


# test_ADs(U -> MahalanobisKernel(Array(U' * U)), U, ADs=[:Zygote])
@test_broken "Nothing passes (problem with Mahalanobis distance in Distances)"
end
3 changes: 1 addition & 2 deletions test/basekernels/nn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,5 @@
@test_throws DimensionMismatch kernelmatrix!(A5, k, ones(4,3), ones(3,4))

@test k([x1], [x2]) ≈ k(x1, x2) atol=1e-5
test_ADs(NeuralNetworkKernel, ADs = [:ForwardDiff, :ReverseDiff])
@test_broken "Zygote uncompatible with BaseKernel"
test_ADs(NeuralNetworkKernel)
end
10 changes: 10 additions & 0 deletions test/zygote_adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
x = rand(rng, 5)
y = rand(rng, 5)
r = rand(rng, 5)
Q = Matrix(Cholesky(rand(rng, 5, 5), 'U', 0))
@assert isposdef(Q)


gzeucl = gradient(:Zygote, [x,y]) do xy
evaluate(Euclidean(), xy[1], xy[2])
Expand All @@ -20,6 +23,9 @@
gzsinus = gradient(:Zygote, [x,y]) do xy
evaluate(KernelFunctions.Sinus(r), xy[1], xy[2])
end
gzsqmaha = gradient(:Zygote, [Q,x,y]) do xy
evaluate(SqMahalanobis(xy[1]), xy[2], xy[3])
end

gfeucl = gradient(:FiniteDiff, [x,y]) do xy
evaluate(Euclidean(), xy[1], xy[2])
Expand All @@ -36,11 +42,15 @@
gfsinus = gradient(:FiniteDiff, [x,y]) do xy
evaluate(KernelFunctions.Sinus(r), xy[1], xy[2])
end
gfsqmaha = gradient(:FiniteDiff, [Q,x,y]) do xy
evaluate(SqMahalanobis(xy[1]), xy[2], xy[3])
end


@test all(gzeucl .≈ gfeucl)
@test all(gzsqeucl .≈ gfsqeucl)
@test all(gzdotprod .≈ gfdotprod)
@test all(gzdelta .≈ gfdelta)
@test all(gzsinus .≈ gfsinus)
@test all(gzsqmaha .≈ gfsqmaha)
end