Skip to content

Commit 5fb8781

Browse files
committed
wrap it up in a handy function
1 parent 00cc2df commit 5fb8781

File tree

2 files changed

+160
-12
lines changed

2 files changed

+160
-12
lines changed

examples/Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
[deps]
22
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
3+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
34
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
45
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
56
KrylovPreconditioners = "45d422c2-293f-44ce-8315-2cb988662dec"
67
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
78
NewtonKrylov = "0be81120-40bf-4f8b-adf0-26103efb66f1"
89
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
910

10-
[sources.NewtonKrylov]
11-
path = ".."
11+
[sources]
12+
NewtonKrylov = {path = ".."}

examples/simple_adjoint.jl

Lines changed: 157 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using CairoMakie
66
function F!(res, x, p)
77
res[1] = p[1] * x[1]^2 + p[2] * x[2]^2 - 2
88
return res[2] = exp(p[1] * x[1] - 1) + p[2] * x[2]^2 - 2
9+
# return nothing
910
end
1011

1112
function F(x, p)
@@ -25,7 +26,8 @@ fig, ax = contour(xs, ys, (x, y) -> norm(F([x, y], p)); levels)
2526

2627

2728
x₀ = [2.0, 0.5]
28-
x, stats = newton_krylov((u) -> F(u, p), x₀)
29+
x, stats = newton_krylov!((res, u) -> F!(res, u, p), x₀)
30+
@assert stats.solved
2931

3032
const= [1.0000001797004159, 1.0000001140397106]
3133

@@ -55,17 +57,162 @@ gₓ = dg_x(x, p)
5557

5658
Fₚ = Enzyme.jacobian(Enzyme.Reverse, p -> F(x, p), p) |> only
5759

58-
Jᵀ = Enzyme.jacobian(Enzyme.Reverse, u -> F(u, p), x) |> only
60+
J = Enzyme.jacobian(Enzyme.Reverse, u -> F(u, p), x) |> only
5961

60-
q = Jᵀ \ gₓ
62+
q = transpose(J) \ gₓ
6163

62-
gₚ - reshape(transpose(q) * Fₚ, :)
64+
@show gₓ
65+
display(transpose(J))
66+
q2, stats = gmres(transpose(J), gₓ)
67+
@assert q == q2
6368

