Skip to content

Commit cfdb033

Browse files
authored
Fix AD issue with mixed input for TransformedKernel (#160)
* Fix mixed input bug for TransformedKernel * Fix tests * Support mized input in pairwise * avoid splatting * style fixes * use permute dims
1 parent 0e3ed6d commit cfdb033

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

src/utils.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ dim(x::ColVecs) = size(x.X, 1)
4343

4444
pairwise(d::PreMetric, x::ColVecs) = Distances.pairwise(d, x.X; dims=2)
4545
pairwise(d::PreMetric, x::ColVecs, y::ColVecs) = Distances.pairwise(d, x.X, y.X; dims=2)
46+
function pairwise(d::PreMetric, x::AbstractVector, y::ColVecs)
47+
return Distances.pairwise(d, reduce(hcat, x), y.X; dims=2)
48+
end
49+
function pairwise(d::PreMetric, x::ColVecs, y::AbstractVector)
50+
return Distances.pairwise(d, x.X, reduce(hcat, y); dims=2)
51+
end
4652
function pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs)
4753
return Distances.pairwise!(out, d, x.X; dims=2)
4854
end
@@ -73,6 +79,12 @@ dim(x::RowVecs) = size(x.X, 2)
7379

7480
pairwise(d::PreMetric, x::RowVecs) = Distances.pairwise(d, x.X; dims=1)
7581
pairwise(d::PreMetric, x::RowVecs, y::RowVecs) = Distances.pairwise(d, x.X, y.X; dims=1)
82+
function pairwise(d::PreMetric, x::AbstractVector, y::RowVecs)
83+
return Distances.pairwise(d, permutedims(reduce(hcat, x)), y.X; dims=1)
84+
end
85+
function pairwise(d::PreMetric, x::RowVecs, y::AbstractVector)
86+
return Distances.pairwise(d, x.X, permutedims(reduce(hcat, y)); dims=1)
87+
end
7688
function pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs)
7789
return Distances.pairwise!(out, d, x.X; dims=1)
7890
end

test/kernels/transformedkernel.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,22 @@
5050
tmp_diag = Vector{Float64}(undef, length(x))
5151
@test kerneldiagmatrix!(tmp_diag, kt, x) kerneldiagmatrix(kt, x)
5252
end
53+
54+
@testset "mixed inputs" begin
55+
k = transform(SqExponentialKernel(), 10.0)
56+
x = rand(rng, 5, 3)
57+
X1 = collect(eachcol(x))
58+
Y1 = KernelFunctions.ColVecs(x)
59+
@test_nowarn Zygote.gradient(k-> sum(kernelmatrix(k, X1, Y1)), k)
60+
@test_nowarn Zygote.gradient(k-> sum(kernelmatrix(k, Y1, X1)), k)
61+
@test kernelmatrix(k, X1, Y1) kernelmatrix(k, X1, X1) kernelmatrix(k, Y1, Y1)
62+
63+
X2 = collect(eachrow(x))
64+
Y2 = KernelFunctions.RowVecs(x)
65+
@test_nowarn Zygote.gradient(k-> sum(kernelmatrix(k, X2, Y2)), k)
66+
@test_nowarn Zygote.gradient(k-> sum(kernelmatrix(k, Y2, X2)), k)
67+
@test kernelmatrix(k, X2, Y2) kernelmatrix(k, X2, X2) kernelmatrix(k, Y2, Y2)
68+
end
5369
end
5470
test_ADs(x->transform(SqExponentialKernel(), x[1]), rand(1))# ADs = [:ForwardDiff, :ReverseDiff])
5571
# Test implicit gradients

0 commit comments

Comments
 (0)