Skip to content

Performance regression using SciMLOperators with OrdinaryDiffEq.jl #315

@utkarsh530

Description

@utkarsh530

If we are using SciMLOperator as a prototype to W or Jacobian, application of the operator triggers GC frequently, regressing performance:

MWE

using LinearAlgebra, SciMLOperators, OrdinaryDiffEq, LinearSolve

function chain!(du,u,p,t)
    a = 10.0
    du[1] = -a*u[1]
    for i in 2:length(u)
        du[i] = a*(u[i-1]-u[i])
    end
end
function jac_chain!(J,u,p,t)
    a = 10.0; fill!(J,0.0)
    for i in 2:length(u)
        J[i,i] = -a
        if i>1; J[i,i-1] = a; end
    end
end

N=1000
u0=ones(N)
tspan = (0.0, 2.0)

J = zeros(N, N)

jac_chain!(J, u0, nothing, 0.0)


I_N = Diagonal(ones(N))


J1_op = MatrixOperator(UpperTriangular(J - Diagonal(J)))
J2_op = MatrixOperator(LowerTriangular(J))
J_op = J1_op + J2_op

J_op = cache_operator(J_op, zeros(N^2))


Ju = UpperTriangular(J - Diagonal(J))

W1_op = MatrixOperator(
    I_N - Ju;
    update_func! = (M, u, p, t;
    dtgamma = 1.0
    ) ->(@. M = I_N - dtgamma * Ju),
    accepted_kwargs = Val((:dtgamma,)),
)

Jl = LowerTriangular(J)

W2_op = MatrixOperator(
    I_N - Jl;
    update_func! = (M, u, p, t; dtgamma = 1.0) ->
        (@. M = I_N - dtgamma * Jl),
    accepted_kwargs = Val((:dtgamma,)),
)

transform_op = ScalarOperator(
    0.0;
    update_func = (old_op, u, p, t; dtgamma = 1.0) ->
        true ? inv(dtgamma) : one(dtgamma),
    accepted_kwargs = Val((:dtgamma, )),
)

W_prototype = -(W1_op * W2_op) * transform_op

W_prototype = cache_operator(W_prototype, zeros(N))

func = ODEFunction{true, SciMLBase.FullSpecialize}(chain!; jac_prototype = MatrixOperator(J), jac = jac_chain!, W_prototype, sparsity = convert(AbstractMatrix, J_op))
prob = ODEProblem(func, u0, tspan)

@time sol = solve(prob, KenCarp4());

@profview_allocs solve(prob, KenCarp4());
Image

The culprit in this case triggering GC (linsolve is a Krylov method) is this piece of code:

function LinearAlgebra.mul!(w::AbstractVecOrMat, L::ComposedOperator, v::AbstractVecOrMat)
@assert iscached(L) """cache needs to be set up for operator of type
$L. Set up cache by calling `cache_operator(L, v)`"""
vecs = (w, L.cache[1:(end - 1)]..., v)
for i in reverse(1:length(L.ops))
mul!(vecs[i], L.ops[i], vecs[i + 1])
end
w
end

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions