Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f3afdec
Initial plan
Copilot Aug 28, 2025
2e3ac84
Replace Zygote with DifferentiationInterface + Mooncake in tests and …
Copilot Aug 28, 2025
9b4d11c
Fix test utilities and examples to work correctly with Differentiatio…
Copilot Aug 28, 2025
96a0bd9
Apply suggestions from code review
yebai Aug 28, 2025
d2e5b9d
Update test/test_util.jl
yebai Aug 28, 2025
d41c9a1
Complete and validate DifferentiationInterface + Mooncake migration
Copilot Aug 28, 2025
400b859
Merge branch 'main' into copilot/fix-427
yebai Aug 28, 2025
f762a6f
Remove all commented-out code from test/finite_gp_projection.jl
Copilot Aug 28, 2025
7b6fc4d
Update test/finite_gp_projection.jl
yebai Aug 28, 2025
cb57b09
Use value_and_gradient for efficiency and update dependency versions
Copilot Aug 28, 2025
20adc06
Fix adjoint_test to use value_and_jacobian for vector-valued functions
Copilot Aug 28, 2025
305b761
Merge branch 'main' into copilot/fix-427
yebai Aug 29, 2025
280354d
Update test/test_util.jl
yebai Aug 29, 2025
3b929e4
fix example mauna loa
yebai Aug 29, 2025
0525191
wip: still does not work
yebai Aug 29, 2025
503dd3d
fix Parametric Heteroscedastic Model
yebai Aug 29, 2025
91e99d7
fix deep kernel learning example
yebai Aug 29, 2025
5ede796
fix more tests.
yebai Aug 29, 2025
33db1e8
Update Project.toml
yebai Aug 29, 2025
861594a
Update examples/3-parametric-heteroscedastic/script.jl
yebai Aug 29, 2025
833dcdf
Fix AutoMooncake instantiation in test_util.jl
yebai Aug 29, 2025
8ccd32b
Import DifferentiationInterface and update training step
yebai Aug 29, 2025
30af1f9
Add DifferentiationInterface to runtests.jl
yebai Aug 29, 2025
121129e
Update Project.toml
yebai Aug 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/1-mauna-loa/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing compat bounds for DI and Mooncake

AbstractGPs = "0.5"
Expand All @@ -16,4 +17,3 @@ Literate = "2"
Optim = "1"
ParameterHandling = "0.4, 0.5"
Plots = "1"
Zygote = "0.6, 0.7"
12 changes: 7 additions & 5 deletions examples/1-mauna-loa/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ using CSV, DataFrames # data loading
using AbstractGPs # exact GP regression
using ParameterHandling # for nested and constrained parameters
using Optim # optimization
using Zygote # auto-diff gradient computation
import DifferentiationInterface as DI # auto-diff interface
using Mooncake # AD backend
using Plots # visualisation

