diff --git a/examples/1-mauna-loa/Project.toml b/examples/1-mauna-loa/Project.toml index b51c0c81..efeb7133 100644 --- a/examples/1-mauna-loa/Project.toml +++ b/examples/1-mauna-loa/Project.toml @@ -2,11 +2,12 @@ AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Optim = "429524aa-4258-5aef-a3af-852621145aeb" ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractGPs = "0.5" @@ -16,4 +17,3 @@ Literate = "2" Optim = "1" ParameterHandling = "0.4, 0.5" Plots = "1" -Zygote = "0.6, 0.7" diff --git a/examples/1-mauna-loa/script.jl b/examples/1-mauna-loa/script.jl index d8137f23..35823b97 100644 --- a/examples/1-mauna-loa/script.jl +++ b/examples/1-mauna-loa/script.jl @@ -12,7 +12,8 @@ using CSV, DataFrames # data loading using AbstractGPs # exact GP regression using ParameterHandling # for nested and constrained parameters using Optim # optimization -using Zygote # auto-diff gradient computation +import DifferentiationInterface as DI # auto-diff interface +using Mooncake # AD backend using Plots # visualisation # Let's load and visualize the dataset. @@ -225,14 +226,15 @@ function optimize_loss(loss, θ_init; optimizer=default_optimizer, maxiter=1_000 loss_packed = loss ∘ unflatten ## https://julianlsolvers.github.io/Optim.jl/stable/#user/tipsandtricks/#avoid-repeating-computations + ## TODO: enable `prep = DI.prepare_gradient(f, backend, x)` function fg!(F, G, x) if F !== nothing && G !== nothing - val, grad = Zygote.withgradient(loss_packed, x) - G .= only(grad) + val, grad = DI.value_and_gradient(loss_packed, AutoMooncake(), x) + G .= grad return val elseif G !== nothing - grad = Zygote.gradient(loss_packed, x) - G .= only(grad) + grad = DI.gradient(loss_packed, AutoMooncake(), x) + G .= grad return nothing elseif F !== nothing return loss_packed(x) diff --git a/examples/2-deep-kernel-learning/Project.toml b/examples/2-deep-kernel-learning/Project.toml index 1c205098..155b04f2 100644 --- a/examples/2-deep-kernel-learning/Project.toml +++ b/examples/2-deep-kernel-learning/Project.toml @@ -1,15 +1,16 @@ [deps] AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractGPs = "0.3,0.4,0.5" @@ -20,5 +21,4 @@ Lux = "1" MLDataUtils = "0.5" Optimisers = "0.4" Plots = "1" -Zygote = "0.7" julia = "1.10" diff --git a/examples/2-deep-kernel-learning/script.jl b/examples/2-deep-kernel-learning/script.jl index 67f3b09e..50f910f5 100644 --- a/examples/2-deep-kernel-learning/script.jl +++ b/examples/2-deep-kernel-learning/script.jl @@ -23,7 +23,8 @@ using Lux using Optimisers using Plots using Random -using Zygote +using Mooncake +import DifferentiationInterface as DI default(; legendfontsize=15.0, linewidth=3.0); Random.seed!(42) # for reproducibility @@ -91,7 +92,7 @@ anim = Animation() let tstate = Training.TrainState(neuralnet, ps, st, Optimisers.Adam(0.005)) for i in 1:nmax _, loss_val, _, tstate = Training.single_train_step!( - AutoZygote(), update_kernel_and_loss, (), tstate + DI.AutoMooncake(), update_kernel_and_loss, (), tstate ) if i % 10 == 0 diff --git a/examples/3-parametric-heteroscedastic/Project.toml b/examples/3-parametric-heteroscedastic/Project.toml index f62fe06f..d5f29129 100644 --- a/examples/3-parametric-heteroscedastic/Project.toml +++ b/examples/3-parametric-heteroscedastic/Project.toml @@ -2,13 +2,14 @@ AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" AbstractGPsMakie = "7834405d-1089-4985-bd30-732a30b92057" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Optim = "429524aa-4258-5aef-a3af-852621145aeb" ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractGPs = "0.5" @@ -18,4 +19,3 @@ KernelFunctions = "0.10" Literate = "2" Optim = "1" ParameterHandling = "0.4, 0.5" -Zygote = "0.6, 0.7" diff --git a/examples/3-parametric-heteroscedastic/script.jl b/examples/3-parametric-heteroscedastic/script.jl index 1be1bcad..fb5e37da 100644 --- a/examples/3-parametric-heteroscedastic/script.jl +++ b/examples/3-parametric-heteroscedastic/script.jl @@ -11,10 +11,11 @@ using AbstractGPs using AbstractGPsMakie using CairoMakie +import DifferentiationInterface as DI using KernelFunctions +using Mooncake using Optim using ParameterHandling -using Zygote using LinearAlgebra using Random @@ -47,15 +48,14 @@ end; # We use L-BFGS for optimising the objective function. # It is a first-order method and hence requires computing the gradient of the objective function. -# We do not derive and implement the gradient function manually here but instead use reverse-mode automatic differentiation with Zygote. -# When computing gradients with Zygote, the objective function is evaluated as well. +# We do not derive and implement the gradient function manually here but instead use reverse-mode automatic differentiation with DifferentiationInterface + Mooncake. +# When computing gradients, the objective function is evaluated as well. # We can exploit this and [avoid re-evaluating the objective function](https://julianlsolvers.github.io/Optim.jl/stable/#user/tipsandtricks/#avoid-repeating-computations) in such cases. function objective_and_gradient(F, G, flat_θ) if G !== nothing - val_grad = Zygote.withgradient(objective, flat_θ) - copyto!(G, only(val_grad.grad)) + val, grad = DI.value_and_gradient!(objective, G, DI.AutoMooncake(), flat_θ) if F !== nothing - return val_grad.val + return val end end if F !== nothing diff --git a/test/Project.toml b/test/Project.toml index 5e73a6d5..1fce9fa6 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,30 +1,32 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Aqua = "0.8" +DifferentiationInterface = "0.7" Distributions = "0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25" Documenter = "1" FillArrays = "0.11, 0.12, 0.13, 1" FiniteDifferences = "0.9.6, 0.10, 0.11, 0.12" LinearAlgebra = "1" +Mooncake = "0.4" PDMats = "0.11" Pkg = "1" Plots = "1" Random = "1" Statistics = "1" Test = "1" -Zygote = "0.5, 0.6, 0.7" julia = "1.6" diff --git a/test/finite_gp_projection.jl b/test/finite_gp_projection.jl index d75c1596..d42304b8 100644 --- a/test/finite_gp_projection.jl +++ b/test/finite_gp_projection.jl @@ -151,13 +151,9 @@ end # Check gradient of logpdf at mean is zero for `f`. adjoint_test(ŷ -> logpdf(fx, ŷ), 1, ones(size(ŷ))) - lp, back = Zygote.pullback(ŷ -> logpdf(fx, ŷ), ones(size(ŷ))) - @test back(randn(rng))[1] == zeros(size(ŷ)) # Check that gradient of logpdf at mean is zero for `y`. adjoint_test(ŷ -> logpdf(y, ŷ), 1, ones(size(ŷ))) - lp, back = Zygote.pullback(ŷ -> logpdf(y, ŷ), ones(size(ŷ))) - @test back(randn(rng))[1] == zeros(size(ŷ)) # Check that gradient w.r.t. inputs is approximately correct for `f`. x, l̄ = randn(rng, N), randn(rng) @@ -211,128 +207,3 @@ end docstring = string(Docs.doc(logpdf, Tuple{AbstractGPs.FiniteGP,Vector{Float64}})) @test occursin("logpdf(f::FiniteGP, y::AbstractVecOrMat{<:Real})", docstring) end - -# """ -# simple_gp_tests(rng::AbstractRNG, f::GP, xs::AV{<:AV}, σs::AV{<:Real}) - -# Integration tests for simple GPs. -# """ -# function simple_gp_tests( -# rng::AbstractRNG, -# f::GP, -# xs::AV{<:AV}, -# isp_σs::AV{<:Real}; -# atol=1e-8, -# rtol=1e-8, -# ) -# for x in xs, isp_σ in isp_σs - -# # Test gradient w.r.t. random sampling. -# N = length(x) -# adjoint_test( -# (x, isp_σ)->rand(_rng(), f(x, exp(isp_σ)^2)), -# randn(rng, N), -# x, -# isp_σ,; -# atol=atol, rtol=rtol, -# ) -# adjoint_test( -# (x, isp_σ)->rand(_rng(), f(x, exp(isp_σ)^2), 11), -# randn(rng, N, 11), -# x, -# isp_σ,; -# atol=atol, rtol=rtol, -# ) - -# # Check that gradient w.r.t. logpdf is correct. -# y, l̄ = rand(rng, f(x, exp(isp_σ))), randn(rng) -# adjoint_test( -# (x, isp_σ, y)->logpdf(f(x, exp(isp_σ)), y), -# l̄, x, isp_σ, y; -# atol=atol, rtol=rtol, -# ) - -# # Check that elbo is tight-ish when it's meant to be. -# fx, yx = f(x, 1e-9), f(x, exp(isp_σ)) -# @test isapprox(elbo(yx, y, fx), logpdf(yx, y); atol=1e-6, rtol=1e-6) - -# # Check that gradient w.r.t. elbo is correct. -# adjoint_test( -# (x, ŷ, isp_σ)->elbo(f(x, exp(isp_σ)), ŷ, f(x, 1e-9)), -# randn(rng), x, y, isp_σ; -# atol=1e-6, rtol=1e-6, -# ) -# end -# end - -# __foo(x) = isnothing(x) ? "nothing" : x - -# @testset "FiniteGP (integration)" begin -# rng = MersenneTwister(123456) -# xs = [collect(range(-3.0, stop=3.0, length=N)) for N in [2, 5, 10]] -# σs = log.([1e-1, 1e0, 1e1]) -# for (k, name, atol, rtol) in vcat( -# [ -# (EQ(), "EQ", 1e-6, 1e-6), -# (Linear(), "Linear", 1e-6, 1e-6), -# (PerEQ(), "PerEQ", 5e-5, 1e-8), -# (Exp(), "Exp", 1e-6, 1e-6), -# ], -# [( -# k(α=α, β=β, l=l), -# "$k_name(α=$(__foo(α)), β=$(__foo(β)), l=$(__foo(l)))", -# 1e-6, -# 1e-6, -# ) -# for (k, k_name) in ((EQ, "EQ"), (Linear, "linear"), (Matern12, "exp")) -# for α in (nothing, randn(rng)) -# for β in (nothing, exp(randn(rng))) -# for l in (nothing, randn(rng)) -# ], -# ) -# @testset "$name" begin -# simple_gp_tests(_rng(), GP(k, GPC()), xs, σs; atol=atol, rtol=rtol) -# end -# end -# end - -# @testset "FiniteGP (BlockDiagonal obs noise)" begin -# rng, Ns = MersenneTwister(123456), [4, 5] -# x = collect(range(-5.0, 5.0; length=sum(Ns))) -# As = [randn(rng, N, N) for N in Ns] -# Ss = [A' * A + I for A in As] - -# S = block_diagonal(Ss) -# Smat = Matrix(S) - -# f = GP(cos, EQ(), GPC()) -# y = rand(f(x, S)) - -# @test logpdf(f(x, S), y) ≈ logpdf(f(x, Smat), y) -# adjoint_test( -# (x, S, y)->logpdf(f(x, S), y), randn(rng), x, Smat, y; -# atol=1e-6, rtol=1e-6, -# ) -# adjoint_test( -# (x, A1, A2, y)->logpdf(f(x, block_diagonal([A1 * A1' + I, A2 * A2' + I])), y), -# randn(rng), x, As[1], As[2], y; -# atol=1e-6, rtol=1e-6 -# ) - -# @test elbo(f(x, Smat), y, f(x)) ≈ logpdf(f(x, Smat), y) -# @test elbo(f(x, S), y, f(x)) ≈ -# elbo(f(x, Smat), y, f(x)) -# adjoint_test( -# (x, A, y)->elbo(f(x, _to_psd(A)), y, f(x)), -# randn(rng), x, randn(rng, sum(Ns), sum(Ns)), y; -# atol=1e-6, rtol=1e-6, -# ) -# adjoint_test( -# (x, A1, A2, y) -> begin -# S = block_diagonal([A1 * A1' + I, A2 * A2' + I]) -# return elbo(f(x, S), y, f(x)) -# end, -# randn(rng), x, As[1], As[2], y; -# atol=1e-6, rtol=1e-6, -# ) -# end diff --git a/test/mean_function.jl b/test/mean_function.jl index 22cb7a66..e92774e2 100644 --- a/test/mean_function.jl +++ b/test/mean_function.jl @@ -35,9 +35,9 @@ # This test fails without the specialized methods # `mean_vector(m::CustomMean, x::ColVecs)` # `mean_vector(m::CustomMean, x::RowVecs)` - @testset "Zygote gradients" begin - X = [1.;; 2.;; 3.;;] - y = [1., 2., 3.] + @testset "DifferentiationInterface gradients" begin + X = [1.0;; 2.0;; 3.0;;] + y = [1.0, 2.0, 3.0] foo_mean = x -> sum(abs2, x) function construct_finite_gp(X, lengthscale, noise) @@ -51,7 +51,7 @@ return logpdf(gp, y) end - @test Zygote.gradient(n -> loglike(1., n), 1.)[1] isa Real - @test Zygote.gradient(l -> loglike(l, 1.), 1.)[1] isa Real + @test only(gradient(n -> loglike(1.0, n), AutoMooncake(), 1.0)) isa Real + @test only(gradient(l -> loglike(l, 1.0), AutoMooncake(), 1.0)) isa Real end end diff --git a/test/runtests.jl b/test/runtests.jl index d5edac8d..fef8548e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,6 +18,8 @@ using AbstractGPs: TestUtils using Aqua +import DifferentiationInterface as DI +using DifferentiationInterface: gradient, jacobian, value_and_gradient, value_and_jacobian using Documenter using Distributions: MvNormal, PDMat, loglikelihood, Distributions using FillArrays @@ -25,13 +27,14 @@ using FiniteDifferences using FiniteDifferences: j′vp, to_vec using LinearAlgebra using LinearAlgebra: AbstractTriangular +using Mooncake +using DifferentiationInterface using PDMats: ScalMat using Pkg using Plots using Random using Statistics using Test -using Zygote const GROUP = get(ENV, "GROUP", "All") const PKGDIR = dirname(dirname(pathof(AbstractGPs))) diff --git a/test/test_util.jl b/test/test_util.jl index 944efe7d..bf0f03ba 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -55,8 +55,8 @@ function adjoint_test( f, ȳ, x...; rtol=_rtol, atol=_atol, fdm=central_fdm(5, 1), print_results=false ) # Compute forwards-pass and j′vp. - y, back = Zygote.pullback(f, x...) - adj_ad = back(ȳ) + _f = (x) -> f(x...) + y, adj_ad = DI.value_and_pullback(_f, DI.AutoMooncake(), x, ȳ) adj_fd = j′vp(fdm, f, ȳ, x...) # Check that forwards-pass agrees with plain forwards-pass.