Skip to content

Commit aa2099e

Browse files
theogfdevmotiongithub-actions[bot]
authored
Fix gradient issues with kernelmatrix_diag and use ChainRulesCore (#208)
Co-authored-by: David Widmann <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent ae78b73 commit aa2099e

26 files changed

+366
-169
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.8.25"
3+
version = "0.8.26"
44

55
[deps]
6+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
67
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
78
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
89
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
@@ -17,8 +18,9 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1718
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1819

1920
[compat]
21+
ChainRulesCore = "0.9"
2022
Compat = "3.7"
21-
Distances = "0.9.1, 0.10"
23+
Distances = "0.10"
2224
Functors = "0.1"
2325
Requires = "1.0.1"
2426
SpecialFunctions = "0.8, 0.9, 0.10, 1"

src/KernelFunctions.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,14 @@ export IndependentMOKernel, LatentFactorMOKernel
5555
export tensor,
5656

5757
using Compat
58+
using ChainRulesCore: ChainRulesCore, Composite, Zero, One, DoesNotExist, NO_FIELDS
59+
using ChainRulesCore: @thunk, InplaceableThunk
5860
using Requires
5961
using Distances, LinearAlgebra
6062
using Functors
6163
using SpecialFunctions: loggamma, besselk, polygamma
62-
using ZygoteRules: @adjoint, pullback
63-
using StatsFuns: logtwo
64+
using ZygoteRules: ZygoteRules
65+
using StatsFuns: logtwo, twoπ
6466
using StatsBase
6567
using TensorCore
6668

@@ -112,7 +114,8 @@ include(joinpath("mokernels", "moinput.jl"))
112114
include(joinpath("mokernels", "independent.jl"))
113115
include(joinpath("mokernels", "slfm.jl"))
114116

115-
include("zygote_adjoints.jl")
117+
include("chainrules.jl")
118+
include("zygoterules.jl")
116119

117120
include("test_utils.jl")
118121

src/basekernels/fbm.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,14 @@ function kernelmatrix!(
6969
K .= _fbm.(_mod(x), _mod(y)', K, κ.h)
7070
return K
7171
end
72+
73+
function kernelmatrix_diag::FBMKernel, x::AbstractVector)
74+
modx = _mod(x)
75+
modxx = colwise(SqEuclidean(), x)
76+
return _fbm.(modx, modx, modxx, κ.h)
77+
end
78+
79+
function kernelmatrix_diag::FBMKernel, x::AbstractVector, y::AbstractVector)
80+
modxy = colwise(SqEuclidean(), x, y)
81+
return _fbm.(_mod(x), _mod(y), modxy, κ.h)
82+
end

src/basekernels/gabor.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,7 @@ function kernelmatrix(κ::GaborKernel, x::AbstractVector, y::AbstractVector)
7272
end
7373

7474
kernelmatrix_diag::GaborKernel, x::AbstractVector) = kernelmatrix_diag.kernel, x)
75+
76+
function kernelmatrix_diag::GaborKernel, x::AbstractVector, y::AbstractVector)
77+
return kernelmatrix_diag.kernel, x, y)
78+
end

