Skip to content

Commit 0525191

Browse files
committed
wip: still does not work
1 parent 3b929e4 commit 0525191

File tree

2 files changed

+4
-33
lines changed

2 files changed

+4
-33
lines changed

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ using AbstractGPs:
1818
TestUtils
1919

2020
using Aqua
21-
using DifferentiationInterface
21+
import DifferentiationInterface as DI
2222
using DifferentiationInterface: gradient, jacobian, value_and_gradient, value_and_jacobian
2323
using Documenter
2424
using Distributions: MvNormal, PDMat, loglikelihood, Distributions

test/test_util.jl

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -55,38 +55,9 @@ function adjoint_test(
5555
f, ȳ, x...; rtol=_rtol, atol=_atol, fdm=central_fdm(5, 1), print_results=false
5656
)
5757
# Compute forwards-pass and j′vp.
58-
backend = AutoMooncake()
59-
y = f(x...)
60-
# Compute VJP using DifferentiationInterface
61-
# For vector-valued functions, we need to use value_and_jacobian and compute VJP manually
62-
if length(x) == 1
63-
# Single input case
64-
if y isa AbstractVector
65-
# Vector-valued function: compute jacobian and then VJP
66-
val, jac = value_and_jacobian(f, backend, x[1])
67-
adj_ad = (vec(ȳ' * jac),)
68-
else
69-
# Scalar-valued function: use gradient
70-
grad_ad = gradient(f, backend, x[1])
71-
adj_ad = (grad_ad .* ȳ,)
72-
end
73-
else
74-
# Multiple input case - compute jacobian for each input
75-
adj_ad = ntuple(length(x)) do i
76-
f_i(xi) = f(x[1:(i - 1)]..., xi, x[(i + 1):end]...)
77-
y_i = f_i(x[i])
78-
if y_i isa AbstractVector
79-
# Vector-valued function
80-
val, jac = value_and_jacobian(f_i, backend, x[i])
81-
vec(ȳ' * jac)
82-
else
83-
# Scalar-valued function
84-
grad_i = gradient(f_i, backend, x[i])
85-
grad_i .* ȳ
86-
end
87-
end
88-
end
89-
adj_fd = j′vp(fdm, f, ȳ, x...)
58+
_f = (x) -> f(x...)
59+
y, adj_ad = DI.value_and_pullback(_f, AutoMooncake(), x, ȳ)
60+
adj_fd = j′vp(fdm, f, ȳ, x...)
9061
9162
# Check that forwards-pass agrees with plain forwards-pass.
9263
@test y ≈ f(x...)

0 commit comments

Comments
 (0)