-
-
Notifications
You must be signed in to change notification settings - Fork 16
Closed
Description
If we are using SciMLOperator as a prototype to W or Jacobian, application of the operator triggers GC frequently, regressing performance:
MWE
using LinearAlgebra, SciMLOperators, OrdinaryDiffEq, LinearSolve
function chain!(du,u,p,t)
a = 10.0
du[1] = -a*u[1]
for i in 2:length(u)
du[i] = a*(u[i-1]-u[i])
end
end
function jac_chain!(J,u,p,t)
a = 10.0; fill!(J,0.0)
for i in 2:length(u)
J[i,i] = -a
if i>1; J[i,i-1] = a; end
end
end
N=1000
u0=ones(N)
tspan = (0.0, 2.0)
J = zeros(N, N)
jac_chain!(J, u0, nothing, 0.0)
I_N = Diagonal(ones(N))
J1_op = MatrixOperator(UpperTriangular(J - Diagonal(J)))
J2_op = MatrixOperator(LowerTriangular(J))
J_op = J1_op + J2_op
J_op = cache_operator(J_op, zeros(N^2))
Ju = UpperTriangular(J - Diagonal(J))
W1_op = MatrixOperator(
I_N - Ju;
update_func! = (M, u, p, t;
dtgamma = 1.0
) ->(@. M = I_N - dtgamma * Ju),
accepted_kwargs = Val((:dtgamma,)),
)
Jl = LowerTriangular(J)
W2_op = MatrixOperator(
I_N - Jl;
update_func! = (M, u, p, t; dtgamma = 1.0) ->
(@. M = I_N - dtgamma * Jl),
accepted_kwargs = Val((:dtgamma,)),
)
transform_op = ScalarOperator(
0.0;
update_func = (old_op, u, p, t; dtgamma = 1.0) ->
true ? inv(dtgamma) : one(dtgamma),
accepted_kwargs = Val((:dtgamma, )),
)
W_prototype = -(W1_op * W2_op) * transform_op
W_prototype = cache_operator(W_prototype, zeros(N))
func = ODEFunction{true, SciMLBase.FullSpecialize}(chain!; jac_prototype = MatrixOperator(J), jac = jac_chain!, W_prototype, sparsity = convert(AbstractMatrix, J_op))
prob = ODEProblem(func, u0, tspan)
@time sol = solve(prob, KenCarp4());
@profview_allocs solve(prob, KenCarp4());
The culprit in this case triggering GC (linsolve is a Krylov method) is this piece of code:
SciMLOperators.jl/src/basic.jl
Lines 890 to 899 in 572d354
| function LinearAlgebra.mul!(w::AbstractVecOrMat, L::ComposedOperator, v::AbstractVecOrMat) | |
| @assert iscached(L) """cache needs to be set up for operator of type | |
| $L. Set up cache by calling `cache_operator(L, v)`""" | |
| vecs = (w, L.cache[1:(end - 1)]..., v) | |
| for i in reverse(1:length(L.ops)) | |
| mul!(vecs[i], L.ops[i], vecs[i + 1]) | |
| end | |
| w | |
| end |
Metadata
Metadata
Assignees
Labels
No labels