@@ -810,12 +810,34 @@ function promote_f(
810810 if f. tgrad != = nothing && ! (f. tgrad isa FunctionWrappersWrappers. FunctionWrappersWrapper)
811811 f = @set f. tgrad = wrapfun_jac_iip (f. tgrad, (u0, u0, p, t))
812812 end
813- # Wrap the Jacobian if present, so its type is also erased
813+ # Wrap the Jacobian if present, so its type is also erased.
814+ # Include both dense and sparse matrix signatures when the function
815+ # has a sparsity pattern, since the solver may use either depending on
816+ # the autodiff configuration (AutoSparse creates sparse J from sparsity).
814817 if f. jac != = nothing && ! (f. jac isa FunctionWrappersWrappers. FunctionWrappersWrapper)
815- n = length (u0)
816- J_proto = f. jac_prototype != = nothing ? similar (f. jac_prototype, uElType) :
817- zeros (uElType, n, n)
818- f = @set f. jac = wrapfun_jac_iip (f. jac, (J_proto, u0, p, t))
818+ if f. jac_prototype != = nothing
819+ J_T = Base. promote_op (similar, typeof (f. jac_prototype), Type{uElType})
820+ sig = Tuple{J_T, typeof (u0), typeof (p), typeof (t)}
821+ f = @set f. jac = FunctionWrappersWrappers. FunctionWrappersWrapper (
822+ Void (f. jac), (sig,), (Nothing,))
823+ elseif isdefined (f, :sparsity ) && f. sparsity isa AbstractMatrix &&
824+ ! (f. sparsity isa Matrix)
825+ # The sparsity pattern is a non-dense matrix (e.g. SparseMatrixCSC).
826+ # The solver may call the Jacobian with either a dense or sparse matrix
827+ # depending on the autodiff config, so wrap for both signatures.
828+ dense_sig = Tuple{Matrix{uElType}, typeof (u0), typeof (p), typeof (t)}
829+ sparse_J_T = Base. promote_op (similar, typeof (f. sparsity), Type{uElType})
830+ sparse_sig = Tuple{sparse_J_T, typeof (u0), typeof (p), typeof (t)}
831+ f = @set f. jac = FunctionWrappersWrappers. FunctionWrappersWrapper (
832+ Void (f. jac),
833+ (dense_sig, sparse_sig),
834+ (Nothing, Nothing)
835+ )
836+ else
837+ sig = Tuple{Matrix{uElType}, typeof (u0), typeof (p), typeof (t)}
838+ f = @set f. jac = FunctionWrappersWrappers. FunctionWrappersWrapper (
839+ Void (f. jac), (sig,), (Nothing,))
840+ end
819841 end
820842 return unwrapped_f (f, wrapfun_iip (f. f, (u0, u0, p, t), Val (CS)))
821843 else
0 commit comments