Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 83 additions & 21 deletions src/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -887,15 +887,35 @@ function cache_internals(L::ComposedOperator, v::AbstractVecOrMat)
@reset L.ops = ops
end

function LinearAlgebra.mul!(w::AbstractVecOrMat, L::ComposedOperator, v::AbstractVecOrMat)
@assert iscached(L) """cache needs to be set up for operator of type
$L. Set up cache by calling `cache_operator(L, v)`"""
@generated function LinearAlgebra.mul!(w::AbstractVecOrMat, L::ComposedOperator, v::AbstractVecOrMat)
N = length(L.parameters[2].parameters) # Number of operators

# Generate the mul! calls in reverse order
# vecs conceptually is (w, L.cache[1], L.cache[2], ..., L.cache[N-1], v)
# For i in reverse(1:N):
# mul!(vecs[i], L.ops[i], vecs[i+1])

exprs = []
for i in N:-1:1
if i == N
# Last operator: mul!(L.cache[N-1], L.ops[N], v)
push!(exprs, :(mul!(L.cache[$(N - 1)], L.ops[$i], v)))
elseif i == 1
# First operator: mul!(w, L.ops[1], L.cache[1])
push!(exprs, :(mul!(w, L.ops[$i], L.cache[1])))
else
# Middle operators: mul!(L.cache[i-1], L.ops[i], L.cache[i])
push!(exprs, :(mul!(L.cache[$(i - 1)], L.ops[$i], L.cache[$i])))
end
end

vecs = (w, L.cache[1:(end - 1)]..., v)
for i in reverse(1:length(L.ops))
mul!(vecs[i], L.ops[i], vecs[i + 1])
quote
@assert iscached(L) """cache needs to be set up for operator of type
$L. Set up cache by calling `cache_operator(L, v)`"""

$(exprs...)
w
end
w
end

