@@ -82,12 +82,12 @@ for op in (
8282 end
8383end
8484
85- function Base.:\ (:: IdentityOperator , A:: AbstractSciMLOperator )
85+ function Base.:\ (ii :: IdentityOperator , A:: AbstractSciMLOperator )
8686 @assert size (A, 1 ) == ii. len
8787 A
8888end
8989
90- function Base.:/ (A:: AbstractSciMLOperator , :: IdentityOperator )
90+ function Base.:/ (A:: AbstractSciMLOperator , ii :: IdentityOperator )
9191 @assert size (A, 2 ) == ii. len
9292 A
9393end
@@ -330,8 +330,9 @@ AddedOperator(L::AbstractSciMLOperator) = L
330330# constructors
331331Base.:+ (A:: AbstractSciMLOperator , B:: AbstractMatrix ) = A + MatrixOperator (B)
332332Base.:+ (A:: AbstractMatrix , B:: AbstractSciMLOperator ) = MatrixOperator (A) + B
333- Base.:+ (ops:: AbstractSciMLOperator... ) = AddedOperator (ops... )
334333
334+ Base.:+ (ops:: AbstractSciMLOperator... ) = reduce (+ , ops)
335+ Base.:+ (A:: AbstractSciMLOperator , B:: AbstractSciMLOperator ) = AddedOperator (A, B)
335336Base.:+ (A:: AbstractSciMLOperator , B:: AddedOperator ) = AddedOperator (A, B. ops... )
336337Base.:+ (A:: AddedOperator , B:: AbstractSciMLOperator ) = AddedOperator (A. ops... , B)
337338Base.:+ (A:: AddedOperator , B:: AddedOperator ) = AddedOperator (A. ops... , B. ops... )
@@ -471,16 +472,15 @@ function ComposedOperator(ops::AbstractSciMLOperator...; cache = nothing)
471472end
472473
473474# constructors
474- Base.:∘ (ops:: AbstractSciMLOperator... ) = ComposedOperator (ops... )
475- Base.:∘ (A:: ComposedOperator , B:: ComposedOperator ) = ComposedOperator (A. ops... , B. ops... )
476- Base.:∘ (A:: AbstractSciMLOperator , B:: ComposedOperator ) = ComposedOperator (A, B. ops... )
477- Base.:∘ (A:: ComposedOperator , B:: AbstractSciMLOperator ) = ComposedOperator (A. ops... , B)
478-
479- Base.:* (ops:: AbstractSciMLOperator... ) = ComposedOperator (ops... )
480- Base.:* (A:: AbstractSciMLOperator , B:: AbstractSciMLOperator ) = ∘ (A, B)
481- Base.:* (A:: ComposedOperator , B:: AbstractSciMLOperator ) = ∘ (A. ops[1 : end - 1 ]. .. , A. ops[end ] * B)
482- Base.:* (A:: AbstractSciMLOperator , B:: ComposedOperator ) = ∘ (A * B. ops[1 ], B. ops[2 : end ]. .. )
483- Base.:* (A:: ComposedOperator , B:: ComposedOperator ) = ComposedOperator (A. ops... , B. ops... )
475+ for op in (
476+ :* , :∘ ,
477+ )
478+ @eval Base.$ op (ops:: AbstractSciMLOperator... ) = reduce ($ op, ops)
479+ @eval Base.$ op (A:: AbstractSciMLOperator , B:: AbstractSciMLOperator ) = ComposedOperator (A, B)
480+ @eval Base.$ op (A:: ComposedOperator , B:: AbstractSciMLOperator ) = ComposedOperator (A. ops... , B)
481+ @eval Base.$ op (A:: AbstractSciMLOperator , B:: ComposedOperator ) = ComposedOperator (A, B. ops... )
482+ @eval Base.$ op (A:: ComposedOperator , B:: ComposedOperator ) = ComposedOperator (A. ops... , B. ops... )
483+ end
484484
485485for op in (
486486 :* , :∘ ,
@@ -606,11 +606,20 @@ function cache_self(L::ComposedOperator, u::AbstractVecOrMat)
606606 K = size (u, 2 )
607607 cache = (zero (u),)
608608 for i in reverse (2 : length (L. ops))
609+ op = L. ops[i]
609610
610- M = size (L. ops[i], 1 )
611- T = promote_type (eltype .((L. ops[i], cache[1 ]))... )
611+ M = size (op, 1 )
612612 sz = u isa AbstractMatrix ? (M, K) : (M,)
613613
614+ T = if op isa FunctionOperator #
615+ # FunctionOperator isn't guaranteed to play by the rules of
616+ # `promote_type`. For example, an rFFT is a complex operation
617+ # that accepts and complex vector and returns a real one.
618+ op. traits. eltypes[2 ]
619+ else
620+ promote_type (eltype .((op, cache[1 ]))... )
621+ end
622+
614623 cache = (similar (u, T, sz), cache... )
615624 end
616625
@@ -623,12 +632,12 @@ function cache_internals(L::ComposedOperator, u::AbstractVecOrMat)
623632 L = cache_self (L, u)
624633 end
625634
626- vecs = L . cache
635+ ops = ()
627636 for i in reverse (1 : length (L. ops))
628- @set! L . ops[i] = cache_operator (L. ops[i], vecs [i])
637+ ops = ( cache_operator (L. ops[i], L . cache [i]), ops ... )
629638 end
630639
631- L
640+ @set! L . ops = ops
632641end
633642
634643function LinearAlgebra. mul! (v:: AbstractVecOrMat , L:: ComposedOperator , u:: AbstractVecOrMat )
0 commit comments