Skip to content

Commit 4dfffd4

Browse files
committed
Added adjoint and kmatrix(X,Y)
1 parent 73d68b8 commit 4dfffd4

File tree

5 files changed

+24
-6
lines changed

5 files changed

+24
-6
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ version = "0.1.0"
66
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
9+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

src/KernelFunctions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ export Kernel, SquaredExponentialKernel
66
export Transform, ScaleTransform
77

88
using Distances, LinearAlgebra
9+
using Zygote: @adjoint
910

1011
const defaultobs = 2
1112
abstract type Kernel{T,Tr} end

src/kernelmatrix.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,18 @@ function kernelmatrix(
116116
Y::AbstractMatrix{T₂};
117117
obsdim=defaultobs
118118
) where {T,T₁<:Real,T₂<:Real}
119-
Tₖ = typeof(zero(eltype(X))*zero(eltype(Y))*zero(T))
120-
m = size(X,obsdim)
121-
n = size(Y,obsdim)
122-
kernelmatrix!(Matrix{Tₖ}(undef,m,n),κ,X,Y,obsdim=obsdim)
119+
# Tₖ = typeof(zero(eltype(X))*zero(T))
120+
# m = size(X,obsdim)
121+
K = map(x->kappa(κ,x),pairwise(metric(κ),transform(κ,X,obsdim),transform(κ,Y,obsdim),dims=obsdim))
122+
# K = Matrix{Tₖ}(undef,m,m)
123+
# for i in 1:m
124+
# tx = transform(κ,@view X[i,:])
125+
# for j in 1:i
126+
# K[i,j] = kappa(κ,kernel(κ,tx,transform(@view X[j,:])))
127+
# end
128+
# end
129+
return K
130+
# return kernelmatrix!(Matrix{Tₖ}(undef,m,m),κ,X,obsdim=obsdim,symmetrize=symmetrize)
123131
end
124132

125133

src/transform/transform.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,12 @@ end
3232
transform(t::ScaleTransform{<:AbstractVector{<:Real}},x::AbstractVector{<:Real}) = t.s.*x
3333
transform(t::ScaleTransform{<:AbstractVector{<:Real}},X::AbstractMatrix{<:Real},obsdim::Int) = obsdim == 1 ? t.s'.*X : t.s.*X
3434

35-
transform(t::ScaleTransform{<:Real},x::AbstractVecOrMat) = t.s*x
35+
transform(t::ScaleTransform{<:Real},x::AbstractVecOrMat,obsdim::Int) = t.s*x
36+
37+
@adjoint transform(t::ScaleTransform{<:AbstractVector{<:Real}},x::AbstractVector{<:Real}) = transform(t,x),Δ->.*x,t.s.*Δ)
38+
@adjoint transform(t::ScaleTransform{<:AbstractVector{<:Real}},X::AbstractMatrix{<:Real},obsdim::Int) = transform(t,X,obsdim),Δ->begin
39+
@show Δ,size(Δ);
40+
return (obsdim == 1 ? Δ'.*X : Δ.*X,transform(t,Δ,obsdim),nothing)
41+
end
42+
43+
@adjoint transform(t::ScaleTransform{<:Real},x::AbstractVecOrMat,obsdim::Int) = transform(t,x), Δ->.s.*x,t.s.*Δ)

test/testAD.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ vl = l*ones(dims[1])
1313
testfunction(k,A,B) = sum(kernelmatrix(k,A,B))
1414
testfunction(k,A) = sum(kernelmatrix(k,A))
1515

16-
16+
testfunction(SquaredExponentialKernel(vl),A)
1717
#For debugging
1818

1919
## Zygote

0 commit comments

Comments
 (0)