-
Notifications
You must be signed in to change notification settings - Fork 24
Replace Zygote with DifferentiationInterface + Mooncake for automatic differentiation #434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f3afdec
2e3ac84
9b4d11c
96a0bd9
d2e5b9d
d41c9a1
400b859
f762a6f
7b6fc4d
cb57b09
20adc06
305b761
280354d
3b929e4
0525191
503dd3d
91e99d7
5ede796
33db1e8
861594a
833dcdf
8ccd32b
30af1f9
121129e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,7 +12,8 @@ using CSV, DataFrames # data loading | |
| using AbstractGPs # exact GP regression | ||
| using ParameterHandling # for nested and constrained parameters | ||
| using Optim # optimization | ||
| using Zygote # auto-diff gradient computation | ||
| import DifferentiationInterface as DI # auto-diff interface | ||
| using Mooncake # AD backend | ||
| using Plots # visualisation | ||
|
|
||
| # Let's load and visualize the dataset. | ||
|
|
@@ -225,14 +226,15 @@ function optimize_loss(loss, θ_init; optimizer=default_optimizer, maxiter=1_000 | |
| loss_packed = loss ∘ unflatten | ||
|
|
||
| ## https://julianlsolvers.github.io/Optim.jl/stable/#user/tipsandtricks/#avoid-repeating-computations | ||
| ## TODO: enable `prep = DI.prepare_gradient(f, backend, x)` | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is really important for performance |
||
| function fg!(F, G, x) | ||
| if F !== nothing && G !== nothing | ||
| val, grad = Zygote.withgradient(loss_packed, x) | ||
| G .= only(grad) | ||
| val, grad = DI.value_and_gradient(loss_packed, AutoMooncake(), x) | ||
| G .= grad | ||
|
Comment on lines
+232
to
+233
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not use |
||
| return val | ||
| elseif G !== nothing | ||
| grad = Zygote.gradient(loss_packed, x) | ||
| G .= only(grad) | ||
| grad = DI.gradient(loss_packed, AutoMooncake(), x) | ||
| G .= grad | ||
|
Comment on lines
+236
to
+237
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not use |
||
| return nothing | ||
| elseif F !== nothing | ||
| return loss_packed(x) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,15 +1,16 @@ | ||
| [deps] | ||
| AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" | ||
| Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
| DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" | ||
| KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" | ||
| LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
| Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" | ||
| Lux = "b2108857-7c20-44ae-9111-449ecde12c47" | ||
| MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d" | ||
| Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" | ||
| Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" | ||
| Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" | ||
| Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
| Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
|
|
||
| [compat] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing compat bounds for DI and Mooncake |
||
| AbstractGPs = "0.3,0.4,0.5" | ||
|
|
@@ -20,5 +21,4 @@ Lux = "1" | |
| MLDataUtils = "0.5" | ||
| Optimisers = "0.4" | ||
| Plots = "1" | ||
| Zygote = "0.7" | ||
| julia = "1.10" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,13 +2,14 @@ | |
| AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" | ||
| AbstractGPsMakie = "7834405d-1089-4985-bd30-732a30b92057" | ||
| CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" | ||
| DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" | ||
| KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" | ||
| LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
| Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" | ||
| Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" | ||
| Optim = "429524aa-4258-5aef-a3af-852621145aeb" | ||
| ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5" | ||
| Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
| Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
|
|
||
| [compat] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing compat bounds for DI and Mooncake |
||
| AbstractGPs = "0.5" | ||
|
|
@@ -18,4 +19,3 @@ KernelFunctions = "0.10" | |
| Literate = "2" | ||
| Optim = "1" | ||
| ParameterHandling = "0.4, 0.5" | ||
| Zygote = "0.6, 0.7" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,10 +11,11 @@ | |
| using AbstractGPs | ||
| using AbstractGPsMakie | ||
| using CairoMakie | ||
| import DifferentiationInterface as DI | ||
| using KernelFunctions | ||
| using Mooncake | ||
| using Optim | ||
| using ParameterHandling | ||
| using Zygote | ||
|
|
||
| using LinearAlgebra | ||
| using Random | ||
|
|
@@ -47,15 +48,14 @@ end; | |
|
|
||
| # We use L-BFGS for optimising the objective function. | ||
| # It is a first-order method and hence requires computing the gradient of the objective function. | ||
| # We do not derive and implement the gradient function manually here but instead use reverse-mode automatic differentiation with Zygote. | ||
| # When computing gradients with Zygote, the objective function is evaluated as well. | ||
| # We do not derive and implement the gradient function manually here but instead use reverse-mode automatic differentiation with DifferentiationInterface + Mooncake. | ||
| # When computing gradients, the objective function is evaluated as well. | ||
| # We can exploit this and [avoid re-evaluating the objective function](https://julianlsolvers.github.io/Optim.jl/stable/#user/tipsandtricks/#avoid-repeating-computations) in such cases. | ||
| function objective_and_gradient(F, G, flat_θ) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO: enable preparation with |
||
| if G !== nothing | ||
| val_grad = Zygote.withgradient(objective, flat_θ) | ||
| copyto!(G, only(val_grad.grad)) | ||
| val, grad = DI.value_and_gradient!(objective, G, DI.AutoMooncake(), flat_θ) | ||
| if F !== nothing | ||
| return val_grad.val | ||
| return val | ||
| end | ||
| end | ||
| if F !== nothing | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,30 +1,32 @@ | ||
| [deps] | ||
| Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" | ||
| DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" | ||
| Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
| Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" | ||
| FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" | ||
| FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" | ||
| LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
| Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" | ||
| PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" | ||
| Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" | ||
| Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" | ||
| Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
| Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
| Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
| Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
|
|
||
| [compat] | ||
| Aqua = "0.8" | ||
| DifferentiationInterface = "0.7" | ||
| Distributions = "0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25" | ||
| Documenter = "1" | ||
| FillArrays = "0.11, 0.12, 0.13, 1" | ||
| FiniteDifferences = "0.9.6, 0.10, 0.11, 0.12" | ||
| LinearAlgebra = "1" | ||
| Mooncake = "0.4" | ||
| PDMats = "0.11" | ||
| Pkg = "1" | ||
| Plots = "1" | ||
| Random = "1" | ||
| Statistics = "1" | ||
| Test = "1" | ||
| Zygote = "0.5, 0.6, 0.7" | ||
| julia = "1.6" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -151,13 +151,9 @@ end | |
|
|
||
| # Check gradient of logpdf at mean is zero for `f`. | ||
| adjoint_test(ŷ -> logpdf(fx, ŷ), 1, ones(size(ŷ))) | ||
| lp, back = Zygote.pullback(ŷ -> logpdf(fx, ŷ), ones(size(ŷ))) | ||
| @test back(randn(rng))[1] == zeros(size(ŷ)) | ||
|
Comment on lines
-154
to
-155
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why were these tests removed? |
||
|
|
||
| # Check that gradient of logpdf at mean is zero for `y`. | ||
| adjoint_test(ŷ -> logpdf(y, ŷ), 1, ones(size(ŷ))) | ||
| lp, back = Zygote.pullback(ŷ -> logpdf(y, ŷ), ones(size(ŷ))) | ||
| @test back(randn(rng))[1] == zeros(size(ŷ)) | ||
|
|
||
| # Check that gradient w.r.t. inputs is approximately correct for `f`. | ||
| x, l̄ = randn(rng, N), randn(rng) | ||
|
|
@@ -211,128 +207,3 @@ end | |
| docstring = string(Docs.doc(logpdf, Tuple{AbstractGPs.FiniteGP,Vector{Float64}})) | ||
| @test occursin("logpdf(f::FiniteGP, y::AbstractVecOrMat{<:Real})", docstring) | ||
| end | ||
|
|
||
| # """ | ||
| # simple_gp_tests(rng::AbstractRNG, f::GP, xs::AV{<:AV}, σs::AV{<:Real}) | ||
|
|
||
| # Integration tests for simple GPs. | ||
| # """ | ||
| # function simple_gp_tests( | ||
| # rng::AbstractRNG, | ||
| # f::GP, | ||
| # xs::AV{<:AV}, | ||
| # isp_σs::AV{<:Real}; | ||
| # atol=1e-8, | ||
| # rtol=1e-8, | ||
| # ) | ||
| # for x in xs, isp_σ in isp_σs | ||
|
|
||
| # # Test gradient w.r.t. random sampling. | ||
| # N = length(x) | ||
| # adjoint_test( | ||
| # (x, isp_σ)->rand(_rng(), f(x, exp(isp_σ)^2)), | ||
| # randn(rng, N), | ||
| # x, | ||
| # isp_σ,; | ||
| # atol=atol, rtol=rtol, | ||
| # ) | ||
| # adjoint_test( | ||
| # (x, isp_σ)->rand(_rng(), f(x, exp(isp_σ)^2), 11), | ||
| # randn(rng, N, 11), | ||
| # x, | ||
| # isp_σ,; | ||
| # atol=atol, rtol=rtol, | ||
| # ) | ||
|
|
||
| # # Check that gradient w.r.t. logpdf is correct. | ||
| # y, l̄ = rand(rng, f(x, exp(isp_σ))), randn(rng) | ||
| # adjoint_test( | ||
| # (x, isp_σ, y)->logpdf(f(x, exp(isp_σ)), y), | ||
| # l̄, x, isp_σ, y; | ||
| # atol=atol, rtol=rtol, | ||
| # ) | ||
|
|
||
| # # Check that elbo is tight-ish when it's meant to be. | ||
| # fx, yx = f(x, 1e-9), f(x, exp(isp_σ)) | ||
| # @test isapprox(elbo(yx, y, fx), logpdf(yx, y); atol=1e-6, rtol=1e-6) | ||
|
|
||
| # # Check that gradient w.r.t. elbo is correct. | ||
| # adjoint_test( | ||
| # (x, ŷ, isp_σ)->elbo(f(x, exp(isp_σ)), ŷ, f(x, 1e-9)), | ||
| # randn(rng), x, y, isp_σ; | ||
| # atol=1e-6, rtol=1e-6, | ||
| # ) | ||
| # end | ||
| # end | ||
|
|
||
| # __foo(x) = isnothing(x) ? "nothing" : x | ||
|
|
||
| # @testset "FiniteGP (integration)" begin | ||
| # rng = MersenneTwister(123456) | ||
| # xs = [collect(range(-3.0, stop=3.0, length=N)) for N in [2, 5, 10]] | ||
| # σs = log.([1e-1, 1e0, 1e1]) | ||
| # for (k, name, atol, rtol) in vcat( | ||
| # [ | ||
| # (EQ(), "EQ", 1e-6, 1e-6), | ||
| # (Linear(), "Linear", 1e-6, 1e-6), | ||
| # (PerEQ(), "PerEQ", 5e-5, 1e-8), | ||
| # (Exp(), "Exp", 1e-6, 1e-6), | ||
| # ], | ||
| # [( | ||
| # k(α=α, β=β, l=l), | ||
| # "$k_name(α=$(__foo(α)), β=$(__foo(β)), l=$(__foo(l)))", | ||
| # 1e-6, | ||
| # 1e-6, | ||
| # ) | ||
| # for (k, k_name) in ((EQ, "EQ"), (Linear, "linear"), (Matern12, "exp")) | ||
| # for α in (nothing, randn(rng)) | ||
| # for β in (nothing, exp(randn(rng))) | ||
| # for l in (nothing, randn(rng)) | ||
| # ], | ||
| # ) | ||
| # @testset "$name" begin | ||
| # simple_gp_tests(_rng(), GP(k, GPC()), xs, σs; atol=atol, rtol=rtol) | ||
| # end | ||
| # end | ||
| # end | ||
|
|
||
| # @testset "FiniteGP (BlockDiagonal obs noise)" begin | ||
| # rng, Ns = MersenneTwister(123456), [4, 5] | ||
| # x = collect(range(-5.0, 5.0; length=sum(Ns))) | ||
| # As = [randn(rng, N, N) for N in Ns] | ||
| # Ss = [A' * A + I for A in As] | ||
|
|
||
| # S = block_diagonal(Ss) | ||
| # Smat = Matrix(S) | ||
|
|
||
| # f = GP(cos, EQ(), GPC()) | ||
| # y = rand(f(x, S)) | ||
|
|
||
| # @test logpdf(f(x, S), y) ≈ logpdf(f(x, Smat), y) | ||
| # adjoint_test( | ||
| # (x, S, y)->logpdf(f(x, S), y), randn(rng), x, Smat, y; | ||
| # atol=1e-6, rtol=1e-6, | ||
| # ) | ||
| # adjoint_test( | ||
| # (x, A1, A2, y)->logpdf(f(x, block_diagonal([A1 * A1' + I, A2 * A2' + I])), y), | ||
| # randn(rng), x, As[1], As[2], y; | ||
| # atol=1e-6, rtol=1e-6 | ||
| # ) | ||
|
|
||
| # @test elbo(f(x, Smat), y, f(x)) ≈ logpdf(f(x, Smat), y) | ||
| # @test elbo(f(x, S), y, f(x)) ≈ | ||
| # elbo(f(x, Smat), y, f(x)) | ||
| # adjoint_test( | ||
| # (x, A, y)->elbo(f(x, _to_psd(A)), y, f(x)), | ||
| # randn(rng), x, randn(rng, sum(Ns), sum(Ns)), y; | ||
| # atol=1e-6, rtol=1e-6, | ||
| # ) | ||
| # adjoint_test( | ||
| # (x, A1, A2, y) -> begin | ||
| # S = block_diagonal([A1 * A1' + I, A2 * A2' + I]) | ||
| # return elbo(f(x, S), y, f(x)) | ||
| # end, | ||
| # randn(rng), x, As[1], As[2], y; | ||
| # atol=1e-6, rtol=1e-6, | ||
| # ) | ||
| # end | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -55,8 +55,8 @@ function adjoint_test( | |||||
| f, ȳ, x...; rtol=_rtol, atol=_atol, fdm=central_fdm(5, 1), print_results=false | ||||||
| ) | ||||||
| # Compute forwards-pass and j′vp. | ||||||
| y, back = Zygote.pullback(f, x...) | ||||||
| adj_ad = back(ȳ) | ||||||
| _f = (x) -> f(x...) | ||||||
| y, adj_ad = DI.value_and_pullback(_f, DI.AutoMooncake(), x, ȳ) | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
In DI, tangents and cotangents are passed as tuples to enable the batched behavior of ForwardDiff and Enzyme. |
||||||
| adj_fd = j′vp(fdm, f, ȳ, x...) | ||||||
|
|
||||||
| # Check that forwards-pass agrees with plain forwards-pass. | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing compat bounds for DI and Mooncake