Skip to content

Commit 7aae96e

Browse files
Merge pull request SciML#2291 from oscardssmith/os/fix-OOP-WOperator
fix out of place `WOperator`
2 parents 17bd407 + 9c667c4 commit 7aae96e

File tree

4 files changed

+87
-88
lines changed

4 files changed

+87
-88
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
101101
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
102102
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
103103
ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5"
104+
ParameterizedFunctions = "65888b18-ceab-5e60-b2b9-181511a3b968"
104105
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
105106
PoissonRandom = "e409e4f3-bfea-5376-8464-e040bb5c01ab"
106107
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
@@ -114,4 +115,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
114115
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
115116

116117
[targets]
117-
test = ["Calculus", "ComponentArrays", "Symbolics", "AlgebraicMultigrid", "IncompleteLU", "DiffEqCallbacks", "DiffEqDevTools", "ODEProblemLibrary", "ElasticArrays", "InteractiveUtils", "PoissonRandom", "Printf", "Random", "ReverseDiff", "SafeTestsets", "SparseArrays", "Statistics", "Test", "Unitful", "ModelingToolkit", "Pkg", "NLsolve"]
118+
test = ["Calculus", "ComponentArrays", "Symbolics", "AlgebraicMultigrid", "IncompleteLU", "DiffEqCallbacks", "DiffEqDevTools", "ODEProblemLibrary", "ElasticArrays", "InteractiveUtils", "ParameterizedFunctions", "PoissonRandom", "Printf", "Random", "ReverseDiff", "SafeTestsets", "SparseArrays", "Statistics", "Test", "Unitful", "ModelingToolkit", "Pkg", "NLsolve"]

src/OrdinaryDiffEq.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ using NonlinearSolve
6666

6767
# Required by temporary fix in not in-place methods with 12+ broadcasts
6868
# `MVector` is used by Nordsieck forms
69-
import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, SA
69+
import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, SA, StaticMatrix
7070

7171
# Integrator Interface
7272
import DiffEqBase: resize!, deleteat!, addat!, full_cache, user_cache, u_cache, du_cache,

src/derivative_utils.jl

Lines changed: 44 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,12 @@ function calc_J(integrator, cache, next_step::Bool = false)
9090
J = jacobian(uf, uprev, integrator)
9191
end
9292

93-
integrator.stats.njacs += 1
94-
9593
if alg isa CompositeAlgorithm
9694
integrator.eigen_est = constvalue(opnorm(J, Inf))
9795
end
9896
end
9997

98+
integrator.stats.njacs += 1
10099
J
101100
end
102101

@@ -144,12 +143,11 @@ function calc_J!(J, integrator, cache, next_step::Bool = false)
144143
end
145144
end
146145

147-
integrator.stats.njacs += 1
148-
149146
if alg isa CompositeAlgorithm
150147
integrator.eigen_est = constvalue(opnorm(J, Inf))
151148
end
152149

150+
integrator.stats.njacs += 1
153151
return nothing
154152
end
155153

@@ -604,21 +602,21 @@ function jacobian2W!(W::Matrix, mass_matrix, dtgamma::Number, J::Matrix,
604602
return nothing
605603
end
606604

607-
function jacobian2W(mass_matrix::MT, dtgamma::Number, J::AbstractMatrix,
608-
W_transform::Bool)::Nothing where {MT}
605+
function jacobian2W(mass_matrix, dtgamma::Number, J::AbstractMatrix,
606+
W_transform::Bool)
609607
# check size and dimension
610608
mass_matrix isa UniformScaling ||
611609
@boundscheck axes(mass_matrix) == axes(J) || _throwJMerror(J, mass_matrix)
612610
@inbounds if W_transform
613611
invdtgamma = inv(dtgamma)
614-
if MT <: UniformScaling
612+
if mass_matrix isa UniformScaling
615613
λ = -mass_matrix.λ
616614
W = J +* invdtgamma) * I
617615
else
618616
W = muladd(-mass_matrix, invdtgamma, J)
619617
end
620618
else
621-
if MT <: UniformScaling
619+
if mass_matrix isa UniformScaling
622620
λ = -mass_matrix.λ
623621
W = dtgamma * J + λ * I
624622
else
@@ -738,67 +736,33 @@ end
738736
islin, isode = islinearfunction(integrator)
739737
!isdae && update_coefficients!(mass_matrix, uprev, p, t)
740738

