Skip to content

Commit 2d316cc

Browse files
Add generated functions for AddedOperator
1 parent 37194cf commit 2d316cc

File tree

1 file changed

+28
-14
lines changed

1 file changed

+28
-14
lines changed

src/basic.jl

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -616,33 +616,47 @@ end
616616
w
617617
end
618618
end
619+
619620
# Out-of-place: v is action vector, u is update vector
620621
function (L::AddedOperator)(v::AbstractVecOrMat, u, p, t; kwargs...)
621-
L = update_coefficients(L, u, p, t; kwargs...)
622+
# We don't need to update coefficients of L, as op(v, u, p, t) will do it for each op
622623
sum(op -> iszero(op) ? zero(v) : op(v, u, p, t; kwargs...), L.ops)
623624
end
624625

625626
# In-place: w is destination, v is action vector, u is update vector
626-
function (L::AddedOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...)
627-
update_coefficients!(L, u, p, t; kwargs...)
628-
for op in L.ops
629-
if !iszero(op)
630-
op(w, v, u, p, t, 1.0, 1.0; kwargs...)
627+
@generated function (L::AddedOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...)
628+
# We don't need to update coefficients of L, as op(w, v, u, p, t) will do it for each op
629+
630+
T = L.parameters[1]
631+
ops_types = L.parameters[2].parameters
632+
N = length(ops_types)-1
633+
634+
quote
635+
L.ops[1](w, v, u, p, t; kwargs...)
636+
Base.@nexprs $N i->begin
637+
op = L.ops[i+1]
638+
op(w, v, u, p, t, one($T), one($T); kwargs...)
631639
end
640+
w
632641
end
633-
w
634642
end
635643

636644
# In-place with scaling: w = α*(L*v) + β*w
637-
function (L::AddedOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...)
638-
update_coefficients!(L, u, p, t; kwargs...)
639-
lmul!(β, w)
640-
for op in L.ops
641-
if !iszero(op)
642-
op(w, v, u, p, t, α, 1.0; kwargs...)
645+
@generated function (L::AddedOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...)
646+
# We don't need to update coefficients of L, as op(w, v, u, p, t) will do it for each op
647+
648+
T = L.parameters[1]
649+
ops_types = L.parameters[2].parameters
650+
N = length(ops_types)-1
651+
652+
quote
653+
L.ops[1](w, v, u, p, t, α, β; kwargs...)
654+
Base.@nexprs $N i->begin
655+
op = L.ops[i+1]
656+
op(w, v, u, p, t, α, one($T); kwargs...)
643657
end
658+
w
644659
end
645-
w
646660
end
647661

648662
"""

0 commit comments

Comments
 (0)