diff --git a/src/basic.jl b/src/basic.jl index e0f5291b..05868889 100644 --- a/src/basic.jl +++ b/src/basic.jl @@ -887,15 +887,35 @@ function cache_internals(L::ComposedOperator, v::AbstractVecOrMat) @reset L.ops = ops end -function LinearAlgebra.mul!(w::AbstractVecOrMat, L::ComposedOperator, v::AbstractVecOrMat) - @assert iscached(L) """cache needs to be set up for operator of type - $L. Set up cache by calling `cache_operator(L, v)`""" +@generated function LinearAlgebra.mul!(w::AbstractVecOrMat, L::ComposedOperator, v::AbstractVecOrMat) + N = length(L.parameters[2].parameters) # Number of operators + + # Generate the mul! calls in reverse order + # vecs conceptually is (w, L.cache[1], L.cache[2], ..., L.cache[N-1], v) + # For i in reverse(1:N): + # mul!(vecs[i], L.ops[i], vecs[i+1]) + + exprs = [] + for i in N:-1:1 + if i == N + # Last operator: mul!(L.cache[N-1], L.ops[N], v) + push!(exprs, :(mul!(L.cache[$(N - 1)], L.ops[$i], v))) + elseif i == 1 + # First operator: mul!(w, L.ops[1], L.cache[1]) + push!(exprs, :(mul!(w, L.ops[$i], L.cache[1]))) + else + # Middle operators: mul!(L.cache[i-1], L.ops[i], L.cache[i]) + push!(exprs, :(mul!(L.cache[$(i - 1)], L.ops[$i], L.cache[$i]))) + end + end - vecs = (w, L.cache[1:(end - 1)]..., v) - for i in reverse(1:length(L.ops)) - mul!(vecs[i], L.ops[i], vecs[i + 1]) + quote + @assert iscached(L) """cache needs to be set up for operator of type + $L. Set up cache by calling `cache_operator(L, v)`""" + + $(exprs...) + w end - w end function LinearAlgebra.mul!(w::AbstractVecOrMat, @@ -914,15 +934,36 @@ function LinearAlgebra.mul!(w::AbstractVecOrMat, axpy!(β, cache, w) end -function LinearAlgebra.ldiv!(w::AbstractVecOrMat, L::ComposedOperator, v::AbstractVecOrMat) - @assert iscached(L) """cache needs to be set up for operator of type - $L. Set up cache by calling `cache_operator(L, v)`.""" +@generated function LinearAlgebra.ldiv!(w::AbstractVecOrMat, L::ComposedOperator, v::AbstractVecOrMat) + N = length(L.parameters[2].parameters) # Number of operators - vecs = (v, reverse(L.cache[1:(end - 1)])..., w) - for i in 1:length(L.ops) - ldiv!(vecs[i + 1], L.ops[i], vecs[i]) + # Generate the ldiv! calls in forward order + # vecs conceptually is (v, reverse(L.cache[1:(N-1)])..., w) + # = (v, L.cache[N-1], L.cache[N-2], ..., L.cache[1], w) + # For i in 1:N: + # ldiv!(vecs[i+1], L.ops[i], vecs[i]) + + exprs = [] + for i in 1:N + if i == 1 + # First operator: ldiv!(L.cache[N-1], L.ops[1], v) + push!(exprs, :(ldiv!(L.cache[$(N - 1)], L.ops[$i], v))) + elseif i == N + # Last operator: ldiv!(w, L.ops[N], L.cache[1]) + push!(exprs, :(ldiv!(w, L.ops[$i], L.cache[1]))) + else + # Middle operators: ldiv!(L.cache[N-i], L.ops[i], L.cache[N-i+1]) + push!(exprs, :(ldiv!(L.cache[$(N - i)], L.ops[$i], L.cache[$(N - i + 1)]))) + end + end + + quote + @assert iscached(L) """cache needs to be set up for operator of type + $L. Set up cache by calling `cache_operator(L, v)`.""" + + $(exprs...) + w end - w end function LinearAlgebra.ldiv!(L::ComposedOperator, v::AbstractVecOrMat) @@ -943,15 +984,36 @@ function (L::ComposedOperator)(v::AbstractVecOrMat, u, p, t; kwargs...) end # In-place: w is destination, v is action vector, u is update vector -function (L::ComposedOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...) - update_coefficients!(L, u, p, t; kwargs...) - @assert iscached(L) "Cache needs to be set up for ComposedOperator. Call cache_operator(L, u) first." +@generated function (L::ComposedOperator)( + w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...) + N = length(L.parameters[2].parameters) # Number of operators + + # Generate the operator call expressions in reverse order + # vecs conceptually is (w, L.cache[1], L.cache[2], ..., L.cache[N-1], v) + # For i in reverse(1:N): + # L.ops[i](vecs[i], vecs[i+1], u, p, t; kwargs...) + + exprs = [] + for i in N:-1:1 + if i == N + # Last operator: L.ops[N](L.cache[N-1], v, u, p, t; kwargs...) + push!(exprs, :(L.ops[$i](L.cache[$(N - 1)], v, u, p, t; kwargs...))) + elseif i == 1 + # First operator: L.ops[1](w, L.cache[1], u, p, t; kwargs...) + push!(exprs, :(L.ops[$i](w, L.cache[1], u, p, t; kwargs...))) + else + # Middle operators: L.ops[i](L.cache[i-1], L.cache[i], u, p, t; kwargs...) + push!(exprs, :(L.ops[$i](L.cache[$(i - 1)], L.cache[$i], u, p, t; kwargs...))) + end + end - vecs = (w, L.cache[1:(end - 1)]..., v) - for i in reverse(1:length(L.ops)) - L.ops[i](vecs[i], vecs[i + 1], u, p, t; kwargs...) + quote + update_coefficients!(L, u, p, t; kwargs...) + @assert iscached(L) "Cache needs to be set up for ComposedOperator. Call cache_operator(L, u) first." + + $(exprs...) + w end - w end # In-place with scaling: w = α*(L*v) + β*w diff --git a/test/downstream/alloccheck.jl b/test/downstream/alloccheck.jl index acab8f0c..9e4e7662 100644 --- a/test/downstream/alloccheck.jl +++ b/test/downstream/alloccheck.jl @@ -1,8 +1,10 @@ -using SciMLOperators, Random, SparseArrays, Test +using SciMLOperators, Random, SparseArrays, Test, LinearAlgebra using SciMLOperators: IdentityOperator, NullOperator, ScaledOperator, - AddedOperator + AddedOperator, + ComposedOperator, + cache_operator function apply_op!(H, w, v, u, p, t) H(w, v, u, p, t) @@ -64,4 +66,105 @@ test_apply_noalloc(H, w, v, u, p, t) = @test (@allocations apply_op!(H, w, v, u, test_apply_noalloc(H_sparse, w, v, u, p, t) test_apply_noalloc(H_dense, w, v, u, p, t) end + + # Test ComposedOperator allocations (PR #316) + # Before the fix, tuple splatting caused many allocations. + # After the fix, we should have minimal allocations (Julia 1.11 has 1, earlier versions have 0). + @testset "ComposedOperator minimal allocations" begin + N = 100 + + # Create operators for composition + A1 = MatrixOperator(rand(N, N)) + A2 = MatrixOperator(rand(N, N)) + A3 = MatrixOperator(rand(N, N)) + + # Create ComposedOperator + L = A1 * A2 * A3 + + # Set up cache + v = rand(N) + w = similar(v) + L = cache_operator(L, v) + + u = rand(N) + p = nothing + t = 0.0 + + # Warm up + mul!(w, L, v) + L(w, v, u, p, t) + + # Test mul! - should have minimal allocations + # Julia 1.11 has a known minor allocation issue (1 allocation) + # Earlier versions should have 0 allocations + allocs_mul = @allocations mul!(w, L, v) + @test allocs_mul <= 1 + + # Test operator call - should have minimal allocations + allocs_call = @allocations L(w, v, u, p, t) + @test allocs_call <= 1 + + # Test with matrices + K = 5 + V = rand(N, K) + W = similar(V) + L_mat = cache_operator(A1 * A2 * A3, V) + + # Warm up + mul!(W, L_mat, V) + L_mat(W, V, u, p, t) + + # Test with matrices - should have minimal allocations + allocs_mul_mat = @allocations mul!(W, L_mat, V) + @test allocs_mul_mat <= 1 + + allocs_call_mat = @allocations L_mat(W, V, u, p, t) + @test allocs_call_mat <= 1 + end + + # Test accepted_kwargs allocations (PR #313) + # With Val(tuple), kwarg filtering should be compile-time with minimal allocations + @testset "accepted_kwargs with Val" begin + N = 50 + + # Create a MatrixOperator with accepted_kwargs using Val for compile-time filtering + J = rand(N, N) + + update_func! = (M, u, p, t; dtgamma = 1.0) -> begin + M .= dtgamma .* J + nothing + end + + op = MatrixOperator( + copy(J); + update_func! = update_func!, + accepted_kwargs = Val((:dtgamma,)) # Use Val for compile-time filtering + ) + + u = rand(N) + p = nothing + t = 0.0 + + # Warm up + update_coefficients!(op, u, p, t; dtgamma = 0.5) + + # Test that update_coefficients! with accepted_kwargs has minimal allocations + # The Val approach significantly reduces allocations compared to plain tuples + allocs_update = @allocations update_coefficients!(op, u, p, t; dtgamma = 0.5) + @test allocs_update <= 6 # Some allocations may occur due to Julia version/kwarg handling + + # Test with different dtgamma values - should have similar behavior + allocs_update2 = @allocations update_coefficients!(op, u, p, t; dtgamma = 1.0) + @test allocs_update2 <= 6 + + allocs_update3 = @allocations update_coefficients!(op, u, p, t; dtgamma = 2.0) + @test allocs_update3 <= 6 + + # Test operator application after update + v = rand(N) + w = similar(v) + op(w, v, u, p, t; dtgamma = 0.5) # Warm up + allocs_call = @allocations op(w, v, u, p, t; dtgamma = 0.5) + @test allocs_call <= 6 + end end