Skip to content

Commit 2eaede7

Browse files
committed
short circuit when identity. will speed things up with kronsum
1 parent a83d75a commit 2eaede7

File tree

1 file changed

+51
-33
lines changed

1 file changed

+51
-33
lines changed

src/sciml.jl

Lines changed: 51 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ end
480480
TensorProductOperator(op::AbstractSciMLOperator) = op
481481
TensorProductOperator(op::AbstractMatrix) = MatrixOperator(op)
482482
TensorProductOperator(ops...) = reduce(TensorProductOperator, ops)
483+
TensorProductOperator(Io::IdentityOperator{No}, Ii::IdentityOperator{Ni}) where{No,Ni} = IdentityOperator{No*Ni}()
483484

484485
# overload ⊗ (\otimes)
485486
(ops::Union{AbstractMatrix,AbstractSciMLOperator}...) = TensorProductOperator(ops...)
@@ -540,13 +541,18 @@ for op in (
540541
C = $op(L.inner, U)
541542

542543
V = if k > 1
543-
C = _reshape(C, (mi, no, k))
544-
C = permutedims(C, perm)
545-
C = _reshape(C, (no, mi*k))
546-
547-
V = $op(L.outer, C)
548-
V = _reshape(V, (mo, mi, k))
549-
V = permutedims(V, perm)
544+
V = if L.outer isa IdentityOperator
545+
copy(C)
546+
else
547+
C = _reshape(C, (mi, no, k))
548+
C = permutedims(C, perm)
549+
C = _reshape(C, (no, mi*k))
550+
551+
V = $op(L.outer, C)
552+
V = _reshape(V, (mo, mi, k))
553+
V = permutedims(V, perm)
554+
V
555+
end
550556

551557
V
552558
else
@@ -565,7 +571,7 @@ function cache_self(L::TensorProductOperator, u::AbstractVecOrMat)
565571
c1 = similar(u, (mi, no*k)) # c1 = L.inner * u
566572
c2 = similar(u, (no, mi, k)) # permut (2, 1, 3)
567573
c3 = similar(u, (mo, mi*k)) # c3 = L.outer * c2
568-
c4 = similar(u, (mo*mi, k)) # 5 arg mul!
574+
c4 = similar(u, (mo*mi, k)) # cache v in 5 arg mul!
569575

570576
@set! L.cache = (c1, c2, c3, c4,)
571577
L
@@ -610,14 +616,17 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::TensorProductOperator, u::Ab
610616

611617
# V .= U * B' <===> V' .= B * C'
612618
if k>1
613-
# TODO - avoid ops if L.outer is IdentityOperator
614-
C1 = _reshape(C1, (mi, no, k))
615-
permutedims!(C2, C1, perm)
616-
C2 = _reshape(C2, (no, mi*k))
617-
mul!(C3, L.outer, C2)
618-
C3 = _reshape(C3, (mo, mi, k))
619-
V = _reshape(v , (mi, mo, k))
620-
permutedims!(V, C3, perm)
619+
if L.outer isa IdentityOperator
620+
copyto!(v, C1)
621+
else
622+
C1 = _reshape(C1, (mi, no, k))
623+
permutedims!(C2, C1, perm)
624+
C2 = _reshape(C2, (no, mi*k))
625+
mul!(C3, L.outer, C2)
626+
C3 = _reshape(C3, (mo, mi, k))
627+
V = _reshape(v , (mi, mo, k))
628+
permutedims!(V, C3, perm)
629+
end
621630
else
622631
V = _reshape(v, (mi, mo))
623632
C1 = _reshape(C1, (mi, no))
@@ -649,15 +658,20 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::TensorProductOperator, u::Ab
649658

650659
# V = α(C * B') + β(V)
651660
if k>1
652-
C1 = _reshape(C1, (mi, no, k))
653-
permutedims!(C2, C1, perm)
654-
C2 = _reshape(C2, (no, mi*k))
655-
mul!(C3, L.outer, C2)
656-
C3 = _reshape(C3, (mo, mi, k))
657-
V = _reshape(v , (mi, mo, k))
658-
copy!(c4, v)
659-
permutedims!(V, C3, perm)
660-
axpby!(β, c4, α, v)
661+
if L.outer isa IdentityOperator
662+
c1 = _reshape(C1, (m, k))
663+
axpby!(α, c1, β, v)
664+
else
665+
C1 = _reshape(C1, (mi, no, k))
666+
permutedims!(C2, C1, perm)
667+
C2 = _reshape(C2, (no, mi*k))
668+
mul!(C3, L.outer, C2)
669+
C3 = _reshape(C3, (mo, mi, k))
670+
V = _reshape(v , (mi, mo, k))
671+
copy!(c4, v)
672+
permutedims!(V, C3, perm)
673+
axpby!(β, c4, α, v)
674+
end
661675
else
662676
V = _reshape(v , (mi, mo))
663677
C1 = _reshape(C1, (mi, no))
@@ -689,13 +703,17 @@ function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::TensorProductOperator, u::A
689703

690704
# V .= C / B' <===> V' .= B \ C'
691705
if k>1
692-
C1 = _reshape(C1, (mi, no, k))
693-
permutedims!(C2, C1, perm)
694-
C2 = _reshape(C2, (no, mi*k))
695-
ldiv!(C3, L.outer, C2)
696-
C3 = _reshape(C3, (mo, mi, k))
697-
V = _reshape(v , (mi, mo, k))
698-
permutedims!(V, C3, perm)
706+
if L.outer isa IdentityOperator
707+
copyto!(v, C1)
708+
else
709+
C1 = _reshape(C1, (mi, no, k))
710+
permutedims!(C2, C1, perm)
711+
C2 = _reshape(C2, (no, mi*k))
712+
ldiv!(C3, L.outer, C2)
713+
C3 = _reshape(C3, (mo, mi, k))
714+
V = _reshape(v , (mi, mo, k))
715+
permutedims!(V, C3, perm)
716+
end
699717
else
700718
V = _reshape(v , (mi, mo))
701719
C1 = _reshape(C1, (mi, no))
@@ -726,7 +744,7 @@ function LinearAlgebra.ldiv!(L::TensorProductOperator, u::AbstractVecOrMat)
726744
ldiv!(L.inner, U)
727745

728746
# U .= U / B' <===> U' .= B \ U'
729-
if k>1
747+
if k>1 & !(L.outer isa IdentityOperator)
730748
U = _reshape(U, (ni, no, k))
731749
C = _reshape(C, (no, ni, k))
732750
permutedims!(C, U, perm)

0 commit comments

Comments
 (0)