Skip to content

Commit cb105fc

Browse files
author
Avik Pal
committed
Fix Jacobian Construction
1 parent a1877cd commit cb105fc

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

src/NonlinearSolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ if isdefined(Base, :Experimental) && isdefined(Base.Experimental, Symbol("@max_m
44
@eval Base.Experimental.@max_methods 1
55
end
66

7-
using DiffEqBase, LinearAlgebra, LinearSolve, SparseDiffTools
7+
using DiffEqBase, LinearAlgebra, LinearSolve, SparseArrays, SparseDiffTools
88
import ForwardDiff
99

1010
import ADTypes: AbstractFiniteDifferencesMode

src/jacobian.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,11 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
8080
if has_analytic_jac
8181
f.jac_prototype === nothing ? undefmatrix(u) : f.jac_prototype
8282
else
83-
f.jac_prototype === nothing ? init_jacobian(jac_cache) : f.jac_prototype
83+
if f.jac_prototype === nothing
84+
__safe_init_jacobian(jac_cache)
85+
else
86+
f.jac_prototype
87+
end
8488
end
8589
end
8690

@@ -98,6 +102,26 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
98102
return uf, linsolve, J, fu, jac_cache, du
99103
end
100104

105+
@generated function __getfield(c::T, ::Val{S}) where {T, S}
106+
hasfield(T, S) && return :(c.$(S))
107+
return :(nothing)
108+
end
109+
110+
function __safe_init_jacobian(c::SparseDiffTools.AbstractMaybeSparseJacobianCache)
111+
T = promote_type(eltype(c.fx), eltype(c.x))
112+
return __safe_init_jacobian(__getfield(c, Val(:jac_prototype)), T, c.fx, c.x)
113+
end
114+
function __safe_init_jacobian(::Nothing, ::Type{T}, fx, x) where {T}
115+
return similar(fx, T, length(fx), length(x))
116+
end
117+
function __safe_init_jacobian(J::SparseMatrixCSC, ::Type{T}, fx, x) where {T}
118+
@assert size(J, 1) == length(fx) && size(J, 2) == length(x)
119+
return T.(J)
120+
end
121+
function __safe_init_jacobian(J, ::Type{T}, fx, x) where {T}
122+
return similar(fx, T, length(fx), length(x)) # This is not safe for sparse jacobians
123+
end
124+
101125
__get_nonsparse_ad(::AutoSparseForwardDiff) = AutoForwardDiff()
102126
__get_nonsparse_ad(::AutoSparseFiniteDiff) = AutoFiniteDiff()
103127
__get_nonsparse_ad(::AutoSparseZygote) = AutoZygote()

0 commit comments

Comments
 (0)