480480TensorProductOperator (op:: AbstractSciMLOperator ) = op
481481TensorProductOperator (op:: AbstractMatrix ) = MatrixOperator (op)
482482TensorProductOperator (ops... ) = reduce (TensorProductOperator, ops)
483+ TensorProductOperator (Io:: IdentityOperator{No} , Ii:: IdentityOperator{Ni} ) where {No,Ni} = IdentityOperator {No*Ni} ()
483484
484485# overload ⊗ (\otimes)
485486⊗ (ops:: Union{AbstractMatrix,AbstractSciMLOperator} ...) = TensorProductOperator (ops... )
@@ -534,15 +535,23 @@ for op in (
534535 m , n = size (L)
535536 k = size (u, 2 )
536537
538+ perm = (2 , 1 , 3 )
539+
537540 U = _reshape (u, (ni, no* k))
538541 C = $ op (L. inner, U)
539542
540543 V = if k > 1
541- C = _reshape (C, (mi, no, k))
542- V = similar ( u, (mi, mo, k))
543-
544- @views for i= 1 : k
545- V[:,:,i] = transpose ($ op (L. outer, transpose (C[:,:,i])))
544+ V = if L. outer isa IdentityOperator
545+ copy (C)
546+ else
547+ C = _reshape (C, (mi, no, k))
548+ C = permutedims (C, perm)
549+ C = _reshape (C, (no, mi* k))
550+
551+ V = $ op (L. outer, C)
552+ V = _reshape (V, (mo, mi, k))
553+ V = permutedims (V, perm)
554+ V
546555 end
547556
548557 V
@@ -556,10 +565,15 @@ end
556565
557566function cache_self (L:: TensorProductOperator , u:: AbstractVecOrMat )
558567 mi, _ = size (L. inner)
559- _ , no = size (L. outer)
568+ mo , no = size (L. outer)
560569 k = size (u, 2 )
561570
562- @set! L. cache = similar (u, (mi, no* k))
571+ c1 = similar (u, (mi, no* k)) # c1 = L.inner * u
572+ c2 = similar (u, (no, mi, k)) # permut (2, 1, 3)
573+ c3 = similar (u, (mo, mi* k)) # c3 = L.outer * c2
574+ c4 = similar (u, (mo* mi, k)) # cache v in 5 arg mul!
575+
576+ @set! L. cache = (c1, c2, c3, c4,)
563577 L
564578end
565579
@@ -573,8 +587,7 @@ function cache_internals(L::TensorProductOperator, u::AbstractVecOrMat) where{D}
573587 k = size (u, 2 )
574588
575589 uinner = _reshape (u, (ni, no* k))
576- uouter = _reshape (L. cache, (no, mi* k))
577- uouter = @views uouter[:,1 : mi]
590+ uouter = L. cache[2 ]
578591
579592 @set! L. inner = cache_operator (L. inner, uinner)
580593 @set! L. outer = cache_operator (L. outer, uouter)
@@ -589,7 +602,8 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::TensorProductOperator, u::Ab
589602 mo, no = size (L. outer)
590603 k = size (u, 2 )
591604
592- C = L. cache
605+ perm = (2 , 1 , 3 )
606+ C1, C2, C3, _ = L. cache
593607 U = _reshape (u, (ni, no* k))
594608
595609 """
@@ -598,20 +612,25 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::TensorProductOperator, u::Ab
598612 """
599613
600614 # C .= A * U
601- mul! (C , L. inner, U)
615+ mul! (C1 , L. inner, U)
602616
603617 # V .= U * B' <===> V' .= B * C'
604618 if k> 1
605- V = _reshape (v, (mi, mo, k))
606- C = _reshape (C, (mi, no, k))
607-
608- @views for i= 1 : k
609- mul! (transpose (V[:,:,i]), L. outer, transpose (C[:,:,i]))
619+ if L. outer isa IdentityOperator
620+ copyto! (v, C1)
621+ else
622+ C1 = _reshape (C1, (mi, no, k))
623+ permutedims! (C2, C1, perm)
624+ C2 = _reshape (C2, (no, mi* k))
625+ mul! (C3, L. outer, C2)
626+ C3 = _reshape (C3, (mo, mi, k))
627+ V = _reshape (v , (mi, mo, k))
628+ permutedims! (V, C3, perm)
610629 end
611630 else
612- V = _reshape (v, (mi, mo))
613- C = _reshape (C , (mi, no))
614- mul! (transpose (V), L. outer, transpose (C ))
631+ V = _reshape (v, (mi, mo))
632+ C1 = _reshape (C1 , (mi, no))
633+ mul! (transpose (V), L. outer, transpose (C1 ))
615634 end
616635
617636 v
@@ -625,7 +644,8 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::TensorProductOperator, u::Ab
625644 mo, no = size (L. outer)
626645 k = size (u, 2 )
627646
628- C = L. cache
647+ perm = (2 , 1 , 3 )
648+ C1, C2, C3, c4 = L. cache
629649 U = _reshape (u, (ni, no* k))
630650
631651 """
@@ -634,19 +654,27 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::TensorProductOperator, u::Ab
634654 """
635655
636656 # C .= A * U
637- mul! (C , L. inner, U)
657+ mul! (C1 , L. inner, U)
638658
639659 # V = α(C * B') + β(V)
640660 if k> 1
641- V = _reshape (v, (mi, mo, k))
642- C = _reshape (C, (mi, no, k))
643-
644- @views for i= 1 : k
645- mul! (transpose (V[:,:,i]), L. outer, transpose (C[:,:,i]), α, β)
661+ if L. outer isa IdentityOperator
662+ c1 = _reshape (C1, (m, k))
663+ axpby! (α, c1, β, v)
664+ else
665+ C1 = _reshape (C1, (mi, no, k))
666+ permutedims! (C2, C1, perm)
667+ C2 = _reshape (C2, (no, mi* k))
668+ mul! (C3, L. outer, C2)
669+ C3 = _reshape (C3, (mo, mi, k))
670+ V = _reshape (v , (mi, mo, k))
671+ copy! (c4, v)
672+ permutedims! (V, C3, perm)
673+ axpby! (β, c4, α, v)
646674 end
647675 else
648- V = _reshape (v, (mi, mo))
649- C = _reshape (C , (mi, no))
676+ V = _reshape (v , (mi, mo))
677+ C1 = _reshape (C1 , (mi, no))
650678 mul! (transpose (V), L. outer, transpose (C), α, β)
651679 end
652680
@@ -661,7 +689,8 @@ function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::TensorProductOperator, u::A
661689 mo, no = size (L. outer)
662690 k = size (u, 2 )
663691
664- C = L. cache
692+ perm = (2 , 1 , 3 )
693+ C1, C2, C3, _ = L. cache
665694 U = _reshape (u, (ni, no* k))
666695
667696 """
@@ -670,20 +699,25 @@ function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::TensorProductOperator, u::A
670699 """
671700
672701 # C .= A \ U
673- ldiv! (C , L. inner, U)
702+ ldiv! (C1 , L. inner, U)
674703
675704 # V .= C / B' <===> V' .= B \ C'
676705 if k> 1
677- C = _reshape (C, (mi, no, k))
678- V = _reshape (v, (mi, mo, k))
679-
680- @views for i= 1 : k
681- ldiv! (transpose (V[:,:,i]), L. outer, transpose (C[:,:,i]))
706+ if L. outer isa IdentityOperator
707+ copyto! (v, C1)
708+ else
709+ C1 = _reshape (C1, (mi, no, k))
710+ permutedims! (C2, C1, perm)
711+ C2 = _reshape (C2, (no, mi* k))
712+ ldiv! (C3, L. outer, C2)
713+ C3 = _reshape (C3, (mo, mi, k))
714+ V = _reshape (v , (mi, mo, k))
715+ permutedims! (V, C3, perm)
682716 end
683717 else
684- V = _reshape (v, (mi, mo))
685- C = _reshape (C , (mi, no))
686- ldiv! (transpose (V), L. outer, transpose (C ))
718+ V = _reshape (v , (mi, mo))
719+ C1 = _reshape (C1 , (mi, no))
720+ ldiv! (transpose (V), L. outer, transpose (C1 ))
687721 end
688722
689723 v
@@ -693,10 +727,12 @@ function LinearAlgebra.ldiv!(L::TensorProductOperator, u::AbstractVecOrMat)
693727 @assert L. isset " cache needs to be set up for operator of type $(typeof (L)) .
694728 set up cache by calling cache_operator(L::AbstractSciMLOperator, u::AbstractArray)"
695729
696- mi, ni = size (L. inner)
697- _ , no = size (L. outer)
698- k = size (u, 2 )
730+ ni = size (L. inner, 1 )
731+ no = size (L. outer, 1 )
732+ k = size (u, 2 )
699733
734+ perm = (2 , 1 , 3 )
735+ C = L. cache[1 ]
700736 U = _reshape (u, (ni, no* k))
701737
702738 """
@@ -708,12 +744,12 @@ function LinearAlgebra.ldiv!(L::TensorProductOperator, u::AbstractVecOrMat)
708744 ldiv! (L. inner, U)
709745
710746 # U .= U / B' <===> U' .= B \ U'
711- if k> 1
712- U = _reshape (U, (mi , no, k))
713-
714- @views for i = 1 : k
715- ldiv! (L. outer, transpose (U[:,:,i]) )
716- end
747+ if k> 1 & ! (L . outer isa IdentityOperator)
748+ U = _reshape (U, (ni , no, k))
749+ C = _reshape (C, (no, ni, k))
750+ permutedims! (C, U, perm)
751+ ldiv! (L. outer, C )
752+ permutedims! (U, C, perm)
717753 else
718754 ldiv! (L. outer, transpose (U))
719755 end
0 commit comments