Skip to content

Commit 0473364

Browse files
author
oscarddssmith
committed
improve static
1 parent 0445429 commit 0473364

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

src/derivative_utils.jl

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,7 @@ end
739739
if cache.W isa StaticWOperator
740740
integrator.stats.nw += 1
741741
J = calc_J(integrator, cache, next_step)
742+
@assert J isa StaticArray
742743
W = StaticWOperator(W_transform ? J - mass_matrix * inv(dtgamma) : dtgamma * J - mass_matrix)
743744
elseif cache.W isa WOperator
744745
integrator.stats.nw += 1
@@ -770,8 +771,6 @@ end
770771
end
771772
end
772773
end
773-
(W isa WOperator && unwrap_alg(integrator, true) isa NewtonAlgorithm) &&
774-
(W = update_coefficients!(W, uprev, p, t)) # we will call `update_coefficients!` in NLNewton
775774
is_compos && (integrator.eigen_est = isarray ? constvalue(opnorm(J, Inf)) :
776775
integrator.opts.internalnorm(J, t))
777776
return W
@@ -833,6 +832,7 @@ function update_W!(nlsolver::AbstractNLSolver,
833832
nothing
834833
end
835834

835+
import StaticArrays: StaticMatrix
836836
function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits},
837837
::Val{IIP}) where {IIP, uEltypeNoUnits, F}
838838
# TODO - make J, W AbstractSciMLOperators (lazily defined with scimlops functionality)
@@ -881,14 +881,20 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits},
881881
else
882882
deepcopy(f.jac_prototype)
883883
end
884-
__f = if IIP
885-
(du, u, p, t) -> _f(du, u, p, t)
884+
W = if J isa StaticMatrix && alg isa OrdinaryDiffEqRosenbrockAdaptiveAlgorithm
885+
StaticWOperator(J, false)
886+
elseif J isa StaticMatrix
887+
ArrayInterface.lu_instance(J)
886888
else
887-
(u, p, t) -> _f(u, p, t)
889+
__f = if IIP
890+
(du, u, p, t) -> _f(du, u, p, t)
891+
else
892+
(u, p, t) -> _f(u, p, t)
893+
end
894+
jacvec = JacVec(__f, copy(u), p, t;
895+
autodiff = alg_autodiff(alg), tag = OrdinaryDiffEqTag())
896+
WOperator{IIP}(f.mass_matrix, dt, J, u, jacvec)
888897
end
889-
jacvec = JacVec(__f, copy(u), p, t;
890-
autodiff = alg_autodiff(alg), tag = OrdinaryDiffEqTag())
891-
W = WOperator{IIP}(f.mass_matrix, dt, J, u, jacvec)
892898
else
893899
J = if !IIP && DiffEqBase.has_jac(f)
894900
f.jac(uprev, p, t)
@@ -901,16 +907,13 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits},
901907
J
902908
elseif IIP
903909
similar(J)
910+
elseif J isa StaticMatrix && alg isa OrdinaryDiffEqRosenbrockAdaptiveAlgorithm
911+
StaticWOperator(J, false)
904912
else
905-
len = StaticArrayInterface.known_length(typeof(J))
906-
if len !== nothing &&
907-
alg isa OrdinaryDiffEqRosenbrockAdaptiveAlgorithm
908-
StaticWOperator(J, false)
909-
else
910-
ArrayInterface.lu_instance(J)
911-
end
913+
ArrayInterface.lu_instance(J)
912914
end
913915
end
916+
@show W
914917
return J, W
915918
end
916919

0 commit comments

Comments
 (0)