Skip to content

Commit b177a3b

Browse files
committed
fix: write out the AD as dispatches
1 parent 3dcbbf8 commit b177a3b

File tree

1 file changed

+45
-48
lines changed

1 file changed

+45
-48
lines changed

lib/SimpleNonlinearSolve/src/utils.jl

Lines changed: 45 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -143,72 +143,69 @@ function prepare_jacobian(prob, autodiff, fx, x)
143143
end
144144
end
145145

146-
function compute_jacobian!!(_, prob, autodiff, fx, x::Number, extras)
147-
if extras isa AnalyticJacobian
148-
if SciMLBase.has_jac(prob.f)
146+
function compute_jacobian!!(_, prob, autodiff, fx, x::Number, ::AnalyticJacobian)
147+
if SciMLBase.has_jac(prob.f)
148+
return prob.f.jac(x, prob.p)
149+
elseif SciMLBase.has_vjp(prob.f)
150+
return prob.f.vjp(one(x), x, prob.p)
151+
elseif SciMLBase.has_jvp(prob.f)
152+
return prob.f.jvp(one(x), x, prob.p)
153+
end
154+
end
155+
function compute_jacobian!!(_, prob, autodiff, fx, x::Number, ::DIExtras)
156+
return DI.derivative(prob.f, extras.prep, autodiff, x, Constant(prob.p))
157+
end
158+
function compute_jacobian!!(_, prob, autodiff, fx, x::Number, ::DINoPreparation)
159+
return DI.derivative(prob.f, autodiff, x, Constant(prob.p))
160+
end
161+
162+
function compute_jacobian!!(J, prob, autodiff, fx, x, ::AnalyticJacobian)
163+
if J === nothing
164+
if SciMLBase.isinplace(prob.f)
165+
J = safe_similar(fx, length(fx), length(x))
166+
prob.f.jac(J, x, prob.p)
167+
return J
168+
else
149169
return prob.f.jac(x, prob.p)
150-
elseif SciMLBase.has_vjp(prob.f)
151-
return prob.f.vjp(one(x), x, prob.p)
152-
elseif SciMLBase.has_jvp(prob.f)
153-
return prob.f.jvp(one(x), x, prob.p)
154170
end
155171
end
156-
if extras isa DIExtras
157-
return DI.derivative(prob.f, extras.prep, autodiff, x, Constant(prob.p))
172+
if SciMLBase.isinplace(prob.f)
173+
prob.f.jac(J, x, prob.p)
174+
return J
158175
else
159-
return DI.derivative(prob.f, autodiff, x, Constant(prob.p))
176+
return prob.f.jac(x, prob.p)
160177
end
161178
end
162-
function compute_jacobian!!(J, prob, autodiff, fx, x, extras)
179+
180+
function compute_jacobian!!(J, prob, autodiff, fx, x, ::DIExtras)
163181
if J === nothing
164-
if extras isa AnalyticJacobian
165-
if SciMLBase.isinplace(prob.f)
166-
J = safe_similar(fx, length(fx), length(x))
167-
prob.f.jac(J, x, prob.p)
168-
return J
169-
else
170-
return prob.f.jac(x, prob.p)
171-
end
172-
end
173-
if SciMLBase.isinplace(prob)
174-
@assert extras isa DIExtras
182+
if SciMLBase.isinplace(prob.f)
175183
return DI.jacobian(prob.f, fx, extras.prep, autodiff, x, Constant(prob.p))
176184
else
177-
if extras isa DIExtras
178-
return DI.jacobian(prob.f, extras.prep, autodiff, x, Constant(prob.p))
179-
else
180-
return DI.jacobian(prob.f, autodiff, x, Constant(prob.p))
181-
end
185+
return DI.jacobian(prob.f, extras.prep, autodiff, x, Constant(prob.p))
182186
end
183187
end
184-
if extras isa AnalyticJacobian
185-
if SciMLBase.isinplace(prob)
186-
prob.f.jac(J, x, prob.p)
187-
return J
188-
else
189-
return prob.f.jac(x, prob.p)
190-
end
191-
end
192-
if SciMLBase.isinplace(prob)
193-
@assert extras isa DIExtras
194-
DI.jacobian!(prob.f, fx, J, extras.prep, autodiff, x, Constant(prob.p))
188+
if SciMLBase.isinplace(prob.f)
189+
DI.jacobian!(prob.f, J, fx, extras.prep, autodiff, x, Constant(prob.p))
195190
else
196191
if ArrayInterface.can_setindex(J)
197-
if extras isa DIExtras
198-
DI.jacobian!(prob.f, J, extras.prep, autodiff, x, Constant(prob.p))
199-
else
200-
DI.jacobian!(prob.f, J, autodiff, x, Constant(prob.p))
201-
end
192+
DI.jacobian!(prob.f, J, extras.prep, autodiff, x, Constant(prob.p))
202193
else
203-
if extras isa DIExtras
204-
J = DI.jacobian(prob.f, extras.prep, autodiff, x, Constant(prob.p))
205-
else
206-
J = DI.jacobian(prob.f, autodiff, x, Constant(prob.p))
207-
end
194+
J = DI.jacobian(prob.f, extras.prep, autodiff, x, Constant(prob.p))
208195
end
209196
end
210197
return J
211198
end
199+
function compute_jacobian!!(J, prob, autodiff, fx, x, ::DINoPreparation)
200+
@assert !SciMLBase.isinplace(prob.f) "This shouldn't happen. Open an issue."
201+
J === nothing && return DI.jacobian(prob.f, autodiff, x, Constant(prob.p))
202+
if ArrayInterface.can_setindex(J)
203+
DI.jacobian!(prob.f, J, autodiff, x, Constant(prob.p))
204+
else
205+
J = DI.jacobian(prob.f, autodiff, x, Constant(prob.p))
206+
end
207+
return J
208+
end
212209

213210
function compute_jacobian_and_hessian(autodiff, prob, _, x::Number)
214211
H = DI.second_derivative(prob.f, autodiff, x, Constant(prob.p))

0 commit comments

Comments
 (0)