Skip to content

Commit 8384a32

Browse files
Don't use DiffResults in Flux optimiser dispatch with FiniteDiff
1 parent 2d585bb commit 8384a32

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/function.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,13 +212,13 @@ function instantiate_function(f, x, adtype::AutoFiniteDiff, p, num_cons = 0)
212212
_f = (θ, args...) -> first(f.f(θ, p, args...))
213213

214214
if f.grad === nothing
215-
grad = (res, θ, args...) -> FiniteDiff.finite_difference_gradient!(res,x ->_f(x, args...), θ, FiniteDiff.GradientCache(res, x, adtype.fdtype))
215+
grad = (res, θ, args...) -> FiniteDiff.finite_difference_gradient!(res, x ->_f(x, args...), θ, FiniteDiff.GradientCache(res, x, adtype.fdtype))
216216
else
217217
grad = f.grad
218218
end
219219

220220
if f.hess === nothing
221-
hess = (res, θ, args...) -> FiniteDiff.finite_difference_hessian!(res,x ->_f(x, args...), θ, FiniteDiff.HessianCache(x, adtype.fdhtype))
221+
hess = (res, θ, args...) -> FiniteDiff.finite_difference_hessian!(res, x ->_f(x, args...), θ, FiniteDiff.HessianCache(x, adtype.fdhtype))
222222
else
223223
hess = f.hess
224224
end

src/solve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ function __solve(prob::OptimizationProblem, opt, _data = DEFAULT_DATA;cb = (args
9494

9595
@withprogress progress name="Training" begin
9696
for (i,d) in enumerate(data)
97-
gs = DiffResults.GradientResult(θ)
97+
gs = prob.f.adtype isa AutoFiniteDiff ? Array{Number}(undef,length(θ)) : DiffResults.GradientResult(θ)
9898
f.grad(gs, θ, d...)
9999
x = f.f(θ, prob.p, d...)
100100
cb_call = cb(θ, x...)
@@ -105,7 +105,7 @@ function __solve(prob::OptimizationProblem, opt, _data = DEFAULT_DATA;cb = (args
105105
end
106106
msg = @sprintf("loss: %.3g", x[1])
107107
progress && ProgressLogging.@logprogress msg i/maxiters
108-
update!(opt, ps, DiffResults.gradient(gs))
108+
update!(opt, ps, prob.f.adtype isa AutoFiniteDiff ? gs : DiffResults.gradient(gs))
109109

110110
if save_best
111111
if first(x) < first(min_err) #found a better solution

0 commit comments

Comments
 (0)