Skip to content

Commit 641d72d

Browse files
Merge pull request #316 from ChrisRackauckas-Claude/fix-composed-operator-allocations
Fix allocation issues in ComposedOperator by using @generated functions
2 parents 572d354 + b6f2456 commit 641d72d

File tree

2 files changed

+188
-23
lines changed

2 files changed

+188
-23
lines changed

src/basic.jl

Lines changed: 83 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -887,15 +887,35 @@ function cache_internals(L::ComposedOperator, v::AbstractVecOrMat)
887887
@reset L.ops = ops
888888
end
889889

890-
function LinearAlgebra.mul!(w::AbstractVecOrMat, L::ComposedOperator, v::AbstractVecOrMat)
891-
@assert iscached(L) """cache needs to be set up for operator of type
892-
$L. Set up cache by calling `cache_operator(L, v)`"""
890+
@generated function LinearAlgebra.mul!(w::AbstractVecOrMat, L::ComposedOperator, v::AbstractVecOrMat)
891+
N = length(L.parameters[2].parameters) # Number of operators
892+
893+
# Generate the mul! calls in reverse order
894+
# vecs conceptually is (w, L.cache[1], L.cache[2], ..., L.cache[N-1], v)
895+
# For i in reverse(1:N):
896+
# mul!(vecs[i], L.ops[i], vecs[i+1])
897+
898+
exprs = []
899+
for i in N:-1:1
900+
if i == N
901+
# Last operator: mul!(L.cache[N-1], L.ops[N], v)
902+
push!(exprs, :(mul!(L.cache[$(N - 1)], L.ops[$i], v)))
903+
elseif i == 1
904+
# First operator: mul!(w, L.ops[1], L.cache[1])
905+
push!(exprs, :(mul!(w, L.ops[$i], L.cache[1])))
906+
else
907+
# Middle operators: mul!(L.cache[i-1], L.ops[i], L.cache[i])
908+
push!(exprs, :(mul!(L.cache[$(i - 1)], L.ops[$i], L.cache[$i])))
909+
end
910+
end
893911

894-
vecs = (w, L.cache[1:(end - 1)]..., v)
895-
for i in reverse(1:length(L.ops))
896-
mul!(vecs[i], L.ops[i], vecs[i + 1])
912+
quote
913+
@assert iscached(L) """cache needs to be set up for operator of type
914+
$L. Set up cache by calling `cache_operator(L, v)`"""
915+
916+
$(exprs...)
917+
w
897918
end
898-
w
899919
end
900920

