Skip to content

Commit 4603b4a

Browse files
Revert non-square and add downstream test
1 parent 89654d7 commit 4603b4a

File tree

4 files changed

+33
-4
lines changed

4 files changed

+33
-4
lines changed

src/jacobians.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -373,8 +373,8 @@ function finite_difference_jacobian!(
373373
# Now return x1 back to its original value
374374
ArrayInterface.allowed_setindex!(x1, x1_save, color_i)
375375
else # Perturb along the colorvec vector
376-
idx = findfirst(isequal(color_i), _color)
377-
tmp = norm(x1[idx])
376+
@. fx1 = x1 * (_color == color_i)
377+
tmp = norm(fx1)
378378
epsilon = compute_epsilon(Val(:forward), sqrt(tmp), relstep, absstep, dir)
379379
@. x1 = x1 + epsilon * (_color == color_i)
380380
f(fx1, x1)
@@ -416,8 +416,8 @@ function finite_difference_jacobian!(
416416
@. J[:,color_i] = (vfx1 - vfx) / 2epsilon
417417
ArrayInterface.allowed_setindex!(x1, x_save, color_i)
418418
else # Perturb along the colorvec vector
419-
idx = findfirst(isequal(color_i), _color)
420-
tmp = norm(x1[idx])
419+
@. fx1 = x1 * (_color == color_i)
420+
tmp = norm(fx1)
421421
epsilon = compute_epsilon(Val(:central), sqrt(tmp), relstep, absstep, dir)
422422
@. x1 = x1 + epsilon * (_color == color_i)
423423
@. x = x - epsilon * (_color == color_i)

test/downstream/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
[deps]
22
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
3+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
34
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using OrdinaryDiffEq, ForwardDiff, Test
2+
3+
const nknots = 10
4+
const h = 1.0/(nknots+1)
5+
x = range(0, step=h, length=nknots)
6+
u0 = sin.(π*x)
7+
8+
@inline function f(du,u,p,t)
9+
du .= zero(eltype(u))
10+
u₃ = @view u[3:end]
11+
u₂ = @view u[2:end-1]
12+
u₁ = @view u[1:end-2]
13+
@. du[2:end-1] = p[1]*((u₃ - 2*u₂ + u₁)/(h^2.0))
14+
nothing
15+
end
16+
17+
p_true = [0.42]
18+
jac_proto = Tridiagonal(similar(u0,nknots-1), similar(u0), similar(u0, nknots-1))
19+
prob = ODEProblem(ODEFunction(f,jac_prototype=jac_proto), u0, (0.0,1.0), p_true)
20+
sol_true = solve(prob, Rodas4P(), saveat=0.1)
21+
22+
function loss(p)
23+
_prob = remake(prob, p=p)
24+
sol = solve(_prob, Rodas4P(autodiff=false), saveat=0.1)
25+
sum((sol .- sol_true).^2)
26+
end
27+
@test ForwardDiff.gradient(loss, [1.0])[1] 0.6662949361011025

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ if GROUP == "All" || GROUP == "Downstream"
2323
activate_downstream_env()
2424
@time @safetestset "ODEs" begin
2525
import OrdinaryDiffEq
26+
@time @safetestset "OrdinaryDiffEq Tridiagonal" begin include("ordinarydiffeq_tridiagonal_solve.jl") end
2627
include(joinpath(dirname(pathof(OrdinaryDiffEq)), "..", "test/interface/sparsediff_tests.jl"))
2728
end
2829
end

0 commit comments

Comments
 (0)