src/basekernels/nn.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,19 @@ function kernelmatrix(::NeuralNetworkKernel, x::ColVecs)
5151
return asin.(XX ./ sqrt.(X_2_1' * X_2_1))
5252
end
5353

54+
function kernelmatrix_diag(::NeuralNetworkKernel, x::ColVecs)
55+
x_2 = vec(sum(x.X .* x.X; dims=1))
56+
return asin.(x_2 ./ (x_2 .+ 1))
57+
end
58+
59+
function kernelmatrix_diag(::NeuralNetworkKernel, x::ColVecs, y::ColVecs)
60+
validate_inputs(x, y)
61+
x_2 = vec(sum(x.X .* x.X; dims=1) .+ 1)
62+
y_2 = vec(sum(y.X .* y.X; dims=1) .+ 1)
63+
xy = vec(sum(x.X' .* y.X'; dims=2))
64+
return asin.(xy ./ sqrt.(x_2 .* y_2))
65+
end
66+
5467
function kernelmatrix(::NeuralNetworkKernel, x::RowVecs, y::RowVecs)
5568
validate_inputs(x, y)
5669
X_2 = sum(x.X .* x.X; dims=2)
@@ -65,4 +78,17 @@ function kernelmatrix(::NeuralNetworkKernel, x::RowVecs)
6578
return asin.(XX ./ sqrt.(X_2_1 * X_2_1'))
6679
end
6780

81+
function kernelmatrix_diag(::NeuralNetworkKernel, x::RowVecs)
82+
x_2 = vec(sum(x.X .* x.X; dims=2))
83+
return asin.(x_2 ./ (x_2 .+ 1))
84+
end
85+
86+
function kernelmatrix_diag(::NeuralNetworkKernel, x::RowVecs, y::RowVecs)
87+
validate_inputs(x, y)
88+
x_2 = vec(sum(x.X .* x.X; dims=2) .+ 1)
89+
y_2 = vec(sum(y.X .* y.X; dims=2) .+ 1)
90+
xy = vec(sum(x.X .* y.X; dims=2))
91+
return asin.(xy ./ sqrt.(x_2 .* y_2))
92+
end
93+
6894
Base.show(io::IO, ::NeuralNetworkKernel) = print(io, "Neural Network Kernel")

src/chainrules.jl

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
## Forward Rules
2+
3+
# Note that this is type piracy as the derivative should be NaN for x == y.
4+
function ChainRulesCore.frule(
5+
(_, Δx, Δy), d::Distances.Euclidean, x::AbstractVector, y::AbstractVector
6+
)
7+
Δ = x - y
8+
D = sqrt(sum(abs2, Δ))
9+
if !iszero(D)
10+
Δ ./= D
11+
end
12+
return D, dot(Δ, Δx) - dot(Δ, Δy)
13+
end
14+
15+
## Reverse Rules Delta
16+
17+
function ChainRulesCore.rrule(dist::Delta, x::AbstractVector, y::AbstractVector)
18+
d = dist(x, y)
19+
function evaluate_pullback(::Any)
20+
return NO_FIELDS, Zero(), Zero()
21+
end
22+
return d, evaluate_pullback
23+
end
24+
25+
function ChainRulesCore.rrule(
26+
::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2
27+
)
28+
P = Distances.pairwise(d, X, Y; dims=dims)
29+
function pairwise_pullback(::AbstractMatrix)
30+
return NO_FIELDS, NO_FIELDS, Zero(), Zero()
31+
end
32+
return P, pairwise_pullback
33+
end
34+
35+
function ChainRulesCore.rrule(
36+
::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix; dims=2
37+
)
38+
P = Distances.pairwise(d, X; dims=dims)
39+
function pairwise_pullback(::AbstractMatrix)
40+
return NO_FIELDS, NO_FIELDS, Zero()
41+
end
42+
return P, pairwise_pullback
43+
end
44+
45+
function ChainRulesCore.rrule(
46+
::typeof(Distances.colwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix
47+
)
48+
C = Distances.colwise(d, X, Y)
49+
function colwise_pullback(::AbstractVector)
50+
return NO_FIELDS, NO_FIELDS, Zero(), Zero()
51+
end
52+
return C, colwise_pullback
53+
end
54+
55+
## Reverse Rules DotProduct
56+
57+
function ChainRulesCore.rrule(dist::DotProduct, x::AbstractVector, y::AbstractVector)
58+
d = dist(x, y)
59+
function evaluate_pullback::Any)
60+
return NO_FIELDS, Δ .* y, Δ .* x
61+
end
62+
return d, evaluate_pullback
63+
end
64+
65+
function ChainRulesCore.rrule(
66+
::typeof(Distances.pairwise),
67+
d::DotProduct,
68+
X::AbstractMatrix,
69+
Y::AbstractMatrix;
70+
dims=2,
71+
)
72+
P = Distances.pairwise(d, X, Y; dims=dims)
73+
function pairwise_pullback_cols::AbstractMatrix)
74+
if dims == 1
75+
return NO_FIELDS, NO_FIELDS, Δ * Y, Δ' * X
76+
else
77+
return NO_FIELDS, NO_FIELDS, Y * Δ', X * Δ
78+
end
79+
end
80+
return P, pairwise_pullback_cols
81+
end
82+
83+
function ChainRulesCore.rrule(
84+
::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix; dims=2
85+
)
86+
P = Distances.pairwise(d, X; dims=dims)
87+
function pairwise_pullback_cols::AbstractMatrix)
88+
if dims == 1
89+
return NO_FIELDS, NO_FIELDS, 2 * Δ * X
90+
else
91+
return NO_FIELDS, NO_FIELDS, 2 * X * Δ
92+
end
93+
end
94+
return P, pairwise_pullback_cols
95+
end
96+
97+
function ChainRulesCore.rrule(
98+
::typeof(Distances.colwise), d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix
99+
)
100+
C = Distances.colwise(d, X, Y)
101+
function colwise_pullback::AbstractVector)
102+
return NO_FIELDS, NO_FIELDS, Δ' .* Y, Δ' .* X
103+
end
104+
return C, colwise_pullback
105+
end
106+
107+
## Reverse Rules Sinus
108+
109+
function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector)
110+
d = x - y
111+
sind = sinpi.(d)
112+
abs2_sind_r = abs2.(sind) ./ s.r
113+
val = sum(abs2_sind_r)
114+
gradx = twoπ .* cospi.(d) .* sind ./ (s.r .^ 2)
115+
function evaluate_pullback::Any)
116+
return (r=-2Δ .* abs2_sind_r,), Δ * gradx, -Δ * gradx
117+
end
118+
return val, evaluate_pullback
119+
end
120+
121+
## Reverse Rulse SqMahalanobis
122+
123+
function ChainRulesCore.rrule(
124+
dist::Distances.SqMahalanobis, a::AbstractVector, b::AbstractVector
125+
)
126+
d = dist(a, b)
127+
function SqMahalanobis_pullback::Real)
128+
a_b = a - b
129+
∂qmat = InplaceableThunk(
130+
@thunk((a_b * a_b') * Δ), X̄ -> mul!(X̄, a_b, a_b', true, Δ)
131+
)
132+
∂a = InplaceableThunk(
133+
@thunk((2 * Δ) * dist.qmat * a_b), X̄ -> mul!(X̄, dist.qmat, a_b, true, 2 * Δ)
134+
)
135+
∂b = InplaceableThunk(
136+
@thunk((-2 * Δ) * dist.qmat * a_b), X̄ -> mul!(X̄, dist.qmat, a_b, true, -2 * Δ)
137+
)
138+
return Composite{typeof(dist)}(; qmat=∂qmat), ∂a, ∂b
139+
end
140+
return d, SqMahalanobis_pullback
141+
end
142+
143+
## Reverse Rules for matrix wrappers
144+
145+
function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix)
146+
ColVecs_pullback::Composite) = (NO_FIELDS, Δ.X)
147+
function ColVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}})
148+
return error(
149+
"Pullback on AbstractVector{<:AbstractVector}.\n" *
150+
"This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n" *
151+
"To solve this issue overload `kernelmatrix(_diag)` for your kernel for `ColVecs`",
152+
)
153+
end
154+
return ColVecs(X), ColVecs_pullback
155+
end
156+
157+
function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix)
158+
RowVecs_pullback::Composite) = (NO_FIELDS, Δ.X)
159+
function RowVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}})
160+
return error(
161+
"Pullback on AbstractVector{<:AbstractVector}.\n" *
162+
"This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n" *
163+
"To solve this issue overload `kernelmatrix(_diag)` for your kernel for `RowVecs`",
164+
)
165+
end
166+
return RowVecs(X), RowVecs_pullback
167+
end

