Skip to content

Commit 1aab2cc

Browse files
Other improvements on AddedOperator
1 parent 1383b7b commit 1aab2cc

File tree

1 file changed

+27
-2
lines changed

1 file changed

+27
-2
lines changed

src/basic.jl

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,13 +288,13 @@ end
288288
Base.:*(L::ScaledOperator, u::AbstractVecOrMat) = L.λ * (L.L * u)
289289
Base.:\(L::ScaledOperator, u::AbstractVecOrMat) = L.λ \ (L.L \ u)
290290

291-
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::ScaledOperator, u::AbstractVecOrMat)
291+
@inline function LinearAlgebra.mul!(v::AbstractVecOrMat, L::ScaledOperator, u::AbstractVecOrMat)
292292
iszero(L.λ) && return lmul!(false, v)
293293
a = convert(Number, L.λ)
294294
mul!(v, L.L, u, a, false)
295295
end
296296

297-
function LinearAlgebra.mul!(v::AbstractVecOrMat,
297+
@inline function LinearAlgebra.mul!(v::AbstractVecOrMat,
298298
L::ScaledOperator,
299299
u::AbstractVecOrMat,
300300
α,
@@ -386,6 +386,23 @@ for op in (:+, :-)
386386
end
387387
end
388388

389+
for T in SCALINGNUMBERTYPES[2:end]
390+
@eval function Base.:*::$T, L::AddedOperator)
391+
ops = map(op -> λ * op, L.ops)
392+
AddedOperator(ops)
393+
end
394+
395+
@eval function Base.:*(L::AddedOperator, λ::$T)
396+
ops = map(op -> λ * op, L.ops)
397+
AddedOperator(ops)
398+
end
399+
400+
@eval function Base.:/(L::AddedOperator, λ::$T)
401+
ops = map(op -> op / λ, L.ops)
402+
AddedOperator(ops)
403+
end
404+
end
405+
389406
function Base.convert(::Type{AbstractMatrix}, L::AddedOperator)
390407
sum(op -> convert(AbstractMatrix, op), L.ops)
391408
end
@@ -422,6 +439,14 @@ function update_coefficients(L::AddedOperator, u, p, t)
422439
@reset L.ops = ops
423440
end
424441

442+
function update_coefficients!(L::AddedOperator, u, p, t)
443+
for op in L.ops
444+
update_coefficients!(op, u, p, t)
445+
end
446+
447+
L
448+
end
449+
425450
getops(L::AddedOperator) = L.ops
426451
islinear(L::AddedOperator) = all(islinear, getops(L))
427452
Base.iszero(L::AddedOperator) = all(iszero, getops(L))

0 commit comments

Comments
 (0)