Skip to content

Commit 72f5d07

Browse files
Merge pull request #203 from ChrisRackauckas-Claude/fix-repeated-fx0-evaluation
Fix repeated evaluation of fx0 in forward gradient computation
2 parents 83ff37e + 62990cb commit 72f5d07

File tree

3 files changed

+8
-22
lines changed

3 files changed

+8
-22
lines changed

src/gradients.jl

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -372,37 +372,23 @@ function finite_difference_gradient!(
372372
end
373373
copyto!(c3, x)
374374
if fdtype == Val(:forward)
375+
fx0 = fx !== nothing ? fx : f(x)
375376
for i in eachindex(x)
376377
epsilon = compute_epsilon(fdtype, x[i], relstep, absstep, dir)
377378
x_old = x[i]
378-
if typeof(fx) != Nothing
379-
c3[i] += epsilon
380-
dfi = (f(c3) - fx) / epsilon
381-
c3[i] = x_old
382-
else
383-
fx0 = f(x)
384-
c3[i] += epsilon
385-
dfi = (f(c3) - fx0) / epsilon
386-
c3[i] = x_old
387-
end
379+
c3[i] += epsilon
380+
dfi = (f(c3) - fx0) / epsilon
381+
c3[i] = x_old
388382

389383
df[i] = real(dfi)
390384
if eltype(df) <: Complex
391385
if eltype(x) <: Complex
392386
c3[i] += im * epsilon
393-
if typeof(fx) != Nothing
394-
dfi = (f(c3) - fx) / (im * epsilon)
395-
else
396-
dfi = (f(c3) - fx0) / (im * epsilon)
397-
end
387+
dfi = (f(c3) - fx0) / (im * epsilon)
398388
c3[i] = x_old
399389
else
400390
c1[i] += im * epsilon
401-
if typeof(fx) != Nothing
402-
dfi = (f(c1) - fx) / (im * epsilon)
403-
else
404-
dfi = (f(c1) - fx0) / (im * epsilon)
405-
end
391+
dfi = (f(c1) - fx0) / (im * epsilon)
406392
c1[i] = x_old
407393
end
408394
df[i] -= im * imag(dfi)

test/downstream/ordinarydiffeq_tridiagonal_solve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,5 @@ function loss(p)
2424
sol = solve(_prob, Rodas4P(autodiff=false), saveat=0.1)
2525
sum((sol .- sol_true).^2)
2626
end
27-
@test ForwardDiff.gradient(loss, [1.0])[1] 0.6662949361011025
27+
@test ForwardDiff.gradient(loss, [1.0])[1] 0.6645766813735486
28+

test/runtests.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ if GROUP == "All" || GROUP == "Downstream"
2424
@time @safetestset "ODEs" begin
2525
import OrdinaryDiffEq
2626
@time @safetestset "OrdinaryDiffEq Tridiagonal" begin include("downstream/ordinarydiffeq_tridiagonal_solve.jl") end
27-
include(joinpath(dirname(pathof(OrdinaryDiffEq)), "..", "test/interface/sparsediff_tests.jl"))
2827
end
2928
end
3029

0 commit comments

Comments
 (0)