741-
if cache.W isa WOperator
742-
W = cache.W
743-
if isnewton(nlsolver)
744-
# we will call `update_coefficients!` for u/p/t in NLNewton
745-
update_coefficients!(W; transform = W_transform, dtgamma)
739+
if cache.W isa StaticWOperator
740+
integrator.stats.nw += 1
741+
J = calc_J(integrator, cache, next_step)
742+
W = StaticWOperator(W_transform ? J - mass_matrix * inv(dtgamma) : dtgamma * J - mass_matrix)
743+
elseif cache.W isa WOperator
744+
integrator.stats.nw += 1
745+
J = if islin
746+
isode ? f.f : f.f1.f
746747
else
747-
update_coefficients!(W, uprev, p, t; transform = W_transform, dtgamma)
748+
calc_J(integrator, cache, next_step)
748749
end
749-
if W.J !== nothing && !(W.J isa AbstractSciMLOperator)
750-
islin, isode = islinearfunction(integrator)
751-
J = islin ? (isode ? f.f : f.f1.f) : calc_J(integrator, cache, next_step)
752-
!isdae &&
753-
jacobian2W!(W._concrete_form, mass_matrix, dtgamma, J, W_transform)
754-
end
755-
elseif cache.W isa AbstractSciMLOperator && !(cache.W isa StaticWOperator)
756-
J = update_coefficients(cache.J, uprev, p, t)
750+
W = WOperator{false}(mass_matrix, dtgamma, J, uprev, cache.W.jacvec; transform = W_transform)
751+
elseif cache.W isa AbstractSciMLOperator
757752
W = update_coefficients(cache.W, uprev, p, t; dtgamma, transform = W_transform)
758-
elseif islin
759-
J = isode ? f.f : f.f1.f # unwrap the Jacobian accordingly
760-
W = WOperator{false}(mass_matrix, dtgamma, J, uprev; transform = W_transform)
761-
elseif DiffEqBase.has_jac(f)
762-
J = f.jac(uprev, p, t)
763-
if J isa StaticArray &&
764-
integrator.alg isa
765-
Union{
766-
Rosenbrock23, Rodas23W, Rodas3P, Rodas4, Rodas4P, Rodas4P2,
767-
Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}
768-
W = W_transform ? J - mass_matrix * inv(dtgamma) :
769-
dtgamma * J - mass_matrix
770-
else
771-
if !isa(J, AbstractSciMLOperator) && (!isnewton(nlsolver) ||
772-
nlsolver.cache.W.J isa AbstractSciMLOperator)
773-
J = MatrixOperator(J)
774-
end
775-
W = WOperator{false}(mass_matrix, dtgamma, J, uprev, cache.W.jacvec;
776-
transform = W_transform)
777-
end
778-
integrator.stats.nw += 1
779753
else
780754
integrator.stats.nw += 1
781-
J = calc_J(integrator, cache, next_step)
755+
J = islin ? isode ? f.f : f.f1.f : calc_J(integrator, cache, next_step)
782756
if isdae
783757
W = J
784758
else
785-
W_full = W_transform ? J - mass_matrix * inv(dtgamma) :
759+
W = W_transform ? J - mass_matrix * inv(dtgamma) :
786760
dtgamma * J - mass_matrix
787-
len = StaticArrayInterface.known_length(typeof(W_full))
788-
W = if W_full isa Number
789-
W_full
790-
elseif len !== nothing &&
791-
integrator.alg isa
792-
Union{Rosenbrock23, Rodas23W, Rodas3P, Rodas4, Rodas4P,
793-
Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}
794-
StaticWOperator(W_full)
795-
else
796-
DiffEqBase.default_factorize(W_full)
761+
if !isa(W, Number)
762+
W = DiffEqBase.default_factorize(W)
797763
end
798764
end
799765
end
800-
(W isa WOperator && unwrap_alg(integrator, true) isa NewtonAlgorithm) &&
801-
(W = update_coefficients!(W, uprev, p, t)) # we will call `update_coefficients!` in NLNewton
802766
is_compos && (integrator.eigen_est = isarray ? constvalue(opnorm(J, Inf)) :
803767
integrator.opts.internalnorm(J, t))
804768
return W
@@ -876,10 +840,11 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits},
876840
elseif f.jac_prototype isa AbstractSciMLOperator
877841
W = WOperator{IIP}(f, u, dt)
878842
J = W.J
843+
elseif islin
844+
J = isode ? f.f : f.f1.f # unwrap the Jacobian accordingly
845+
W = WOperator{IIP}(f.mass_matrix, dt, J, u)
879846
elseif IIP && f.jac_prototype !== nothing && concrete_jac(alg) === nothing &&
880-
(alg.linsolve === nothing ||
881-
alg.linsolve !== nothing &&
882-
LinearSolve.needs_concrete_A(alg.linsolve))
847+
(alg.linsolve === nothing || LinearSolve.needs_concrete_A(alg.linsolve))
883848

