@@ -616,33 +616,47 @@ end
616616 w
617617 end
618618end
619+
619620# Out-of-place: v is action vector, u is update vector
620621function (L:: AddedOperator )(v:: AbstractVecOrMat , u, p, t; kwargs... )
621- L = update_coefficients (L , u, p, t; kwargs ... )
622+ # We don't need to update coefficients of L, as op(v , u, p, t) will do it for each op
622623 sum (op -> iszero (op) ? zero (v) : op (v, u, p, t; kwargs... ), L. ops)
623624end
624625
625626# In-place: w is destination, v is action vector, u is update vector
626- function (L:: AddedOperator )(w:: AbstractVecOrMat , v:: AbstractVecOrMat , u, p, t; kwargs... )
627- update_coefficients! (L, u, p, t; kwargs... )
628- for op in L. ops
629- if ! iszero (op)
630- op (w, v, u, p, t, 1.0 , 1.0 ; kwargs... )
627+ @generated function (L:: AddedOperator )(w:: AbstractVecOrMat , v:: AbstractVecOrMat , u, p, t; kwargs... )
628+ # We don't need to update coefficients of L, as op(w, v, u, p, t) will do it for each op
629+
630+ T = L. parameters[1 ]
631+ ops_types = L. parameters[2 ]. parameters
632+ N = length (ops_types)- 1
633+
634+ quote
635+ L. ops[1 ](w, v, u, p, t; kwargs... )
636+ Base. @nexprs $ N i-> begin
637+ op = L. ops[i+ 1 ]
638+ op (w, v, u, p, t, one ($ T), one ($ T); kwargs... )
631639 end
640+ w
632641 end
633- w
634642end
635643
636644# In-place with scaling: w = α*(L*v) + β*w
637- function (L:: AddedOperator )(w:: AbstractVecOrMat , v:: AbstractVecOrMat , u, p, t, α, β; kwargs... )
638- update_coefficients! (L, u, p, t; kwargs... )
639- lmul! (β, w)
640- for op in L. ops
641- if ! iszero (op)
642- op (w, v, u, p, t, α, 1.0 ; kwargs... )
645+ @generated function (L:: AddedOperator )(w:: AbstractVecOrMat , v:: AbstractVecOrMat , u, p, t, α, β; kwargs... )
646+ # We don't need to update coefficients of L, as op(w, v, u, p, t) will do it for each op
647+
648+ T = L. parameters[1 ]
649+ ops_types = L. parameters[2 ]. parameters
650+ N = length (ops_types)- 1
651+
652+ quote
653+ L. ops[1 ](w, v, u, p, t, α, β; kwargs... )
654+ Base. @nexprs $ N i-> begin
655+ op = L. ops[i+ 1 ]
656+ op (w, v, u, p, t, α, one ($ T); kwargs... )
643657 end
658+ w
644659 end
645- w
646660end
647661
648662"""
0 commit comments