@@ -55,38 +55,9 @@ function adjoint_test(
5555 f, ȳ, x...; rtol=_rtol, atol=_atol, fdm=central_fdm(5, 1), print_results=false
5656)
5757 # Compute forwards-pass and j′vp.
58- backend = AutoMooncake()
59- y = f(x...)
60- # Compute VJP using DifferentiationInterface
61- # For vector-valued functions, we need to use value_and_jacobian and compute VJP manually
62- if length(x) == 1
63- # Single input case
64- if y isa AbstractVector
65- # Vector-valued function: compute jacobian and then VJP
66- val, jac = value_and_jacobian(f, backend, x[1])
67- adj_ad = (vec(ȳ' * jac),)
68- else
69- # Scalar-valued function: use gradient
70- grad_ad = gradient(f, backend, x[1 ])
71- adj_ad = (grad_ad .* ȳ,)
72- end
73- else
74- # Multiple input case - compute jacobian for each input
75- adj_ad = ntuple(length(x)) do i
76- f_i(xi) = f(x[1 : (i - 1 )]. .. , xi, x[(i + 1 ): end ]. .. )
77- y_i = f_i(x[i])
78- if y_i isa AbstractVector
79- # Vector-valued function
80- val, jac = value_and_jacobian(f_i, backend, x[i])
81- vec(ȳ' * jac)
82- else
83- # Scalar-valued function
84- grad_i = gradient(f_i, backend, x[i])
85- grad_i .* ȳ
86- end
87- end
88- end
89- adj_fd = j′vp(fdm, f, ȳ, x...)
58+ _f = (x) -> f(x...)
59+ y, adj_ad = DI.value_and_pullback(_f, AutoMooncake(), x, ȳ)
60+ adj_fd = j′vp(fdm, f, ȳ, x...)
9061
9162 # Check that forwards-pass agrees with plain forwards-pass.
9263 @test y ≈ f(x...)
0 commit comments