901921
function LinearAlgebra.mul!(w::AbstractVecOrMat,
@@ -914,15 +934,36 @@ function LinearAlgebra.mul!(w::AbstractVecOrMat,
914934
axpy!(β, cache, w)
915935
end
916936

917-
function LinearAlgebra.ldiv!(w::AbstractVecOrMat, L::ComposedOperator, v::AbstractVecOrMat)
918-
@assert iscached(L) """cache needs to be set up for operator of type
919-
$L. Set up cache by calling `cache_operator(L, v)`."""
937+
@generated function LinearAlgebra.ldiv!(w::AbstractVecOrMat, L::ComposedOperator, v::AbstractVecOrMat)
938+
N = length(L.parameters[2].parameters) # Number of operators
920939

921-
vecs = (v, reverse(L.cache[1:(end - 1)])..., w)
922-
for i in 1:length(L.ops)
923-
ldiv!(vecs[i + 1], L.ops[i], vecs[i])
940+
# Generate the ldiv! calls in forward order
941+
# vecs conceptually is (v, reverse(L.cache[1:(N-1)])..., w)
942+
# = (v, L.cache[N-1], L.cache[N-2], ..., L.cache[1], w)
943+
# For i in 1:N:
944+
# ldiv!(vecs[i+1], L.ops[i], vecs[i])
945+
946+
exprs = []
947+
for i in 1:N
948+
if i == 1
949+
# First operator: ldiv!(L.cache[N-1], L.ops[1], v)
950+
push!(exprs, :(ldiv!(L.cache[$(N - 1)], L.ops[$i], v)))
951+
elseif i == N
952+
# Last operator: ldiv!(w, L.ops[N], L.cache[1])
953+
push!(exprs, :(ldiv!(w, L.ops[$i], L.cache[1])))
954+
else
955+
# Middle operators: ldiv!(L.cache[N-i], L.ops[i], L.cache[N-i+1])
956+
push!(exprs, :(ldiv!(L.cache[$(N - i)], L.ops[$i], L.cache[$(N - i + 1)])))
957+
end
958+
end
959+
960+
quote
961+
@assert iscached(L) """cache needs to be set up for operator of type
962+
$L. Set up cache by calling `cache_operator(L, v)`."""
963+
964+
$(exprs...)
965+
w
924966
end
925-
w
926967
end
927968

928969
function LinearAlgebra.ldiv!(L::ComposedOperator, v::AbstractVecOrMat)
@@ -943,15 +984,36 @@ function (L::ComposedOperator)(v::AbstractVecOrMat, u, p, t; kwargs...)
943984
end
944985

945986
# In-place: w is destination, v is action vector, u is update vector
946-
function (L::ComposedOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...)
947-
update_coefficients!(L, u, p, t; kwargs...)
948-
@assert iscached(L) "Cache needs to be set up for ComposedOperator. Call cache_operator(L, u) first."
987+
@generated function (L::ComposedOperator)(
988+
w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...)
989+
N = length(L.parameters[2].parameters) # Number of operators
990+
991+
# Generate the operator call expressions in reverse order
992+
# vecs conceptually is (w, L.cache[1], L.cache[2], ..., L.cache[N-1], v)
993+
# For i in reverse(1:N):
994+
# L.ops[i](vecs[i], vecs[i+1], u, p, t; kwargs...)
995+
996+
exprs = []
997+
for i in N:-1:1
998+
if i == N
999+
# Last operator: L.ops[N](L.cache[N-1], v, u, p, t; kwargs...)
1000+
push!(exprs, :(L.ops[$i](L.cache[$(N - 1)], v, u, p, t; kwargs...)))
1001+
elseif i == 1
1002+
# First operator: L.ops[1](w, L.cache[1], u, p, t; kwargs...)
1003+
push!(exprs, :(L.ops[$i](w, L.cache[1], u, p, t; kwargs...)))
1004+
else
1005+
# Middle operators: L.ops[i](L.cache[i-1], L.cache[i], u, p, t; kwargs...)
1006+
push!(exprs, :(L.ops[$i](L.cache[$(i - 1)], L.cache[$i], u, p, t; kwargs...)))
1007+
end
1008+
end
9491009

950-
vecs = (w, L.cache[1:(end - 1)]..., v)
951-
for i in reverse(1:length(L.ops))
952-
L.ops[i](vecs[i], vecs[i + 1], u, p, t; kwargs...)
1010+
quote
1011+
update_coefficients!(L, u, p, t; kwargs...)
1012+
@assert iscached(L) "Cache needs to be set up for ComposedOperator. Call cache_operator(L, u) first."
1013+
1014+
$(exprs...)
1015+
w
9531016
end
954-
w
9551017
end
9561018

9571019
# In-place with scaling: w = α*(L*v) + β*w

test/downstream/alloccheck.jl

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
using SciMLOperators, Random, SparseArrays, Test
1+
using SciMLOperators, Random, SparseArrays, Test, LinearAlgebra
22
using SciMLOperators: IdentityOperator,
33
NullOperator,
44
ScaledOperator,
5-
AddedOperator
5+
AddedOperator,
6+
ComposedOperator,
7+
cache_operator
68

79
function apply_op!(H, w, v, u, p, t)
810
H(w, v, u, p, t)
@@ -64,4 +66,105 @@ test_apply_noalloc(H, w, v, u, p, t) = @test (@allocations apply_op!(H, w, v, u,
6466
test_apply_noalloc(H_sparse, w, v, u, p, t)
6567
test_apply_noalloc(H_dense, w, v, u, p, t)
6668
end
69+
70+
# Test ComposedOperator allocations (PR #316)
71+
# Before the fix, tuple splatting caused many allocations.
72+
# After the fix, we should have minimal allocations (Julia 1.11 has 1, earlier versions have 0).
73+
@testset "ComposedOperator minimal allocations" begin
74+
N = 100
75+
76+
# Create operators for composition
77+
A1 = MatrixOperator(rand(N, N))
78+
A2 = MatrixOperator(rand(N, N))
79+
A3 = MatrixOperator(rand(N, N))
80+
81+
# Create ComposedOperator
82+
L = A1 * A2 * A3
83+
84+
# Set up cache
85+
v = rand(N)
86+
w = similar(v)
87+
L = cache_operator(L, v)
88+
89+
u = rand(N)
90+
p = nothing
91+
t = 0.0
92+
93+
# Warm up
94+
mul!(w, L, v)
95+
L(w, v, u, p, t)
96+
97+
# Test mul! - should have minimal allocations
98+
# Julia 1.11 has a known minor allocation issue (1 allocation)
99+
# Earlier versions should have 0 allocations
100+
allocs_mul = @allocations mul!(w, L, v)
101+
@test allocs_mul <= 1
102+
103+
# Test operator call - should have minimal allocations
104+
allocs_call = @allocations L(w, v, u, p, t)
105+
@test allocs_call <= 1
106+
107+
# Test with matrices
108+
K = 5
109+
V = rand(N, K)
110+
W = similar(V)
111+
L_mat = cache_operator(A1 * A2 * A3, V)
112+
113+
# Warm up
114+
mul!(W, L_mat, V)
115+
L_mat(W, V, u, p, t)
116+
117+
# Test with matrices - should have minimal allocations
118+
allocs_mul_mat = @allocations mul!(W, L_mat, V)
119+
@test allocs_mul_mat <= 1
120+
121+
allocs_call_mat = @allocations L_mat(W, V, u, p, t)
122+
@test allocs_call_mat <= 1
123+
end
124+
125+
# Test accepted_kwargs allocations (PR #313)
126+
# With Val(tuple), kwarg filtering should be compile-time with minimal allocations
127+
@testset "accepted_kwargs with Val" begin
128+
N = 50
129+
130+
# Create a MatrixOperator with accepted_kwargs using Val for compile-time filtering
131+
J = rand(N, N)
132+
133+
update_func! = (M, u, p, t; dtgamma = 1.0) -> begin
134+
M .= dtgamma .* J
135+
nothing
136+
end
137+
138+
op = MatrixOperator(
139+
copy(J);
140+
update_func! = update_func!,
141+
accepted_kwargs = Val((:dtgamma,)) # Use Val for compile-time filtering
142+
)
143+
144+
u = rand(N)
145+
p = nothing
146+
t = 0.0
147+
148+
# Warm up
149+
update_coefficients!(op, u, p, t; dtgamma = 0.5)
150+
151+
# Test that update_coefficients! with accepted_kwargs has minimal allocations
152+
# The Val approach significantly reduces allocations compared to plain tuples
153+
allocs_update = @allocations update_coefficients!(op, u, p, t; dtgamma = 0.5)
154+
@test allocs_update <= 6 # Some allocations may occur due to Julia version/kwarg handling
155+
156+
# Test with different dtgamma values - should have similar behavior
157+
allocs_update2 = @allocations update_coefficients!(op, u, p, t; dtgamma = 1.0)
158+
@test allocs_update2 <= 6
159+
160+
allocs_update3 = @allocations update_coefficients!(op, u, p, t; dtgamma = 2.0)
161+
@test allocs_update3 <= 6
162+
163+
# Test operator application after update
164+
v = rand(N)
165+
w = similar(v)
166+
op(w, v, u, p, t; dtgamma = 0.5) # Warm up
167+
allocs_call = @allocations op(w, v, u, p, t; dtgamma = 0.5)
168+
@test allocs_call <= 6
169+
end
67170
end

0 commit comments

Comments
 (0)