@@ -887,15 +887,35 @@ function cache_internals(L::ComposedOperator, v::AbstractVecOrMat)
887887 @reset L. ops = ops
888888end
889889
890- function LinearAlgebra. mul! (w:: AbstractVecOrMat , L:: ComposedOperator , v:: AbstractVecOrMat )
891- @assert iscached (L) """ cache needs to be set up for operator of type
892- $L . Set up cache by calling `cache_operator(L, v)`"""
890+ @generated function LinearAlgebra. mul! (w:: AbstractVecOrMat , L:: ComposedOperator , v:: AbstractVecOrMat )
891+ N = length (L. parameters[2 ]. parameters) # Number of operators
892+
893+ # Generate the mul! calls in reverse order
894+ # vecs conceptually is (w, L.cache[1], L.cache[2], ..., L.cache[N-1], v)
895+ # For i in reverse(1:N):
896+ # mul!(vecs[i], L.ops[i], vecs[i+1])
897+
898+ exprs = []
899+ for i in N: - 1 : 1
900+ if i == N
901+ # Last operator: mul!(L.cache[N-1], L.ops[N], v)
902+ push! (exprs, :(mul! (L. cache[$ (N - 1 )], L. ops[$ i], v)))
903+ elseif i == 1
904+ # First operator: mul!(w, L.ops[1], L.cache[1])
905+ push! (exprs, :(mul! (w, L. ops[$ i], L. cache[1 ])))
906+ else
907+ # Middle operators: mul!(L.cache[i-1], L.ops[i], L.cache[i])
908+ push! (exprs, :(mul! (L. cache[$ (i - 1 )], L. ops[$ i], L. cache[$ i])))
909+ end
910+ end
893911
894- vecs = (w, L. cache[1 : (end - 1 )]. .. , v)
895- for i in reverse (1 : length (L. ops))
896- mul! (vecs[i], L. ops[i], vecs[i + 1 ])
912+ quote
913+ @assert iscached (L) """ cache needs to be set up for operator of type
914+ $L . Set up cache by calling `cache_operator(L, v)`"""
915+
916+ $ (exprs... )
917+ w
897918 end
898- w
899919end
900920
901921function LinearAlgebra. mul! (w:: AbstractVecOrMat ,
@@ -914,15 +934,36 @@ function LinearAlgebra.mul!(w::AbstractVecOrMat,
914934 axpy! (β, cache, w)
915935end
916936
917- function LinearAlgebra. ldiv! (w:: AbstractVecOrMat , L:: ComposedOperator , v:: AbstractVecOrMat )
918- @assert iscached (L) """ cache needs to be set up for operator of type
919- $L . Set up cache by calling `cache_operator(L, v)`."""
937+ @generated function LinearAlgebra. ldiv! (w:: AbstractVecOrMat , L:: ComposedOperator , v:: AbstractVecOrMat )
938+ N = length (L. parameters[2 ]. parameters) # Number of operators
920939
921- vecs = (v, reverse (L. cache[1 : (end - 1 )])... , w)
922- for i in 1 : length (L. ops)
923- ldiv! (vecs[i + 1 ], L. ops[i], vecs[i])
940+ # Generate the ldiv! calls in forward order
941+ # vecs conceptually is (v, reverse(L.cache[1:(N-1)])..., w)
942+ # = (v, L.cache[N-1], L.cache[N-2], ..., L.cache[1], w)
943+ # For i in 1:N:
944+ # ldiv!(vecs[i+1], L.ops[i], vecs[i])
945+
946+ exprs = []
947+ for i in 1 : N
948+ if i == 1
949+ # First operator: ldiv!(L.cache[N-1], L.ops[1], v)
950+ push! (exprs, :(ldiv! (L. cache[$ (N - 1 )], L. ops[$ i], v)))
951+ elseif i == N
952+ # Last operator: ldiv!(w, L.ops[N], L.cache[1])
953+ push! (exprs, :(ldiv! (w, L. ops[$ i], L. cache[1 ])))
954+ else
955+ # Middle operators: ldiv!(L.cache[N-i], L.ops[i], L.cache[N-i+1])
956+ push! (exprs, :(ldiv! (L. cache[$ (N - i)], L. ops[$ i], L. cache[$ (N - i + 1 )])))
957+ end
958+ end
959+
960+ quote
961+ @assert iscached (L) """ cache needs to be set up for operator of type
962+ $L . Set up cache by calling `cache_operator(L, v)`."""
963+
964+ $ (exprs... )
965+ w
924966 end
925- w
926967end
927968
928969function LinearAlgebra. ldiv! (L:: ComposedOperator , v:: AbstractVecOrMat )
@@ -943,15 +984,36 @@ function (L::ComposedOperator)(v::AbstractVecOrMat, u, p, t; kwargs...)
943984end
944985
945986# In-place: w is destination, v is action vector, u is update vector
946- function (L:: ComposedOperator )(w:: AbstractVecOrMat , v:: AbstractVecOrMat , u, p, t; kwargs... )
947- update_coefficients! (L, u, p, t; kwargs... )
948- @assert iscached (L) " Cache needs to be set up for ComposedOperator. Call cache_operator(L, u) first."
987+ @generated function (L:: ComposedOperator )(
988+ w:: AbstractVecOrMat , v:: AbstractVecOrMat , u, p, t; kwargs... )
989+ N = length (L. parameters[2 ]. parameters) # Number of operators
990+
991+ # Generate the operator call expressions in reverse order
992+ # vecs conceptually is (w, L.cache[1], L.cache[2], ..., L.cache[N-1], v)
993+ # For i in reverse(1:N):
994+ # L.ops[i](vecs[i], vecs[i+1], u, p, t; kwargs...)
995+
996+ exprs = []
997+ for i in N: - 1 : 1
998+ if i == N
999+ # Last operator: L.ops[N](L.cache[N-1], v, u, p, t; kwargs...)
1000+ push! (exprs, :(L. ops[$ i](L. cache[$ (N - 1 )], v, u, p, t; kwargs... )))
1001+ elseif i == 1
1002+ # First operator: L.ops[1](w, L.cache[1], u, p, t; kwargs...)
1003+ push! (exprs, :(L. ops[$ i](w, L. cache[1 ], u, p, t; kwargs... )))
1004+ else
1005+ # Middle operators: L.ops[i](L.cache[i-1], L.cache[i], u, p, t; kwargs...)
1006+ push! (exprs, :(L. ops[$ i](L. cache[$ (i - 1 )], L. cache[$ i], u, p, t; kwargs... )))
1007+ end
1008+ end
9491009
950- vecs = (w, L. cache[1 : (end - 1 )]. .. , v)
951- for i in reverse (1 : length (L. ops))
952- L. ops[i](vecs[i], vecs[i + 1 ], u, p, t; kwargs... )
1010+ quote
1011+ update_coefficients! (L, u, p, t; kwargs... )
1012+ @assert iscached (L) " Cache needs to be set up for ComposedOperator. Call cache_operator(L, u) first."
1013+
1014+ $ (exprs... )
1015+ w
9531016 end
954- w
9551017end
9561018
9571019# In-place with scaling: w = α*(L*v) + β*w
0 commit comments