src/distances/delta.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
struct Delta <: Distances.PreMetric end
1+
# Delta is not following the PreMetric rules since d(x, x) == 1
2+
struct Delta <: Distances.UnionPreMetric end
23

34
@inline function Distances._evaluate(::Delta, a::AbstractVector, b::AbstractVector)
45
@boundscheck if length(a) != length(b)
@@ -12,7 +13,7 @@ struct Delta <: Distances.PreMetric end
1213
return a == b
1314
end
1415

15-
Distances.result_type(::Delta, Ta::Type, Tb::Type) = promote_type(Ta, Tb)
16+
Distances.result_type(::Delta, Ta::Type, Tb::Type) = Bool
1617

1718
@inline (dist::Delta)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b)
1819
@inline (dist::Delta)(a::Number, b::Number) = a == b

src/distances/dotproduct.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
struct DotProduct <: Distances.PreMetric end
2-
# struct DotProduct <: Distances.UnionSemiMetric end
1+
## DotProduct is not following the PreMetric rules since d(x, x) != 0 and d(x, y) >= 0 for all x, y
2+
struct DotProduct <: Distances.UnionPreMetric end
33

44
@inline function Distances._evaluate(::DotProduct, a::AbstractVector, b::AbstractVector)
55
@boundscheck if length(a) != length(b)

src/distances/pairwise.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,34 @@ function pairwise!(
2929
)
3030
return Distances.pairwise!(out, d, reshape(x, :, 1), reshape(y, :, 1); dims=1)
3131
end
32+
33+
# Also defines the colwise method for abstractvectors
34+
35+
function colwise(d::PreMetric, x::AbstractVector)
36+
return zeros(Distances.result_type(d, x, x), length(x)) # Valid since d(x,x) == 0 by definition
37+
end
38+
39+
## The following is a hack for DotProduct and Delta to still work
40+
function colwise(d::Distances.UnionPreMetric, x::ColVecs)
41+
return Distances.colwise(d, x.X, x.X)
42+
end
43+
44+
function colwise(d::Distances.UnionPreMetric, x::RowVecs)
45+
return Distances.colwise(d, x.X', x.X')
46+
end
47+
48+
function colwise(d::Distances.UnionPreMetric, x::AbstractVector)
49+
return map(d, x, x)
50+
end
51+
52+
function colwise(d::PreMetric, x::ColVecs, y::ColVecs)
53+
return Distances.colwise(d, x.X, y.X)
54+
end
55+
56+
function colwise(d::PreMetric, x::RowVecs, y::RowVecs)
57+
return Distances.colwise(d, x.X', y.X')
58+
end
59+
60+
function colwise(d::PreMetric, x::AbstractVector, y::AbstractVector)
61+
return map(d, x, y)
62+
end

src/distances/sinus.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
struct Sinus{T} <: Distances.SemiMetric
2-
# struct Sinus{T} <: Distances.UnionSemiMetric
1+
struct Sinus{T} <: Distances.UnionSemiMetric
32
r::Vector{T}
43
end
54

0 commit comments

Comments
 (0)