Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
[deps]
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
KrylovPreconditioners = "45d422c2-293f-44ce-8315-2cb988662dec"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NewtonKrylov = "0be81120-40bf-4f8b-adf0-26103efb66f1"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[sources.NewtonKrylov]
path = ".."
[sources]
NewtonKrylov = {path = ".."}
88 changes: 88 additions & 0 deletions examples/simple_adjoint.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
## Simple 2D example from (Kelley2003)[@cite]

using NewtonKrylov, LinearAlgebra
using CairoMakie

function F!(res, x, p)
res[1] = p[1] * x[1]^2 + p[2] * x[2]^2 - 2
return res[2] = exp(p[1] * x[1] - 1) + p[2] * x[2]^2 - 2
# return nothing
end

function F(x, p)
res = similar(x)
F!(res, x, p)
return res
end

p = [1.0, 1.3, 1.0]

xs = LinRange(-3, 8, 1000)
ys = LinRange(-15, 10, 1000)

levels = [0.1, 0.25, 0.5:2:10..., 10.0:10:200..., 200:100:4000...]

fig, ax = contour(xs, ys, (x, y) -> norm(F([x, y], p)); levels)


x₀ = [2.0, 0.5]
x, stats = newton_krylov!((res, u) -> F!(res, u, p), x₀)
@assert stats.solved

# ## Adjoint setup
# Define x̂ to be a solution we would like discover the parameter of.

const x̂ = [1.0000001797004159, 1.0000001140397106]

# `g` is our target function measuring the distance
function g(x, p)
return sum(abs2, x .- x̂)
end

using Enzyme
using Krylov

# function adjoint_with_primal(F!, G, u₀, p; kwargs...)
# res = similar(u₀)
# u, stats = newton_krylov!(F!, u₀, res; kwargs...)
# # @assert stats.solved

# return (; u, loss = G(u, p), dp = adjoint_nl!(F!, G, res, u, p))
# end

"""
adjoint_nl!(F!, G, res, u, p)

# Arguments
- `F!` -> F!(res, u, p) solves F(u; p) = 0
- `G` -> "Target function"/ "Loss function" G(u, p) = scalar
"""
function adjoint_nl!(F!, G, res, u, p)
# Calculate gₚ and gₓ
gₚ = Enzyme.make_zero(p)
gₓ = Enzyme.make_zero(u)
Enzyme.autodiff(Enzyme.Reverse, G, Duplicated(u, gₓ), Duplicated(p, gₚ))

# Solve adjoint equation for λ
J = NewtonKrylov.JacobianOperator((res, u) -> F!(res, u, p), res, u)
λ, stats = gmres(transpose(J), gₓ)
@assert stats.solved

# Now do vJp for λᵀ*fₚ
dp = Enzyme.make_zero(p)
Enzyme.autodiff(
Enzyme.Reverse, F!, Const,
Duplicated(res, λ),
Const(u),
Duplicated(p, dp)
)

# TODO:
# Use recursive_map to implement this subtraction https://github.com/EnzymeAD/Enzyme.jl/pull/1852
return gₚ - dp
end

adjoint_with_primal(F!, g, x₀, p)

# ## TODO:
# Use Optimizer.jl to find `p`