# Let's load and visualize the dataset.
Expand Down Expand Up @@ -225,14 +226,15 @@ function optimize_loss(loss, θ_init; optimizer=default_optimizer, maxiter=1_000
loss_packed = loss ∘ unflatten

## https://julianlsolvers.github.io/Optim.jl/stable/#user/tipsandtricks/#avoid-repeating-computations
## TODO: enable `prep = DI.prepare_gradient(f, backend, x)`
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really important for performance

function fg!(F, G, x)
if F !== nothing && G !== nothing
val, grad = Zygote.withgradient(loss_packed, x)
G .= only(grad)
val, grad = DI.value_and_gradient(loss_packed, AutoMooncake(), x)
G .= grad
Comment on lines +232 to +233
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use DI.value_and_gradient!?

return val
elseif G !== nothing
grad = Zygote.gradient(loss_packed, x)
G .= only(grad)
grad = DI.gradient(loss_packed, AutoMooncake(), x)
G .= grad
Comment on lines +236 to +237
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use DI.gradient!?

return nothing
elseif F !== nothing
return loss_packed(x)
Expand Down
4 changes: 2 additions & 2 deletions examples/2-deep-kernel-learning/Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing compat bounds for DI and Mooncake

AbstractGPs = "0.3,0.4,0.5"
Expand All @@ -20,5 +21,4 @@ Lux = "1"
MLDataUtils = "0.5"
Optimisers = "0.4"
Plots = "1"
Zygote = "0.7"
julia = "1.10"
5 changes: 3 additions & 2 deletions examples/2-deep-kernel-learning/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ using Lux
using Optimisers
using Plots
using Random
using Zygote
using Mooncake
import DifferentiationInterface as DI
default(; legendfontsize=15.0, linewidth=3.0);

Random.seed!(42) # for reproducibility
Expand Down Expand Up @@ -91,7 +92,7 @@ anim = Animation()
let tstate = Training.TrainState(neuralnet, ps, st, Optimisers.Adam(0.005))
for i in 1:nmax
_, loss_val, _, tstate = Training.single_train_step!(
AutoZygote(), update_kernel_and_loss, (), tstate
DI.AutoMooncake(), update_kernel_and_loss, (), tstate
)

if i % 10 == 0
Expand Down
4 changes: 2 additions & 2 deletions examples/3-parametric-heteroscedastic/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
AbstractGPsMakie = "7834405d-1089-4985-bd30-732a30b92057"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing compat bounds for DI and Mooncake

AbstractGPs = "0.5"
Expand All @@ -18,4 +19,3 @@ KernelFunctions = "0.10"
Literate = "2"
Optim = "1"
ParameterHandling = "0.4, 0.5"
Zygote = "0.6, 0.7"
12 changes: 6 additions & 6 deletions examples/3-parametric-heteroscedastic/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
using AbstractGPs
using AbstractGPsMakie
using CairoMakie
import DifferentiationInterface as DI
using KernelFunctions
using Mooncake
using Optim
using ParameterHandling
using Zygote

using LinearAlgebra
using Random
Expand Down Expand Up @@ -47,15 +48,14 @@ end;

# We use L-BFGS for optimising the objective function.
# It is a first-order method and hence requires computing the gradient of the objective function.
# We do not derive and implement the gradient function manually here but instead use reverse-mode automatic differentiation with Zygote.
# When computing gradients with Zygote, the objective function is evaluated as well.
# We do not derive and implement the gradient function manually here but instead use reverse-mode automatic differentiation with DifferentiationInterface + Mooncake.
# When computing gradients, the objective function is evaluated as well.
# We can exploit this and [avoid re-evaluating the objective function](https://julianlsolvers.github.io/Optim.jl/stable/#user/tipsandtricks/#avoid-repeating-computations) in such cases.
function objective_and_gradient(F, G, flat_θ)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: enable preparation with DI.prepare_gradient

if G !== nothing
val_grad = Zygote.withgradient(objective, flat_θ)
copyto!(G, only(val_grad.grad))
val, grad = DI.value_and_gradient!(objective, G, DI.AutoMooncake(), flat_θ)
if F !== nothing
return val_grad.val
return val
end
end
if F !== nothing
Expand Down
6 changes: 4 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,30 +1,32 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Aqua = "0.8"
DifferentiationInterface = "0.7"
Distributions = "0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25"
Documenter = "1"
FillArrays = "0.11, 0.12, 0.13, 1"
FiniteDifferences = "0.9.6, 0.10, 0.11, 0.12"
LinearAlgebra = "1"
Mooncake = "0.4"
PDMats = "0.11"
Pkg = "1"
Plots = "1"
Random = "1"
Statistics = "1"
Test = "1"
Zygote = "0.5, 0.6, 0.7"
julia = "1.6"
129 changes: 0 additions & 129 deletions test/finite_gp_projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,9 @@ end

# Check gradient of logpdf at mean is zero for `f`.
adjoint_test(ŷ -> logpdf(fx, ŷ), 1, ones(size(ŷ)))
lp, back = Zygote.pullback(ŷ -> logpdf(fx, ŷ), ones(size(ŷ)))
@test back(randn(rng))[1] == zeros(size(ŷ))
Comment on lines -154 to -155
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why were these tests removed?


# Check that gradient of logpdf at mean is zero for `y`.
adjoint_test(ŷ -> logpdf(y, ŷ), 1, ones(size(ŷ)))
lp, back = Zygote.pullback(ŷ -> logpdf(y, ŷ), ones(size(ŷ)))
@test back(randn(rng))[1] == zeros(size(ŷ))

# Check that gradient w.r.t. inputs is approximately correct for `f`.
x, l̄ = randn(rng, N), randn(rng)
Expand Down Expand Up @@ -211,128 +207,3 @@ end
docstring = string(Docs.doc(logpdf, Tuple{AbstractGPs.FiniteGP,Vector{Float64}}))
@test occursin("logpdf(f::FiniteGP, y::AbstractVecOrMat{<:Real})", docstring)
end

# """
# simple_gp_tests(rng::AbstractRNG, f::GP, xs::AV{<:AV}, σs::AV{<:Real})

# Integration tests for simple GPs.
# """
# function simple_gp_tests(
# rng::AbstractRNG,
# f::GP,
# xs::AV{<:AV},
# isp_σs::AV{<:Real};
# atol=1e-8,
# rtol=1e-8,
# )
# for x in xs, isp_σ in isp_σs

# # Test gradient w.r.t. random sampling.
# N = length(x)
# adjoint_test(
# (x, isp_σ)->rand(_rng(), f(x, exp(isp_σ)^2)),
# randn(rng, N),
# x,
# isp_σ,;
# atol=atol, rtol=rtol,
# )
# adjoint_test(
# (x, isp_σ)->rand(_rng(), f(x, exp(isp_σ)^2), 11),
# randn(rng, N, 11),
# x,
# isp_σ,;
# atol=atol, rtol=rtol,
# )

# # Check that gradient w.r.t. logpdf is correct.
# y, l̄ = rand(rng, f(x, exp(isp_σ))), randn(rng)
# adjoint_test(
# (x, isp_σ, y)->logpdf(f(x, exp(isp_σ)), y),
# l̄, x, isp_σ, y;
# atol=atol, rtol=rtol,
# )

# # Check that elbo is tight-ish when it's meant to be.
# fx, yx = f(x, 1e-9), f(x, exp(isp_σ))
# @test isapprox(elbo(yx, y, fx), logpdf(yx, y); atol=1e-6, rtol=1e-6)

# # Check that gradient w.r.t. elbo is correct.
# adjoint_test(
# (x, ŷ, isp_σ)->elbo(f(x, exp(isp_σ)), ŷ, f(x, 1e-9)),
# randn(rng), x, y, isp_σ;
# atol=1e-6, rtol=1e-6,
# )
# end
# end

# __foo(x) = isnothing(x) ? "nothing" : x

# @testset "FiniteGP (integration)" begin
# rng = MersenneTwister(123456)
# xs = [collect(range(-3.0, stop=3.0, length=N)) for N in [2, 5, 10]]
# σs = log.([1e-1, 1e0, 1e1])
# for (k, name, atol, rtol) in vcat(
# [
# (EQ(), "EQ", 1e-6, 1e-6),
# (Linear(), "Linear", 1e-6, 1e-6),
# (PerEQ(), "PerEQ", 5e-5, 1e-8),
# (Exp(), "Exp", 1e-6, 1e-6),
# ],
# [(
# k(α=α, β=β, l=l),
# "$k_name(α=$(__foo(α)), β=$(__foo(β)), l=$(__foo(l)))",
# 1e-6,
# 1e-6,
# )
# for (k, k_name) in ((EQ, "EQ"), (Linear, "linear"), (Matern12, "exp"))
# for α in (nothing, randn(rng))
# for β in (nothing, exp(randn(rng)))
# for l in (nothing, randn(rng))
# ],
# )
# @testset "$name" begin
# simple_gp_tests(_rng(), GP(k, GPC()), xs, σs; atol=atol, rtol=rtol)
# end
# end
# end

# @testset "FiniteGP (BlockDiagonal obs noise)" begin
# rng, Ns = MersenneTwister(123456), [4, 5]
# x = collect(range(-5.0, 5.0; length=sum(Ns)))
# As = [randn(rng, N, N) for N in Ns]
# Ss = [A' * A + I for A in As]

# S = block_diagonal(Ss)
# Smat = Matrix(S)

# f = GP(cos, EQ(), GPC())
# y = rand(f(x, S))

# @test logpdf(f(x, S), y) ≈ logpdf(f(x, Smat), y)
# adjoint_test(
# (x, S, y)->logpdf(f(x, S), y), randn(rng), x, Smat, y;
# atol=1e-6, rtol=1e-6,
# )
# adjoint_test(
# (x, A1, A2, y)->logpdf(f(x, block_diagonal([A1 * A1' + I, A2 * A2' + I])), y),
# randn(rng), x, As[1], As[2], y;
# atol=1e-6, rtol=1e-6
# )

# @test elbo(f(x, Smat), y, f(x)) ≈ logpdf(f(x, Smat), y)
# @test elbo(f(x, S), y, f(x)) ≈
# elbo(f(x, Smat), y, f(x))
# adjoint_test(
# (x, A, y)->elbo(f(x, _to_psd(A)), y, f(x)),
# randn(rng), x, randn(rng, sum(Ns), sum(Ns)), y;
# atol=1e-6, rtol=1e-6,
# )
# adjoint_test(
# (x, A1, A2, y) -> begin
# S = block_diagonal([A1 * A1' + I, A2 * A2' + I])
# return elbo(f(x, S), y, f(x))
# end,
# randn(rng), x, As[1], As[2], y;
# atol=1e-6, rtol=1e-6,
# )
# end
10 changes: 5 additions & 5 deletions test/mean_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
# This test fails without the specialized methods
# `mean_vector(m::CustomMean, x::ColVecs)`
# `mean_vector(m::CustomMean, x::RowVecs)`
@testset "Zygote gradients" begin
X = [1.;; 2.;; 3.;;]
y = [1., 2., 3.]
@testset "DifferentiationInterface gradients" begin
X = [1.0;; 2.0;; 3.0;;]
y = [1.0, 2.0, 3.0]
foo_mean = x -> sum(abs2, x)

function construct_finite_gp(X, lengthscale, noise)
Expand All @@ -51,7 +51,7 @@
return logpdf(gp, y)
end

@test Zygote.gradient(n -> loglike(1., n), 1.)[1] isa Real
@test Zygote.gradient(l -> loglike(l, 1.), 1.)[1] isa Real
@test only(gradient(n -> loglike(1.0, n), AutoMooncake(), 1.0)) isa Real
@test only(gradient(l -> loglike(l, 1.0), AutoMooncake(), 1.0)) isa Real
end
end
5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,23 @@ using AbstractGPs:
TestUtils

using Aqua
import DifferentiationInterface as DI
using DifferentiationInterface: gradient, jacobian, value_and_gradient, value_and_jacobian
using Documenter
using Distributions: MvNormal, PDMat, loglikelihood, Distributions
using FillArrays
using FiniteDifferences
using FiniteDifferences: j′vp, to_vec
using LinearAlgebra
using LinearAlgebra: AbstractTriangular
using Mooncake
using DifferentiationInterface
using PDMats: ScalMat
using Pkg
using Plots
using Random
using Statistics
using Test
using Zygote

const GROUP = get(ENV, "GROUP", "All")
const PKGDIR = dirname(dirname(pathof(AbstractGPs)))
Expand Down
4 changes: 2 additions & 2 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ function adjoint_test(
f, ȳ, x...; rtol=_rtol, atol=_atol, fdm=central_fdm(5, 1), print_results=false
)
# Compute forwards-pass and j′vp.
y, back = Zygote.pullback(f, x...)
adj_ad = back(ȳ)
_f = (x) -> f(x...)
y, adj_ad = DI.value_and_pullback(_f, DI.AutoMooncake(), x, ȳ)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
y, adj_ad = DI.value_and_pullback(_f, DI.AutoMooncake(), x, )
y, adj_ad = DI.value_and_pullback(_f, DI.AutoMooncake(), x, (ȳ, ))

In DI, tangents and cotangents are passed as tuples to enable the batched behavior of ForwardDiff and Enzyme.

adj_fd = j′vp(fdm, f, ȳ, x...)

# Check that forwards-pass agrees with plain forwards-pass.
Expand Down
Loading