64-
function everything_all_at_once(p)
65-
x₀ = [2.0, 0.5]
66-
x, _ = newton_krylov((u) -> F(u, p), x₀)
67-
return g(x, p)
69+
gₚ - transpose(Fₚ) * q
70+
71+
72+
# dp = vJp_p(F!, res, x, p, q)
73+
74+
75+
# F!(res, x, p); res = 0 || F(x,p) = 0
76+
# function vJp_x(F!, res, x, p, v)
77+
# dx = Enzyme.make_zero(x)
78+
# Enzyme.autodiff(Enzyme.Reverse, F!,
79+
# DuplicatedNoNeed(res, reshape(v, size(res))),
80+
# Duplicated(x, dx),
81+
# Const(p))
82+
# dx
83+
# end
84+
85+
# function vJp_p(F!, res, x, p, v)
86+
# dp = Enzyme.make_zero(p)
87+
# Enzyme.autodiff(Enzyme.Reverse, F!,
88+
# DuplicatedNoNeed(res, reshape(v, size(res))),
89+
# Const(x),
90+
# Duplicated(p, dp))
91+
# dp
92+
# end
93+
94+
95+
# Notes: "discretise-then-optimise", or "backpropagation through the solver" has the benefit of only requiring "resursive accumulate" on the shadow
96+
# whereas "continous adjoint" after SGJ notes requires parameters to be "vector" like.
97+
98+
# function everything_all_at_once(p)
99+
# x₀ = [2.0, 0.5]
100+
# x, _ = newton_krylov((u) -> F(u, p), x₀)
101+
# return g(x, p)
102+
# end
103+
104+
# everything_all_at_once(p)
105+
# Enzyme.gradient(Enzyme.Reverse, everything_all_at_once, p)
106+
107+
# struct JacobianOperatorP{F, A}
108+
# f::F # F!(res, u, p)
109+
# res::A
110+
# u::A
111+
# p::AbstractArray
112+
# function JacobianOperatorP(f::F, res, u, p) where {F}
113+
# return new{F, typeof(u), typeof(p)}(f, res, u, p)
114+
# end
115+
# end
116+
117+
# Base.size(J::JacobianOperatorP) = (length(J.res), length(J.p))
118+
# Base.eltype(J::JacobianOperatorP) = eltype(J.u)
119+
120+
# function mul!(out, J::JacobianOperatorP, v)
121+
# # Enzyme.make_zero!(J.f_cache)
122+
# f_cache = Enzyme.make_zero(J.f) # Stop gap until we can zero out mutable values
123+
# autodiff(
124+
# Forward,
125+
# maybe_duplicated(J.f, f_cache), Const,
126+
# DuplicatedNoNeed(J.res, reshape(out, size(J.res))),
127+
# Const(J.u),
128+
# # DuplicatedNoNeed(J.u, Enzyme.make_zero(J.u)) #, reshape(v, size(J.u)))
129+
# )
130+
# return nothing
131+
# end
132+
133+
134+
# #####
135+
136+
# function dg(x,dx, y, dy)
137+
# _x = x[1]
138+
# _y = y[1]
139+
# _a = _x * _y
140+
141+
# x[1] = _a
142+
# #
143+
# _da = 0
144+
# _da += dx[1]
145+
# dx[1] = 0
146+
147+
# _dx = 0
148+
# _dx += _da*_y
149+
150+
# _dy = 0
151+
# _dy += _da*_x
152+
153+
# _da = 0
154+
155+
# dy[1] += _dy
156+
# _dy = 0
157+
# dx[1] += _dx 3-element Vector{Float64}:
158+
159+
# function f(x, y)
160+
# g(x, y)
161+
# g(x, y)
162+
# end
163+
164+
# x = [1.0]
165+
# y = [1.0]
166+
# dx = [1.0]
167+
# dy = [0.0]
168+
169+
# autodiff(Enzyme.Reverse, f, Duplicated(x, dx), Duplicated(y, dy))
170+
171+
dx
172+
dy
173+
174+
using Enzyme
175+
using Krylov
176+
177+
function adjoint_with_primal(F!, G, u₀, p; kwargs...)
178+
res = similar(u₀)
179+
# TODO: Adjust newton_krylov interface to work with `F(u, p)`
180+
u, stats = newton_krylov!((res, u) -> F!(res, u, p), u₀, res; kwargs...)
181+
# @assert stats.solved
182+
183+
return (; u, loss = G(u, p), dp = adjoint_nl!(F!, G, res, u, p))
68184
end
69185

70-
everything_all_at_once(p)
71-
Enzyme.gradient(Enzyme.Reverse, everything_all_at_once, p)
186+
"""
187+
adjoint_nl!(F!, G, )
188+
189+
# Arguments
190+
- `F!` -> F!(res, u, p) solves F(u; p) = 0
191+
- `G` -> "Target function"/ "Loss function" G(u, p) = scalar
192+
"""
193+
function adjoint_nl!(F!, G, res, u, p)
194+
# Calculate gₚ and gₓ
195+
gₚ = Enzyme.make_zero(p)
196+
gₓ = Enzyme.make_zero(u)
197+
Enzyme.autodiff(Enzyme.Reverse, G, Duplicated(u, gₓ), Duplicated(p, gₚ))
198+
199+
# Solve adjoint equation for λ
200+
J = NewtonKrylov.JacobianOperator((res, u) -> F!(res, u, p), res, u)
201+
λ, stats = gmres(transpose(J), gₓ) # todo why no transpose(gₓ)
202+
@assert stats.solved
203+
204+
# Now do vJp for λᵀ*fₚ
205+
dp = Enzyme.make_zero(p)
206+
Enzyme.autodiff(
207+
Enzyme.Reverse, F!, Const,
208+
DuplicatedNoNeed(res, λ),
209+
Const(u),
210+
DuplicatedNoNeed(p, dp)
211+
)
212+
213+
return gₚ - dp
214+
end
215+
216+
adjoint_nl!(F!, g, similar(x), x, p)
217+
218+
adjoint_with_primal(F!, g, x₀, p)

0 commit comments

Comments
 (0)