Skip to content

Commit e995a85

Browse files
Fix allocation issues in ComposedOperator by using @generated functions
Convert mul!, ldiv!, and operator call functions for ComposedOperator to use @generated functions instead of runtime tuple splatting. This eliminates GC-triggering allocations that were causing performance regressions. The issue was that these functions created tuples dynamically: - `vecs = (w, L.cache[1:(end - 1)]..., v)` in mul! and operator calls - `vecs = (v, reverse(L.cache[1:(end - 1)])..., w)` in ldiv! These splatting operations allocated memory on every call, triggering frequent GC and degrading performance. The @generated versions generate specialized code at compile time that directly references cache elements without intermediate allocations, following the pattern established for AddedOperator. Fixes #315 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 572d354 commit e995a85

File tree

1 file changed

+82
-21
lines changed

1 file changed

+82
-21
lines changed

src/basic.jl

Lines changed: 82 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -887,15 +887,35 @@ function cache_internals(L::ComposedOperator, v::AbstractVecOrMat)
887887
@reset L.ops = ops
888888
end
889889

890-
function LinearAlgebra.mul!(w::AbstractVecOrMat, L::ComposedOperator, v::AbstractVecOrMat)
891-
@assert iscached(L) """cache needs to be set up for operator of type
892-
$L. Set up cache by calling `cache_operator(L, v)`"""
890+
@generated function LinearAlgebra.mul!(w::AbstractVecOrMat, L::ComposedOperator, v::AbstractVecOrMat)
891+
N = length(L.parameters[2].parameters) # Number of operators
892+
893+
# Generate the mul! calls in reverse order
894+
# vecs conceptually is (w, L.cache[1], L.cache[2], ..., L.cache[N-1], v)
895+
# For i in reverse(1:N):
896+
# mul!(vecs[i], L.ops[i], vecs[i+1])
897+
898+
exprs = []
899+
for i in N:-1:1
900+
if i == N
901+
# Last operator: mul!(L.cache[N-1], L.ops[N], v)
902+
push!(exprs, :(mul!(L.cache[$(N-1)], L.ops[$i], v)))
903+
elseif i == 1
904+
# First operator: mul!(w, L.ops[1], L.cache[1])
905+
push!(exprs, :(mul!(w, L.ops[$i], L.cache[1])))
906+
else
907+
# Middle operators: mul!(L.cache[i-1], L.ops[i], L.cache[i])
908+
push!(exprs, :(mul!(L.cache[$(i-1)], L.ops[$i], L.cache[$i])))
909+
end
910+
end
893911

894-
vecs = (w, L.cache[1:(end - 1)]..., v)
895-
for i in reverse(1:length(L.ops))
896-
mul!(vecs[i], L.ops[i], vecs[i + 1])
912+
quote
913+
@assert iscached(L) """cache needs to be set up for operator of type
914+
$L. Set up cache by calling `cache_operator(L, v)`"""
915+
916+
$(exprs...)
917+
w
897918
end
898-
w
899919
end
900920

