@@ -12,7 +12,6 @@ function Base.one(A::AbstractSciMLOperator)
1212 IdentityOperator {N} ()
1313end
1414
15- # TODO - sparse diagonal
1615Base. convert (:: Type{AbstractMatrix} , :: IdentityOperator{N} ) where {N} = Diagonal (ones (Bool, N))
1716
1817# traits
@@ -187,8 +186,10 @@ function Base.adjoint(α::ScalarOperator) # TODO - test
187186 ScalarOperator (val; update_func= update_func)
188187end
189188Base. transpose (α:: ScalarOperator ) = α
189+ Base. one (:: Type{AbstractSciMLOperator} ) = ScalarOperator (true )
190+ Base. zero (:: Type{AbstractSciMLOperator} ) = ScalarOperator (false )
190191
191- getops (α:: ScalarOperator ) = (α. val)
192+ getops (α:: ScalarOperator ) = (α. val, )
192193islinear (L:: ScalarOperator ) = true
193194issquare (L:: ScalarOperator ) = true
194195isconstant (α:: ScalarOperator ) = α. update_func == DEFAULT_UPDATE_FUNC
@@ -249,9 +250,11 @@ struct ScaledOperator{T,
249250 L:: LType
250251 cache:: T
251252
252- function ScaledOperator (λ:: ScalarOperator , L:: AbstractSciMLOperator )
253- T = promote_type (eltype .((λ, L))... )
254- cache = zero (T)
253+ function ScaledOperator (λ:: ScalarOperator{Tλ} ,
254+ L:: AbstractSciMLOperator{TL} ,
255+ cache = zero (promote_type (Tλ,TL))
256+ ) where {Tλ,TL}
257+ T = promote_type (Tλ, TL)
255258 new {T,typeof(λ),typeof(L)} (λ, L, cache)
256259 end
257260end
@@ -297,7 +300,7 @@ for op in (
297300end
298301LinearAlgebra. opnorm (L:: ScaledOperator , p:: Real = 2 ) = abs (L. λ) * opnorm (L. L, p)
299302
300- getops (L:: ScaledOperator ) = (L. λ, L. L)
303+ getops (L:: ScaledOperator ) = (L. λ, L. L, )
301304islinear (L:: ScaledOperator ) = all (islinear, L. ops)
302305isconstant (L:: ScaledOperator ) = isconstant (L. L) & isconstant (L. λ)
303306Base. iszero (L:: ScaledOperator ) = iszero (L. L) | iszero (L. λ)
@@ -306,6 +309,12 @@ has_mul!(L::ScaledOperator) = has_mul!(L.L)
306309has_ldiv (L:: ScaledOperator ) = has_ldiv (L. L) & ! iszero (L. λ)
307310has_ldiv! (L:: ScaledOperator ) = has_ldiv! (L. L) & ! iszero (L. λ)
308311
312+ function cache_internals (L:: ScaledOperator , u:: AbstractVector )
313+ @set! L. L = cache_operator (L. L, u)
314+ @set! L. λ = cache_operator (L. λ, u)
315+ L
316+ end
317+
309318# getindex
310319Base. getindex (L:: ScaledOperator , i:: Int ) = L. coeff * L. op[i]
311320Base. getindex (L:: ScaledOperator , I:: Vararg{Int, N} ) where {N} = L. λ * L. L[I... ]
@@ -348,7 +357,7 @@ function LinearAlgebra.ldiv!(L::ScaledOperator, u::AbstractVector)
348357end
349358
350359"""
351- Lazy operator addition (A + B)
360+ Lazy operator addition
352361
353362 (A1 + A2 + A3...)u = A1*u + A2*u + A3*u ....
354363"""
@@ -357,17 +366,22 @@ struct AddedOperator{T,
357366 } <: AbstractSciMLOperator{T}
358367 ops:: O
359368
360- function AddedOperator (ops... )
361- sz = size (first (ops))
362- for op in ops[2 : end ]
363- @assert size (op) == sz " Size mismatich in operators $ops "
364- end
365-
369+ function AddedOperator (ops)
366370 T = promote_type (eltype .(ops)... )
367371 new {T,typeof(ops)} (ops)
368372 end
369373end
370374
375+ function AddedOperator (ops:: AbstractSciMLOperator... )
376+ sz = size (first (ops))
377+ for op in ops[2 : end ]
378+ @assert size (op) == sz " Size mismatich in operators $ops "
379+ end
380+ AddedOperator (ops)
381+ end
382+
383+ AddedOperator (L:: AbstractSciMLOperator ) = L
384+
371385# constructors
372386Base.:+ (ops:: AbstractSciMLOperator... ) = AddedOperator (ops... )
373387Base.:- (A:: AbstractSciMLOperator , B:: AbstractSciMLOperator ) = AddedOperator (A, - B)
@@ -423,6 +437,13 @@ getops(L::AddedOperator) = L.ops
423437Base. iszero (L:: AddedOperator ) = all (iszero, getops (L))
424438has_adjoint (L:: AddedOperator ) = all (has_adjoint, L. ops)
425439
440+ function cache_internals (L:: AddedOperator , u:: AbstractVector )
441+ for i= 1 : length (L. ops)
442+ @set! L. ops[i] = cache_operator (L. ops[i], u)
443+ end
444+ L
445+ end
446+
426447getindex (L:: AddedOperator , i:: Int ) = sum (op -> op[i], L. ops)
427448getindex (L:: AddedOperator , I:: Vararg{Int, N} ) where {N} = sum (op -> op[I... ], L. ops)
428449
@@ -532,30 +553,35 @@ end
532553Base.:* (L:: ComposedOperator , u:: AbstractVector ) = foldl ((acc, op) -> op * acc, reverse (L. ops); init= u)
533554Base.:\ (L:: ComposedOperator , u:: AbstractVector ) = foldl ((acc, op) -> op \ acc, L. ops; init= u)
534555
535- function cache_operator (L:: ComposedOperator , u:: AbstractVector )
536- # for 3 arg mul!
537- # Tuple of N-1 cache vectors. cache[N-1] = op[N] * u and so on
538- vec = u
539- c3 = ()
556+ function cache_self (L:: ComposedOperator , u:: AbstractVector )
557+ vec = similar (u)
558+ cache = (vec,)
540559 for i in reverse (2 : length (L. ops))
541- vec = L. ops[i] * vec
542- c3 = (c3 ... , vec )
560+ vec = L. ops[i] * vec
561+ cache = (vec, cache ... )
543562 end
544563
545- # for 5 arg mul!
546- c5 = similar (u)
564+ @set! L. cache = cache
565+ L
566+ end
547567
548- cache = (;c3= c3, c5= c5)
568+ function cache_internals (L:: ComposedOperator , u:: AbstractVector )
569+ if ! (L. isset)
570+ L = cache_self (L, u)
571+ end
572+
573+ vecs = L. cache
574+ for i in reverse (1 : length (L. ops))
575+ @set! L. ops[i] = cache_operator (L. ops[i], vecs[i])
576+ end
549577
550- @set! L. cache = cache
551578 L
552579end
553580
554581function LinearAlgebra. mul! (v:: AbstractVector , L:: ComposedOperator , u:: AbstractVector )
555582 @assert L. isset " cache needs to be set up to use LinearAlgebra.mul!"
556583
557- cache = L. cache. c3
558- vecs = (v, cache... , u)
584+ vecs = (v, L. cache[1 : end - 1 ]. .. , u)
559585 for i in reverse (1 : length (L. ops))
560586 mul! (vecs[i], L. ops[i], vecs[i+ 1 ])
561587 end
565591function LinearAlgebra. mul! (v:: AbstractVector , L:: ComposedOperator , u:: AbstractVector , α, β)
566592 @assert L. isset " cache needs to be set up to use LinearAlgebra.mul!"
567593
568- cache = L. cache. c5
594+ cache = L. cache[ end ]
569595 copy! (cache, v)
570596
571597 mul! (v, L, u)
576602function LinearAlgebra. ldiv! (v:: AbstractVector , L:: ComposedOperator , u:: AbstractVector )
577603 @assert L. isset " cache needs to be set up to use 3 arg LinearAlgebra.ldiv!"
578604
579- cache = L. cache. c3
580- vecs = (u, reverse (cache)... , v)
605+ vecs = (u, reverse (L. cache[1 : end - 1 ])... , v)
581606 for i in 1 : length (L. ops)
582607 ldiv! (vecs[i+ 1 ], L. ops[i], vecs[i])
583608 end
@@ -645,6 +670,11 @@ for (op, LType, VType) in (
645670 has_ldiv!,
646671 )
647672
673+ @eval function cache_internals (L:: $LType , u:: AbstractVector )
674+ @set! L. L = cache_operator (L. L, _reshape (u, size (L,1 )))
675+ L
676+ end
677+
648678 # oeprator application
649679 @eval Base.:* (u:: $VType , L:: $LType ) = $ op (L. L * u. parent)
650680 @eval Base.:/ (u:: $VType , L:: $LType ) = $ op (L. L \ u. parent)
@@ -724,12 +754,17 @@ has_ldiv!(L::InvertedOperator) = has_mul!(L.L)
724754Base.:* (L:: InvertedOperator , u:: AbstractVector ) = L. L \ u
725755Base.:\ (L:: InvertedOperator , u:: AbstractVector ) = L. L * u
726756
727- function cache_operator (L:: InvertedOperator , u:: AbstractVector )
757+ function cache_self (L:: InvertedOperator , u:: AbstractVector )
728758 cache = similar (u)
729759 @set! L. cache = cache
730760 L
731761end
732762
763+ function cache_internals (L:: InvertedOperator , u:: AbstractVector )
764+ @set! L. L = cache_operator (L. L, u)
765+ L
766+ end
767+
733768function LinearAlgebra. mul! (v:: AbstractVector , L:: InvertedOperator , u:: AbstractVector )
734769 ldiv! (v, L. L, u)
735770end
0 commit comments