Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OrdinaryDiffEq"
uuid = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
authors = ["Chris Rackauckas <[email protected]>", "Yingbo Ma <[email protected]>"]
version = "6.102.1"
authors = ["Chris Rackauckas <[email protected]>", "Yingbo Ma <[email protected]>"]

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -69,6 +69,8 @@ SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac"
Comment on lines +72 to +73
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these should be in extras.

Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand Down Expand Up @@ -173,6 +175,8 @@ SciMLOperators = "1.8"
SciMLStructures = "1.7"
SimpleNonlinearSolve = "2.7"
SimpleUnPack = "1.1"
SparseConnectivityTracer = "1.1.1"
Sparspak = "0.3.14"
Static = "1.2"
StaticArrayInterface = "1.8"
StaticArrays = "1.9.14"
Expand Down
2 changes: 2 additions & 0 deletions lib/OrdinaryDiffEqFIRK/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
OrdinaryDiffEqDifferentiation = "4302a76b-040a-498a-8c04-15b101fed76b"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
Expand Down Expand Up @@ -42,6 +43,7 @@ MuladdMacro = "0.2"
LinearSolve = "3.26"
Polyester = "0.7"
LinearAlgebra = "1.10"
SparseArrays = "1.10"
OrdinaryDiffEqDifferentiation = "1.12.0"
SciMLBase = "2.99"
OrdinaryDiffEqCore = "1.29.0"
Expand Down
3 changes: 2 additions & 1 deletion lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ using MuladdMacro, DiffEqBase, RecursiveArrayTools, Polyester
isfirk, generic_solver_docstring
using SciMLOperators: AbstractSciMLOperator
using LinearAlgebra: I, UniformScaling, mul!, lu
using SparseArrays: nonzeros
import LinearSolve
import FastBroadcast: @..
import OrdinaryDiffEqCore
import OrdinaryDiffEqCore: _ode_interpolant, _ode_interpolant!, has_stiff_interpolation
import FastPower: fastpower
using OrdinaryDiffEqDifferentiation: UJacobianWrapper, build_J_W, build_jac_config,
UDerivativeWrapper, calc_J!, dolinsolve, calc_J,
islinearfunction
islinearfunction, is_sparse
using OrdinaryDiffEqNonlinearSolve: du_alias_or_new, Convergence, FastConvergence, NLStatus,
VerySlowConvergence,
Divergence, get_new_W_γdt_cutoff
Expand Down
47 changes: 37 additions & 10 deletions lib/OrdinaryDiffEqFIRK/src/firk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,15 @@ function alg_cache(alg::RadauIIA3, u, rate_prototype, ::Type{uEltypeNoUnits},
recursivefill!(atmp, false)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw12)

J, W1 = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))
W1 = similar(J, Complex{eltype(W1)})
recursivefill!(W1, false)
J, W1_temp = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))
# For sparse matrices, preserve sparsity pattern for KLU compatibility
if is_sparse(J)
W1 = similar(J, Complex{eltype(W1_temp)})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line is the same between branches.

fill!(nonzeros(W1), false)
else
W1 = similar(J, Complex{eltype(W1_temp)})
recursivefill!(W1, false)
end