901921
function LinearAlgebra.mul!(w::AbstractVecOrMat,
@@ -914,15 +934,36 @@ function LinearAlgebra.mul!(w::AbstractVecOrMat,
914934
axpy!(β, cache, w)
915935
end
916936

917-
function LinearAlgebra.ldiv!(w::AbstractVecOrMat, L::ComposedOperator, v::AbstractVecOrMat)
918-
@assert iscached(L) """cache needs to be set up for operator of type
919-
$L. Set up cache by calling `cache_operator(L, v)`."""
937+
@generated function LinearAlgebra.ldiv!(w::AbstractVecOrMat, L::ComposedOperator, v::AbstractVecOrMat)
938+
N = length(L.parameters[2].parameters) # Number of operators
920939

921-
vecs = (v, reverse(L.cache[1:(end - 1)])..., w)
922-
for i in 1:length(L.ops)
923-
ldiv!(vecs[i + 1], L.ops[i], vecs[i])
940+
# Generate the ldiv! calls in forward order
941+
# vecs conceptually is (v, reverse(L.cache[1:(N-1)])..., w)
942+
# = (v, L.cache[N-1], L.cache[N-2], ..., L.cache[1], w)
943+
# For i in 1:N:
944+
# ldiv!(vecs[i+1], L.ops[i], vecs[i])
945+
946+
exprs = []
947+
for i in 1:N
948+
if i == 1
949+
# First operator: ldiv!(L.cache[N-1], L.ops[1], v)
950+
push!(exprs, :(ldiv!(L.cache[$(N-1)], L.ops[$i], v)))
951+
elseif i == N
952+
# Last operator: ldiv!(w, L.ops[N], L.cache[1])
953+
push!(exprs, :(ldiv!(w, L.ops[$i], L.cache[1])))
954+
else
955+
# Middle operators: ldiv!(L.cache[N-i], L.ops[i], L.cache[N-i+1])
956+
push!(exprs, :(ldiv!(L.cache[$(N-i)], L.ops[$i], L.cache[$(N-i+1)])))
957+
end
958+
end
959+
960+
quote
961+
@assert iscached(L) """cache needs to be set up for operator of type
962+
$L. Set up cache by calling `cache_operator(L, v)`."""
963+
964+
$(exprs...)
965+
w
924966
end
925-
w
926967
end
927968

928969
function LinearAlgebra.ldiv!(L::ComposedOperator, v::AbstractVecOrMat)
@@ -943,15 +984,35 @@ function (L::ComposedOperator)(v::AbstractVecOrMat, u, p, t; kwargs...)
943984
end
944985

945986
# In-place: w is destination, v is action vector, u is update vector
946-
function (L::ComposedOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...)
947-
update_coefficients!(L, u, p, t; kwargs...)
948-
@assert iscached(L) "Cache needs to be set up for ComposedOperator. Call cache_operator(L, u) first."
987+
@generated function (L::ComposedOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...)
988+
N = length(L.parameters[2].parameters) # Number of operators
989+
990+
# Generate the operator call expressions in reverse order
991+
# vecs conceptually is (w, L.cache[1], L.cache[2], ..., L.cache[N-1], v)
992+
# For i in reverse(1:N):
993+
# L.ops[i](vecs[i], vecs[i+1], u, p, t; kwargs...)
994+
995+
exprs = []
996+
for i in N:-1:1
997+
if i == N
998+
# Last operator: L.ops[N](L.cache[N-1], v, u, p, t; kwargs...)
999+
push!(exprs, :(L.ops[$i](L.cache[$(N-1)], v, u, p, t; kwargs...)))
1000+
elseif i == 1
1001+
# First operator: L.ops[1](w, L.cache[1], u, p, t; kwargs...)
1002+
push!(exprs, :(L.ops[$i](w, L.cache[1], u, p, t; kwargs...)))
1003+
else
1004+
# Middle operators: L.ops[i](L.cache[i-1], L.cache[i], u, p, t; kwargs...)
1005+
push!(exprs, :(L.ops[$i](L.cache[$(i-1)], L.cache[$i], u, p, t; kwargs...)))
1006+
end
1007+
end
9491008

950-
vecs = (w, L.cache[1:(end - 1)]..., v)
951-
for i in reverse(1:length(L.ops))
952-
L.ops[i](vecs[i], vecs[i + 1], u, p, t; kwargs...)
1009+
quote
1010+
update_coefficients!(L, u, p, t; kwargs...)
1011+
@assert iscached(L) "Cache needs to be set up for ComposedOperator. Call cache_operator(L, u) first."
1012+
1013+
$(exprs...)
1014+
w
9531015
end
954-
w
9551016
end
9561017

9571018
# In-place with scaling: w = α*(L*v) + β*w

0 commit comments

Comments
 (0)