Skip to content

Commit c1ea0d0

Browse files
committed
Adding tests to check implicit behavior
1 parent 7ec483b commit c1ea0d0

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

test/kernels/transformedkernel.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,27 @@
5252
end
5353
end
5454
test_ADs(x->transform(SqExponentialKernel(), x[1]), rand(1))# ADs = [:ForwardDiff, :ReverseDiff])
55+
# Test implicit gradients
56+
@testset "Implicit gradients" begin
57+
k = transform(SqExponentialKernel(), 2.0)
58+
ps = Flux.params(k)
59+
X = rand(10, 1); x = vec(X)
60+
A = rand(10, 10)
61+
# Implicit
62+
g1 = Flux.gradient(ps) do
63+
tr(kernelmatrix(k, X, obsdim = 1) * A)
64+
end
65+
# Explicit
66+
g2 = Flux.gradient(k) do k
67+
tr(kernelmatrix(k, X, obsdim = 1) * A)
68+
end
69+
70+
# Implicit for a vector
71+
g3 = Flux.gradient(ps) do
72+
tr(kernelmatrix(k, x) * A)
73+
end
74+
@test g1[first(ps)] first(g2).transform.s
75+
@test g1[first(ps)] g3[first(ps)]
76+
end
77+
5578
end

0 commit comments

Comments
 (0)