884849
# If factorization, then just use the jac_prototype
885850
J = similar(f.jac_prototype)
@@ -896,7 +861,6 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits},
896861
autodiff = alg_autodiff(alg), tag = OrdinaryDiffEqTag())
897862
J = jacvec
898863
W = WOperator{IIP}(f.mass_matrix, dt, J, u, jacvec)
899-
900864
elseif alg.linsolve !== nothing && !LinearSolve.needs_concrete_A(alg.linsolve) ||
901865
concrete_jac(alg) !== nothing && concrete_jac(alg)
902866
# The linear solver does not need a concrete Jacobian, but the user has
@@ -908,42 +872,36 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits},
908872
else
909873
deepcopy(f.jac_prototype)
910874
end
911-
__f = if IIP
912-
(du, u, p, t) -> _f(du, u, p, t)
875+
W = if J isa StaticMatrix && alg isa OrdinaryDiffEqRosenbrockAdaptiveAlgorithm
876+
StaticWOperator(J, false)
877+
elseif J isa StaticMatrix
878+
ArrayInterface.lu_instance(J)
913879
else
914-
(u, p, t) -> _f(u, p, t)
915-
end
916-
jacvec = JacVec(__f, copy(u), p, t;
917-
autodiff = alg_autodiff(alg), tag = OrdinaryDiffEqTag())
918-
W = WOperator{IIP}(f.mass_matrix, dt, J, u, jacvec)
919-
920-
elseif islin || (!IIP && DiffEqBase.has_jac(f))
921-
J = islin ? (isode ? f.f : f.f1.f) : f.jac(uprev, p, t) # unwrap the Jacobian accordingly
922-
if !isa(J, AbstractSciMLOperator)
923-
J = MatrixOperator(J)
880+
__f = if IIP
881+
(du, u, p, t) -> _f(du, u, p, t)
882+
else
883+
(u, p, t) -> _f(u, p, t)
884+
end
885+
jacvec = JacVec(__f, copy(u), p, t;
886+
autodiff = alg_autodiff(alg), tag = OrdinaryDiffEqTag())
887+
WOperator{IIP}(f.mass_matrix, dt, J, u, jacvec)
924888
end
925-
W = WOperator{IIP}(f.mass_matrix, dt, J, u)
926889
else
927-
J = if f.jac_prototype === nothing
890+
J = if !IIP && DiffEqBase.has_jac(f)
891+
f.jac(uprev, p, t)
892+
elseif f.jac_prototype === nothing
928893
ArrayInterface.undefmatrix(u)
929894
else
930895
deepcopy(f.jac_prototype)
931896
end
932-
isdae = alg isa DAEAlgorithm
933-
W = if isdae
897+
W = if alg isa DAEAlgorithm
934898
J
935899
elseif IIP
936900
similar(J)
901+
elseif J isa StaticMatrix && alg isa OrdinaryDiffEqRosenbrockAdaptiveAlgorithm
902+
StaticWOperator(J, false)
937903
else
938-
len = StaticArrayInterface.known_length(typeof(J))
939-
if len !== nothing &&
940-
alg isa
941-
Union{Rosenbrock23, Rodas23W, Rodas3P, Rodas4, Rodas4P,
942-
Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}
943-
StaticWOperator(J, false)
944-
else
945-
ArrayInterface.lu_instance(J)
946-
end
904+
ArrayInterface.lu_instance(J)
947905
end
948906
end
949907
return J, W

