diff --git a/examples/Project.toml b/examples/Project.toml index e095cc7..37e88f1 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -1,5 +1,6 @@ [deps] CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" KrylovPreconditioners = "45d422c2-293f-44ce-8315-2cb988662dec" @@ -7,5 +8,5 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NewtonKrylov = "0be81120-40bf-4f8b-adf0-26103efb66f1" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -[sources.NewtonKrylov] -path = ".." +[sources] +NewtonKrylov = {path = ".."} diff --git a/examples/simple_adjoint.jl b/examples/simple_adjoint.jl new file mode 100644 index 0000000..fc52031 --- /dev/null +++ b/examples/simple_adjoint.jl @@ -0,0 +1,88 @@ +## Simple 2D example from (Kelley2003)[@cite] + +using NewtonKrylov, LinearAlgebra +using CairoMakie + +function F!(res, x, p) + res[1] = p[1] * x[1]^2 + p[2] * x[2]^2 - 2 + return res[2] = exp(p[1] * x[1] - 1) + p[2] * x[2]^2 - 2 + # return nothing +end + +function F(x, p) + res = similar(x) + F!(res, x, p) + return res +end + +p = [1.0, 1.3, 1.0] + +xs = LinRange(-3, 8, 1000) +ys = LinRange(-15, 10, 1000) + +levels = [0.1, 0.25, 0.5:2:10..., 10.0:10:200..., 200:100:4000...] + +fig, ax = contour(xs, ys, (x, y) -> norm(F([x, y], p)); levels) + + +x₀ = [2.0, 0.5] +x, stats = newton_krylov!((res, u) -> F!(res, u, p), x₀) +@assert stats.solved + +# ## Adjoint setup +# Define x̂ to be a solution we would like discover the parameter of. + +const x̂ = [1.0000001797004159, 1.0000001140397106] + +# `g` is our target function measuring the distance +function g(x, p) + return sum(abs2, x .- x̂) +end + +using Enzyme +using Krylov + +# function adjoint_with_primal(F!, G, u₀, p; kwargs...) +# res = similar(u₀) +# u, stats = newton_krylov!(F!, u₀, res; kwargs...) +# # @assert stats.solved + +# return (; u, loss = G(u, p), dp = adjoint_nl!(F!, G, res, u, p)) +# end + +""" + adjoint_nl!(F!, G, res, u, p) + +# Arguments +- `F!` -> F!(res, u, p) solves F(u; p) = 0 +- `G` -> "Target function"/ "Loss function" G(u, p) = scalar +""" +function adjoint_nl!(F!, G, res, u, p) + # Calculate gₚ and gₓ + gₚ = Enzyme.make_zero(p) + gₓ = Enzyme.make_zero(u) + Enzyme.autodiff(Enzyme.Reverse, G, Duplicated(u, gₓ), Duplicated(p, gₚ)) + + # Solve adjoint equation for λ + J = NewtonKrylov.JacobianOperator((res, u) -> F!(res, u, p), res, u) + λ, stats = gmres(transpose(J), gₓ) + @assert stats.solved + + # Now do vJp for λᵀ*fₚ + dp = Enzyme.make_zero(p) + Enzyme.autodiff( + Enzyme.Reverse, F!, Const, + Duplicated(res, λ), + Const(u), + Duplicated(p, dp) + ) + + # TODO: + # Use recursive_map to implement this subtraction https://github.com/EnzymeAD/Enzyme.jl/pull/1852 + return gₚ - dp +end + +adjoint_with_primal(F!, g, x₀, p) + +# ## TODO: +# Use Optimizer.jl to find `p`