diff --git a/src/basic.jl b/src/basic.jl index 0ccc7d35..91899f08 100644 --- a/src/basic.jl +++ b/src/basic.jl @@ -77,16 +77,14 @@ function (ii::IdentityOperator)(v::AbstractVecOrMat, u, p, t; kwargs...) end # In-place: w is destination, v is action vector, u is update vector -function (ii::IdentityOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...) +@inline function (ii::IdentityOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...) @assert size(v, 1) == ii.len - update_coefficients!(ii, u, p, t; kwargs...) copy!(w, v) end # In-place with scaling: w = α*(ii*v) + β*w -function (ii::IdentityOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...) +@inline function (ii::IdentityOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...) @assert size(v, 1) == ii.len - update_coefficients!(ii, u, p, t; kwargs...) mul!(w, I, v, α, β) end @@ -388,29 +386,17 @@ function (L::ScaledOperator)(v::AbstractVecOrMat, u, p, t; kwargs...) end # In-place: w is destination, v is action vector, u is update vector -function (L::ScaledOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...) - update_coefficients!(L, u, p, t; kwargs...) - if iszero(L.λ) - lmul!(false, w) - return w - else - a = convert(Number, L.λ) - mul!(w, L.L, v, a, false) - return w - end +@inline function (L::ScaledOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...) + update_coefficients!(L.λ, u, p, t; kwargs...) + a = convert(Number, L.λ) + return L.L(w, v, u, p, t, a, false; kwargs...) end # In-place with scaling: w = α*(L*v) + β*w -function (L::ScaledOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...) - update_coefficients!(L, u, p, t; kwargs...) - if iszero(L.λ) - lmul!(β, w) - return w - else - a = convert(Number, L.λ * α) - mul!(w, L.L, v, a, β) - return w - end +@inline function (L::ScaledOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...) + update_coefficients!(L.λ, u, p, t; kwargs...) + a = convert(Number, L.λ * α) + return L.L(w, v, u, p, t, a, β; kwargs...) end