function LinearAlgebra.mul!(w::AbstractVecOrMat,
Expand All @@ -914,15 +934,36 @@ function LinearAlgebra.mul!(w::AbstractVecOrMat,
axpy!(β, cache, w)
end

function LinearAlgebra.ldiv!(w::AbstractVecOrMat, L::ComposedOperator, v::AbstractVecOrMat)
@assert iscached(L) """cache needs to be set up for operator of type
$L. Set up cache by calling `cache_operator(L, v)`."""
@generated function LinearAlgebra.ldiv!(w::AbstractVecOrMat, L::ComposedOperator, v::AbstractVecOrMat)
N = length(L.parameters[2].parameters) # Number of operators

vecs = (v, reverse(L.cache[1:(end - 1)])..., w)
for i in 1:length(L.ops)
ldiv!(vecs[i + 1], L.ops[i], vecs[i])
# Generate the ldiv! calls in forward order
# vecs conceptually is (v, reverse(L.cache[1:(N-1)])..., w)
# = (v, L.cache[N-1], L.cache[N-2], ..., L.cache[1], w)
# For i in 1:N:
# ldiv!(vecs[i+1], L.ops[i], vecs[i])

exprs = []
for i in 1:N
if i == 1
# First operator: ldiv!(L.cache[N-1], L.ops[1], v)
push!(exprs, :(ldiv!(L.cache[$(N - 1)], L.ops[$i], v)))
elseif i == N
# Last operator: ldiv!(w, L.ops[N], L.cache[1])
push!(exprs, :(ldiv!(w, L.ops[$i], L.cache[1])))
else
# Middle operators: ldiv!(L.cache[N-i], L.ops[i], L.cache[N-i+1])
push!(exprs, :(ldiv!(L.cache[$(N - i)], L.ops[$i], L.cache[$(N - i + 1)])))
end
end

quote
@assert iscached(L) """cache needs to be set up for operator of type
$L. Set up cache by calling `cache_operator(L, v)`."""

$(exprs...)
w
end
w
end

function LinearAlgebra.ldiv!(L::ComposedOperator, v::AbstractVecOrMat)
Expand All @@ -943,15 +984,36 @@ function (L::ComposedOperator)(v::AbstractVecOrMat, u, p, t; kwargs...)
end

# In-place: w is destination, v is action vector, u is update vector
function (L::ComposedOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...)
update_coefficients!(L, u, p, t; kwargs...)
@assert iscached(L) "Cache needs to be set up for ComposedOperator. Call cache_operator(L, u) first."
@generated function (L::ComposedOperator)(
w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...)
N = length(L.parameters[2].parameters) # Number of operators

# Generate the operator call expressions in reverse order
# vecs conceptually is (w, L.cache[1], L.cache[2], ..., L.cache[N-1], v)
# For i in reverse(1:N):
# L.ops[i](vecs[i], vecs[i+1], u, p, t; kwargs...)

exprs = []
for i in N:-1:1
if i == N
# Last operator: L.ops[N](L.cache[N-1], v, u, p, t; kwargs...)
push!(exprs, :(L.ops[$i](L.cache[$(N - 1)], v, u, p, t; kwargs...)))
elseif i == 1
# First operator: L.ops[1](w, L.cache[1], u, p, t; kwargs...)
push!(exprs, :(L.ops[$i](w, L.cache[1], u, p, t; kwargs...)))
else
# Middle operators: L.ops[i](L.cache[i-1], L.cache[i], u, p, t; kwargs...)
push!(exprs, :(L.ops[$i](L.cache[$(i - 1)], L.cache[$i], u, p, t; kwargs...)))
end
end

vecs = (w, L.cache[1:(end - 1)]..., v)
for i in reverse(1:length(L.ops))
L.ops[i](vecs[i], vecs[i + 1], u, p, t; kwargs...)
quote
update_coefficients!(L, u, p, t; kwargs...)
@assert iscached(L) "Cache needs to be set up for ComposedOperator. Call cache_operator(L, u) first."

$(exprs...)
w
end
w
end

# In-place with scaling: w = α*(L*v) + β*w
Expand Down
107 changes: 105 additions & 2 deletions test/downstream/alloccheck.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
using SciMLOperators, Random, SparseArrays, Test
using SciMLOperators, Random, SparseArrays, Test, LinearAlgebra
using SciMLOperators: IdentityOperator,
NullOperator,
ScaledOperator,
AddedOperator
AddedOperator,
ComposedOperator,
cache_operator

function apply_op!(H, w, v, u, p, t)
H(w, v, u, p, t)
Expand Down Expand Up @@ -64,4 +66,105 @@ test_apply_noalloc(H, w, v, u, p, t) = @test (@allocations apply_op!(H, w, v, u,
test_apply_noalloc(H_sparse, w, v, u, p, t)
test_apply_noalloc(H_dense, w, v, u, p, t)
end

# Test ComposedOperator allocations (PR #316)
# Before the fix, tuple splatting caused many allocations.
# After the fix, we should have minimal allocations (Julia 1.11 has 1, earlier versions have 0).
@testset "ComposedOperator minimal allocations" begin
N = 100

# Create operators for composition
A1 = MatrixOperator(rand(N, N))
A2 = MatrixOperator(rand(N, N))
A3 = MatrixOperator(rand(N, N))

# Create ComposedOperator
L = A1 * A2 * A3

# Set up cache
v = rand(N)
w = similar(v)
L = cache_operator(L, v)

u = rand(N)
p = nothing
t = 0.0

# Warm up
mul!(w, L, v)
L(w, v, u, p, t)

# Test mul! - should have minimal allocations
# Julia 1.11 has a known minor allocation issue (1 allocation)
# Earlier versions should have 0 allocations
allocs_mul = @allocations mul!(w, L, v)
@test allocs_mul <= 1

# Test operator call - should have minimal allocations
allocs_call = @allocations L(w, v, u, p, t)
@test allocs_call <= 1

# Test with matrices
K = 5
V = rand(N, K)
W = similar(V)
L_mat = cache_operator(A1 * A2 * A3, V)

# Warm up
mul!(W, L_mat, V)
L_mat(W, V, u, p, t)

# Test with matrices - should have minimal allocations
allocs_mul_mat = @allocations mul!(W, L_mat, V)
@test allocs_mul_mat <= 1

allocs_call_mat = @allocations L_mat(W, V, u, p, t)
@test allocs_call_mat <= 1
end

# Test accepted_kwargs allocations (PR #313)
# With Val(tuple), kwarg filtering should be compile-time with minimal allocations
@testset "accepted_kwargs with Val" begin
N = 50

# Create a MatrixOperator with accepted_kwargs using Val for compile-time filtering
J = rand(N, N)

update_func! = (M, u, p, t; dtgamma = 1.0) -> begin
M .= dtgamma .* J
nothing
end

op = MatrixOperator(
copy(J);
update_func! = update_func!,
accepted_kwargs = Val((:dtgamma,)) # Use Val for compile-time filtering
)

u = rand(N)
p = nothing
t = 0.0

# Warm up
update_coefficients!(op, u, p, t; dtgamma = 0.5)

# Test that update_coefficients! with accepted_kwargs has minimal allocations
# The Val approach significantly reduces allocations compared to plain tuples
allocs_update = @allocations update_coefficients!(op, u, p, t; dtgamma = 0.5)
@test allocs_update <= 6 # Some allocations may occur due to Julia version/kwarg handling

# Test with different dtgamma values - should have similar behavior
allocs_update2 = @allocations update_coefficients!(op, u, p, t; dtgamma = 1.0)
@test allocs_update2 <= 6

allocs_update3 = @allocations update_coefficients!(op, u, p, t; dtgamma = 2.0)
@test allocs_update3 <= 6

# Test operator application after update
v = rand(N)
w = similar(v)
op(w, v, u, p, t; dtgamma = 0.5) # Warm up
allocs_call = @allocations op(w, v, u, p, t; dtgamma = 0.5)
@test allocs_call <= 6
end
end
Loading