linprob = LinearProblem(W1, _vec(cubuff); u0 = _vec(dw12))
linsolve = init(
Expand Down Expand Up @@ -239,8 +245,14 @@ function alg_cache(alg::RadauIIA5, u, rate_prototype, ::Type{uEltypeNoUnits},
if J isa AbstractSciMLOperator
error("Non-concrete Jacobian not yet supported by RadauIIA5.")
end
W2 = similar(J, Complex{eltype(W1)})
recursivefill!(W2, false)
# For sparse matrices, preserve sparsity pattern for KLU compatibility
if is_sparse(J)
W2 = similar(J, Complex{eltype(W1)})
fill!(nonzeros(W2), false)
else
W2 = similar(J, Complex{eltype(W1)})
recursivefill!(W2, false)
end

linprob = LinearProblem(W1, _vec(ubuff); u0 = _vec(dw1))
linsolve1 = init(
Expand Down Expand Up @@ -429,10 +441,18 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
if J isa AbstractSciMLOperator
error("Non-concrete Jacobian not yet supported by RadauIIA5.")
end
W2 = similar(J, Complex{eltype(W1)})
W3 = similar(J, Complex{eltype(W1)})
recursivefill!(W2, false)
recursivefill!(W3, false)
# For sparse matrices, preserve sparsity pattern for KLU compatibility
if is_sparse(J)
W2 = similar(J, Complex{eltype(W1)})
W3 = similar(J, Complex{eltype(W1)})
fill!(nonzeros(W2), false)
fill!(nonzeros(W3), false)
else
W2 = similar(J, Complex{eltype(W1)})
W3 = similar(J, Complex{eltype(W1)})
recursivefill!(W2, false)
recursivefill!(W3, false)
end

linprob = LinearProblem(W1, _vec(ubuff); u0 = _vec(dw1))
linsolve1 = init(
Expand Down Expand Up @@ -638,8 +658,15 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
error("Non-concrete Jacobian not yet supported by AdaptiveRadau.")
end

# For sparse matrices, preserve sparsity pattern for KLU compatibility
W2 = [similar(J, Complex{eltype(W1)}) for _ in 1:((max_stages - 1) ÷ 2)]
recursivefill!.(W2, false)
if is_sparse(J)
for W in W2
fill!(nonzeros(W), false)
end
else
recursivefill!.(W2, false)
end

linprob = LinearProblem(W1, _vec(ubuff); u0 = _vec(dw1))
linsolve1 = init(
Expand Down
41 changes: 41 additions & 0 deletions test_issue_2892.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using OrdinaryDiffEq
using OrdinaryDiffEqFIRK
using ADTypes
using SparseConnectivityTracer
using LinearSolve
using Sparspak

function test_sparse(ode_solver)
function f(du, u, p, t)
du .= [u[1], u[2]]
end

u0 = [1.0, 2.0]
p = ()
du0 = similar(u0)
jac_prototype = float.(ADTypes.jacobian_sparsity(
(du, u) -> f(du, u, p, 0.0),
du0,
u0, TracerSparsityDetector()))

ode_fun = ODEFunction(f, jac_prototype=jac_prototype)
prob = ODEProblem(ode_fun, u0, (0, 10))
sol = solve(prob, ode_solver)
return sol
end

println("Testing AdaptiveRadau with LUFactorization...")
test_sparse(AdaptiveRadau(;linsolve=LUFactorization())) # Success
println("Success!")

println("Testing AdaptiveRadau with SparspakFactorization...")
test_sparse(AdaptiveRadau(;linsolve=SparspakFactorization())) # Success
println("Success!")

println("Testing QNDF with KLUFactorization...")
test_sparse(QNDF(;linsolve=KLUFactorization())) # Success
println("Success!")

println("Testing AdaptiveRadau with KLUFactorization...")
test_sparse(AdaptiveRadau(;linsolve=KLUFactorization())) # Error
println("Success!")
28 changes: 28 additions & 0 deletions test_radauiia5_klu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
using OrdinaryDiffEq
using OrdinaryDiffEqFIRK
using ADTypes
using SparseConnectivityTracer
using LinearSolve

function test_sparse(ode_solver)
function f(du, u, p, t)
du .= [u[1], u[2]]
end

u0 = [1.0, 2.0]
p = ()
du0 = similar(u0)
jac_prototype = float.(ADTypes.jacobian_sparsity(
(du, u) -> f(du, u, p, 0.0),
du0,
u0, TracerSparsityDetector()))

ode_fun = ODEFunction(f, jac_prototype=jac_prototype)
prob = ODEProblem(ode_fun, u0, (0, 10))
sol = solve(prob, ode_solver)
return sol
end

println("Testing RadauIIA5 with KLUFactorization...")
test_sparse(RadauIIA5(;linsolve=KLUFactorization()))
println("Success!")
Loading