480480TensorProductOperator (op:: AbstractSciMLOperator ) = op
481481TensorProductOperator (op:: AbstractMatrix ) = MatrixOperator (op)
482482TensorProductOperator (ops... ) = reduce (TensorProductOperator, ops)
483+ TensorProductOperator (Io:: IdentityOperator{No} , Ii:: IdentityOperator{Ni} ) where {No,Ni} = IdentityOperator {No*Ni} ()
483484
484485# overload ⊗ (\otimes)
485486⊗ (ops:: Union{AbstractMatrix,AbstractSciMLOperator} ...) = TensorProductOperator (ops... )
@@ -540,13 +541,18 @@ for op in (
540541 C = $ op (L. inner, U)
541542
542543 V = if k > 1
543- C = _reshape (C, (mi, no, k))
544- C = permutedims (C, perm)
545- C = _reshape (C, (no, mi* k))
546-
547- V = $ op (L. outer, C)
548- V = _reshape (V, (mo, mi, k))
549- V = permutedims (V, perm)
544+ V = if L. outer isa IdentityOperator
545+ copy (C)
546+ else
547+ C = _reshape (C, (mi, no, k))
548+ C = permutedims (C, perm)
549+ C = _reshape (C, (no, mi* k))
550+
551+ V = $ op (L. outer, C)
552+ V = _reshape (V, (mo, mi, k))
553+ V = permutedims (V, perm)
554+ V
555+ end
550556
551557 V
552558 else
@@ -565,7 +571,7 @@ function cache_self(L::TensorProductOperator, u::AbstractVecOrMat)
565571 c1 = similar (u, (mi, no* k)) # c1 = L.inner * u
566572 c2 = similar (u, (no, mi, k)) # permut (2, 1, 3)
567573 c3 = similar (u, (mo, mi* k)) # c3 = L.outer * c2
568- c4 = similar (u, (mo* mi, k)) # 5 arg mul!
574+ c4 = similar (u, (mo* mi, k)) # cache v in 5 arg mul!
569575
570576 @set! L. cache = (c1, c2, c3, c4,)
571577 L
@@ -610,14 +616,17 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::TensorProductOperator, u::Ab
610616
611617 # V .= U * B' <===> V' .= B * C'
612618 if k> 1
613- # TODO - avoid ops if L.outer is IdentityOperator
614- C1 = _reshape (C1, (mi, no, k))
615- permutedims! (C2, C1, perm)
616- C2 = _reshape (C2, (no, mi* k))
617- mul! (C3, L. outer, C2)
618- C3 = _reshape (C3, (mo, mi, k))
619- V = _reshape (v , (mi, mo, k))
620- permutedims! (V, C3, perm)
619+ if L. outer isa IdentityOperator
620+ copyto! (v, C1)
621+ else
622+ C1 = _reshape (C1, (mi, no, k))
623+ permutedims! (C2, C1, perm)
624+ C2 = _reshape (C2, (no, mi* k))
625+ mul! (C3, L. outer, C2)
626+ C3 = _reshape (C3, (mo, mi, k))
627+ V = _reshape (v , (mi, mo, k))
628+ permutedims! (V, C3, perm)
629+ end
621630 else
622631 V = _reshape (v, (mi, mo))
623632 C1 = _reshape (C1, (mi, no))
@@ -649,15 +658,20 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::TensorProductOperator, u::Ab
649658
650659 # V = α(C * B') + β(V)
651660 if k> 1
652- C1 = _reshape (C1, (mi, no, k))
653- permutedims! (C2, C1, perm)
654- C2 = _reshape (C2, (no, mi* k))
655- mul! (C3, L. outer, C2)
656- C3 = _reshape (C3, (mo, mi, k))
657- V = _reshape (v , (mi, mo, k))
658- copy! (c4, v)
659- permutedims! (V, C3, perm)
660- axpby! (β, c4, α, v)
661+ if L. outer isa IdentityOperator
662+ c1 = _reshape (C1, (m, k))
663+ axpby! (α, c1, β, v)
664+ else
665+ C1 = _reshape (C1, (mi, no, k))
666+ permutedims! (C2, C1, perm)
667+ C2 = _reshape (C2, (no, mi* k))
668+ mul! (C3, L. outer, C2)
669+ C3 = _reshape (C3, (mo, mi, k))
670+ V = _reshape (v , (mi, mo, k))
671+ copy! (c4, v)
672+ permutedims! (V, C3, perm)
673+ axpby! (β, c4, α, v)
674+ end
661675 else
662676 V = _reshape (v , (mi, mo))
663677 C1 = _reshape (C1, (mi, no))
@@ -689,13 +703,17 @@ function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::TensorProductOperator, u::A
689703
690704 # V .= C / B' <===> V' .= B \ C'
691705 if k> 1
692- C1 = _reshape (C1, (mi, no, k))
693- permutedims! (C2, C1, perm)
694- C2 = _reshape (C2, (no, mi* k))
695- ldiv! (C3, L. outer, C2)
696- C3 = _reshape (C3, (mo, mi, k))
697- V = _reshape (v , (mi, mo, k))
698- permutedims! (V, C3, perm)
706+ if L. outer isa IdentityOperator
707+ copyto! (v, C1)
708+ else
709+ C1 = _reshape (C1, (mi, no, k))
710+ permutedims! (C2, C1, perm)
711+ C2 = _reshape (C2, (no, mi* k))
712+ ldiv! (C3, L. outer, C2)
713+ C3 = _reshape (C3, (mo, mi, k))
714+ V = _reshape (v , (mi, mo, k))
715+ permutedims! (V, C3, perm)
716+ end
699717 else
700718 V = _reshape (v , (mi, mo))
701719 C1 = _reshape (C1, (mi, no))
@@ -726,7 +744,7 @@ function LinearAlgebra.ldiv!(L::TensorProductOperator, u::AbstractVecOrMat)
726744 ldiv! (L. inner, U)
727745
728746 # U .= U / B' <===> U' .= B \ U'
729- if k> 1
747+ if k> 1 & ! (L . outer isa IdentityOperator)
730748 U = _reshape (U, (ni, no, k))
731749 C = _reshape (C, (no, ni, k))
732750 permutedims! (C, U, perm)
0 commit comments