test/interface/linear_solver_test.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,43 @@ end
157157
atol = 1e-1, rtol = 1e-1)
158158
@test isapprox(exp.(p), g_helper(p; alg = KenCarp47(linsolve = KrylovJL_GMRES()));
159159
atol = 1e-1, rtol = 1e-1)
160+
161+
using OrdinaryDiffEq, StaticArrays, LinearSolve, ParameterizedFunctions
162+
163+
hires = @ode_def Hires begin
164+
dy1 = -1.71*y1 + 0.43*y2 + 8.32*y3 + 0.0007
165+
dy2 = 1.71*y1 - 8.75*y2
166+
dy3 = -10.03*y3 + 0.43*y4 + 0.035*y5
167+
dy4 = 8.32*y2 + 1.71*y3 - 1.12*y4
168+
dy5 = -1.745*y5 + 0.43*y6 + 0.43*y7
169+
dy6 = -280.0*y6*y8 + 0.69*y4 + 1.71*y5 - 0.43*y6 + 0.69*y7
170+
dy7 = 280.0*y6*y8 - 1.81*y7
171+
dy8 = -280.0*y6*y8 + 1.81*y7
172+
end
173+
174+
u0 = zeros(8)
175+
u0[1] = 1
176+
u0[8] = 0.0057
177+
178+
probiip = ODEProblem{true}(hires, u0, (0.0,10.0))
179+
proboop = ODEProblem{false}(hires, u0, (0.0,10.0))
180+
probstatic = ODEProblem{false}(hires, SVector{8}(u0), (0.0,10.0))
181+
probs = (;probiip, proboop, probstatic)
182+
qndf = QNDF()
183+
krylov_qndf = QNDF(linsolve=KrylovJL_GMRES())
184+
fbdf = FBDF()
185+
krylov_fbdf = FBDF(linsolve=KrylovJL_GMRES())
186+
rodas = Rodas5P()
187+
krylov_rodas = Rodas5P(linsolve=KrylovJL_GMRES())
188+
solvers = (;qndf, krylov_qndf, rodas, krylov_rodas, fbdf, krylov_fbdf, )
189+
190+
refsol = solve(probiip, FBDF(), abstol=1e-12, reltol=1e-12)
191+
@testset "Hires calc_W tests" begin
192+
@testset "$probname" for (probname, prob) in pairs(probs)
193+
@testset "$solname" for (solname, solver) in pairs(solvers)
194+
sol = solve(prob, solver, abstol=1e-12, reltol=1e-12, maxiters=2e4)
195+
@test sol.retcode == ReturnCode.Success
196+
@test isapprox(sol.u[end], refsol.u[end], rtol=1e-8, atol=1e-10)
197+
end
198+
end
199+
end

0 commit comments

Comments
 (0)