Skip to content

Commit c8409a9

Browse files
Merge pull request #1288 from ChrisRackauckas-Claude/fix-sparse-jac-functionwrapper
Wrap Jacobian with both dense and sparse FunctionWrapper signatures
2 parents 92bfd6b + dea4704 commit c8409a9

File tree

1 file changed

+27
-5
lines changed

1 file changed

+27
-5
lines changed

src/solve.jl

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -810,12 +810,34 @@ function promote_f(
810810
if f.tgrad !== nothing && !(f.tgrad isa FunctionWrappersWrappers.FunctionWrappersWrapper)
811811
f = @set f.tgrad = wrapfun_jac_iip(f.tgrad, (u0, u0, p, t))
812812
end
813-
# Wrap the Jacobian if present, so its type is also erased
813+
# Wrap the Jacobian if present, so its type is also erased.
814+
# Include both dense and sparse matrix signatures when the function
815+
# has a sparsity pattern, since the solver may use either depending on
816+
# the autodiff configuration (AutoSparse creates sparse J from sparsity).
814817
if f.jac !== nothing && !(f.jac isa FunctionWrappersWrappers.FunctionWrappersWrapper)
815-
n = length(u0)
816-
J_proto = f.jac_prototype !== nothing ? similar(f.jac_prototype, uElType) :
817-
zeros(uElType, n, n)
818-
f = @set f.jac = wrapfun_jac_iip(f.jac, (J_proto, u0, p, t))
818+
if f.jac_prototype !== nothing
819+
J_T = Base.promote_op(similar, typeof(f.jac_prototype), Type{uElType})
820+
sig = Tuple{J_T, typeof(u0), typeof(p), typeof(t)}
821+
f = @set f.jac = FunctionWrappersWrappers.FunctionWrappersWrapper(
822+
Void(f.jac), (sig,), (Nothing,))
823+
elseif isdefined(f, :sparsity) && f.sparsity isa AbstractMatrix &&
824+
!(f.sparsity isa Matrix)
825+
# The sparsity pattern is a non-dense matrix (e.g. SparseMatrixCSC).
826+
# The solver may call the Jacobian with either a dense or sparse matrix
827+
# depending on the autodiff config, so wrap for both signatures.
828+
dense_sig = Tuple{Matrix{uElType}, typeof(u0), typeof(p), typeof(t)}
829+
sparse_J_T = Base.promote_op(similar, typeof(f.sparsity), Type{uElType})
830+
sparse_sig = Tuple{sparse_J_T, typeof(u0), typeof(p), typeof(t)}
831+
f = @set f.jac = FunctionWrappersWrappers.FunctionWrappersWrapper(
832+
Void(f.jac),
833+
(dense_sig, sparse_sig),
834+
(Nothing, Nothing)
835+
)
836+
else
837+
sig = Tuple{Matrix{uElType}, typeof(u0), typeof(p), typeof(t)}
838+
f = @set f.jac = FunctionWrappersWrappers.FunctionWrappersWrapper(
839+
Void(f.jac), (sig,), (Nothing,))
840+
end
819841
end
820842
return unwrapped_f(f, wrapfun_iip(f.f, (u0, u0, p, t), Val(CS)))
821843
else

0 commit comments

Comments
 (0)