Skip to content

Commit c948888

Browse files
Merge pull request #59 from vpuri3/permutedims
Faster TensorProductOperator
2 parents 746c0d5 + 2eaede7 commit c948888

File tree

4 files changed

+89
-63
lines changed

4 files changed

+89
-63
lines changed

src/basic.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ function Base.zero(A::AbstractSciMLOperator)
102102
NullOperator{N}()
103103
end
104104

105-
# TODO sparse diagonal
106105
Base.convert(::Type{AbstractMatrix}, ::NullOperator{N}) where{N} = Diagonal(zeros(Bool, N))
107106

108107
# traits
@@ -200,8 +199,8 @@ function Base.adjoint(α::ScalarOperator) # TODO - test
200199
ScalarOperator(val; update_func=update_func)
201200
end
202201
Base.transpose::ScalarOperator) = α
203-
Base.one(::Type{AbstractSciMLOperator}) = ScalarOperator(true)
204-
Base.zero(::Type{AbstractSciMLOperator}) = ScalarOperator(false)
202+
Base.one(::Type{<:AbstractSciMLOperator}) = ScalarOperator(true)
203+
Base.zero(::Type{<:AbstractSciMLOperator}) = ScalarOperator(false)
205204

206205
getops::ScalarOperator) =.val,)
207206
islinear(L::ScalarOperator) = true
@@ -230,7 +229,7 @@ end
230229
for op in (:-, :+)
231230
@eval Base.$op::ScalarOperator, x::Number) = $op.val, x)
232231
@eval Base.$op(x::Number, α::ScalarOperator) = $op(x, α.val)
233-
@eval Base.$op(x::ScalarOperator, y::ScalarOperator) = $op(x.val, y.val) # TODO - lazy compose instead?
232+
@eval Base.$op(x::ScalarOperator, y::ScalarOperator) = $op(x.val, y.val) # TODO - lazy sum instead?
234233
end
235234

236235
LinearAlgebra.lmul!::ScalarOperator, u::AbstractVecOrMat) = lmul!.val, u)

src/sciml.jl

Lines changed: 83 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ end
480480
TensorProductOperator(op::AbstractSciMLOperator) = op
481481
TensorProductOperator(op::AbstractMatrix) = MatrixOperator(op)
482482
TensorProductOperator(ops...) = reduce(TensorProductOperator, ops)
483+
TensorProductOperator(Io::IdentityOperator{No}, Ii::IdentityOperator{Ni}) where{No,Ni} = IdentityOperator{No*Ni}()
483484

484485
# overload ⊗ (\otimes)
485486
(ops::Union{AbstractMatrix,AbstractSciMLOperator}...) = TensorProductOperator(ops...)
@@ -534,15 +535,23 @@ for op in (
534535
m , n = size(L)
535536
k = size(u, 2)
536537

538+
perm = (2, 1, 3)
539+
537540
U = _reshape(u, (ni, no*k))
538541
C = $op(L.inner, U)
539542

540543
V = if k > 1
541-
C = _reshape(C, (mi, no, k))
542-
V = similar( u, (mi, mo, k))
543-
544-
@views for i=1:k
545-
V[:,:,i] = transpose($op(L.outer, transpose(C[:,:,i])))
544+
V = if L.outer isa IdentityOperator
545+
copy(C)
546+
else
547+
C = _reshape(C, (mi, no, k))
548+
C = permutedims(C, perm)
549+
C = _reshape(C, (no, mi*k))
550+
551+
V = $op(L.outer, C)
552+
V = _reshape(V, (mo, mi, k))
553+
V = permutedims(V, perm)
554+
V
546555
end
547556

548557
V
@@ -556,10 +565,15 @@ end
556565

557566
function cache_self(L::TensorProductOperator, u::AbstractVecOrMat)
558567
mi, _ = size(L.inner)
559-
_ , no = size(L.outer)
568+
mo, no = size(L.outer)
560569
k = size(u, 2)
561570

562-
@set! L.cache = similar(u, (mi, no*k))
571+
c1 = similar(u, (mi, no*k)) # c1 = L.inner * u
572+
c2 = similar(u, (no, mi, k)) # permut (2, 1, 3)
573+
c3 = similar(u, (mo, mi*k)) # c3 = L.outer * c2
574+
c4 = similar(u, (mo*mi, k)) # cache v in 5 arg mul!
575+
576+
@set! L.cache = (c1, c2, c3, c4,)
563577
L
564578
end
565579

@@ -573,8 +587,7 @@ function cache_internals(L::TensorProductOperator, u::AbstractVecOrMat) where{D}
573587
k = size(u, 2)
574588

575589
uinner = _reshape(u, (ni, no*k))
576-
uouter = _reshape(L.cache, (no, mi*k))
577-
uouter = @views uouter[:,1:mi]
590+
uouter = L.cache[2]
578591

579592
@set! L.inner = cache_operator(L.inner, uinner)
580593
@set! L.outer = cache_operator(L.outer, uouter)
@@ -589,7 +602,8 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::TensorProductOperator, u::Ab
589602
mo, no = size(L.outer)
590603
k = size(u, 2)
591604

592-
C = L.cache
605+
perm = (2, 1, 3)
606+
C1, C2, C3, _ = L.cache
593607
U = _reshape(u, (ni, no*k))
594608

595609
"""
@@ -598,20 +612,25 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::TensorProductOperator, u::Ab
598612
"""
599613

600614
# C .= A * U
601-
mul!(C, L.inner, U)
615+
mul!(C1, L.inner, U)
602616

603617
# V .= U * B' <===> V' .= B * C'
604618
if k>1
605-
V = _reshape(v, (mi, mo, k))
606-
C = _reshape(C, (mi, no, k))
607-
608-
@views for i=1:k
609-
mul!(transpose(V[:,:,i]), L.outer, transpose(C[:,:,i]))
619+
if L.outer isa IdentityOperator
620+
copyto!(v, C1)
621+
else
622+
C1 = _reshape(C1, (mi, no, k))
623+
permutedims!(C2, C1, perm)
624+
C2 = _reshape(C2, (no, mi*k))
625+
mul!(C3, L.outer, C2)
626+
C3 = _reshape(C3, (mo, mi, k))
627+
V = _reshape(v , (mi, mo, k))
628+
permutedims!(V, C3, perm)
610629
end
611630
else
612-
V = _reshape(v, (mi, mo))
613-
C = _reshape(C, (mi, no))
614-
mul!(transpose(V), L.outer, transpose(C))
631+
V = _reshape(v, (mi, mo))
632+
C1 = _reshape(C1, (mi, no))
633+
mul!(transpose(V), L.outer, transpose(C1))
615634
end
616635

617636
v
@@ -625,7 +644,8 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::TensorProductOperator, u::Ab
625644
mo, no = size(L.outer)
626645
k = size(u, 2)
627646

628-
C = L.cache
647+
perm = (2, 1, 3)
648+
C1, C2, C3, c4 = L.cache
629649
U = _reshape(u, (ni, no*k))
630650

631651
"""
@@ -634,19 +654,27 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::TensorProductOperator, u::Ab
634654
"""
635655

636656
# C .= A * U
637-
mul!(C, L.inner, U)
657+
mul!(C1, L.inner, U)
638658

639659
# V = α(C * B') + β(V)
640660
if k>1
641-
V = _reshape(v, (mi, mo, k))
642-
C = _reshape(C, (mi, no, k))
643-
644-
@views for i=1:k
645-
mul!(transpose(V[:,:,i]), L.outer, transpose(C[:,:,i]), α, β)
661+
if L.outer isa IdentityOperator
662+
c1 = _reshape(C1, (m, k))
663+
axpby!(α, c1, β, v)
664+
else
665+
C1 = _reshape(C1, (mi, no, k))
666+
permutedims!(C2, C1, perm)
667+
C2 = _reshape(C2, (no, mi*k))
668+
mul!(C3, L.outer, C2)
669+
C3 = _reshape(C3, (mo, mi, k))
670+
V = _reshape(v , (mi, mo, k))
671+
copy!(c4, v)
672+
permutedims!(V, C3, perm)
673+
axpby!(β, c4, α, v)
646674
end
647675
else
648-
V = _reshape(v, (mi, mo))
649-
C = _reshape(C, (mi, no))
676+
V = _reshape(v , (mi, mo))
677+
C1 = _reshape(C1, (mi, no))
650678
mul!(transpose(V), L.outer, transpose(C), α, β)
651679
end
652680

@@ -661,7 +689,8 @@ function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::TensorProductOperator, u::A
661689
mo, no = size(L.outer)
662690
k = size(u, 2)
663691

664-
C = L.cache
692+
perm = (2, 1, 3)
693+
C1, C2, C3, _ = L.cache
665694
U = _reshape(u, (ni, no*k))
666695

667696
"""
@@ -670,20 +699,25 @@ function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::TensorProductOperator, u::A
670699
"""
671700

672701
# C .= A \ U
673-
ldiv!(C, L.inner, U)
702+
ldiv!(C1, L.inner, U)
674703

675704
# V .= C / B' <===> V' .= B \ C'
676705
if k>1
677-
C = _reshape(C, (mi, no, k))
678-
V = _reshape(v, (mi, mo, k))
679-
680-
@views for i=1:k
681-
ldiv!(transpose(V[:,:,i]), L.outer, transpose(C[:,:,i]))
706+
if L.outer isa IdentityOperator
707+
copyto!(v, C1)
708+
else
709+
C1 = _reshape(C1, (mi, no, k))
710+
permutedims!(C2, C1, perm)
711+
C2 = _reshape(C2, (no, mi*k))
712+
ldiv!(C3, L.outer, C2)
713+
C3 = _reshape(C3, (mo, mi, k))
714+
V = _reshape(v , (mi, mo, k))
715+
permutedims!(V, C3, perm)
682716
end
683717
else
684-
V = _reshape(v, (mi, mo))
685-
C = _reshape(C, (mi, no))
686-
ldiv!(transpose(V), L.outer, transpose(C))
718+
V = _reshape(v , (mi, mo))
719+
C1 = _reshape(C1, (mi, no))
720+
ldiv!(transpose(V), L.outer, transpose(C1))
687721
end
688722

689723
v
@@ -693,10 +727,12 @@ function LinearAlgebra.ldiv!(L::TensorProductOperator, u::AbstractVecOrMat)
693727
@assert L.isset "cache needs to be set up for operator of type $(typeof(L)).
694728
set up cache by calling cache_operator(L::AbstractSciMLOperator, u::AbstractArray)"
695729

696-
mi, ni = size(L.inner)
697-
_ , no = size(L.outer)
698-
k = size(u, 2)
730+
ni = size(L.inner, 1)
731+
no = size(L.outer, 1)
732+
k = size(u, 2)
699733

734+
perm = (2, 1, 3)
735+
C = L.cache[1]
700736
U = _reshape(u, (ni, no*k))
701737

702738
"""
@@ -708,12 +744,12 @@ function LinearAlgebra.ldiv!(L::TensorProductOperator, u::AbstractVecOrMat)
708744
ldiv!(L.inner, U)
709745

710746
# U .= U / B' <===> U' .= B \ U'
711-
if k>1
712-
U = _reshape(U, (mi, no, k))
713-
714-
@views for i=1:k
715-
ldiv!(L.outer, transpose(U[:,:,i]))
716-
end
747+
if k>1 & !(L.outer isa IdentityOperator)
748+
U = _reshape(U, (ni, no, k))
749+
C = _reshape(C, (no, ni, k))
750+
permutedims!(C, U, perm)
751+
ldiv!(L.outer, C)
752+
permutedims!(U, C, perm)
717753
else
718754
ldiv!(L.outer, transpose(U))
719755
end

src/utils.jl

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,6 @@ _vec(a::AbstractVector) = a
1212
_vec(a::AbstractArray) = _reshape(a,(length(a),))
1313
_vec(a::ReshapedArray) = _vec(a.parent)
1414

15-
function _view(a, dims::NTuple{D,Int}) where{D}
16-
# just one Colon -> _vec
17-
all(dim -> isa(dim, Colon), dims) && return a
18-
dims == size(a) && return a
19-
length(a) == prod(dims) && return a
20-
21-
view(a, dims...)
22-
end
23-
2415
function _mat_sizes(L::AbstractSciMLOperator, u::AbstractArray)
2516

2617
size_in = u isa AbstractVecOrMat ? size(u) : begin

test/sciml.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using SciMLOperators, LinearAlgebra
22
using Random
33

4-
using SciMLOperators: AbstractSciMLOperator, InvertibleOperator,
4+
using SciMLOperators: InvertibleOperator,
55

66
Random.seed!(0)
77
N = 8
@@ -214,8 +214,8 @@ end
214214
D1 = DiagonalOperator(rand(N2))
215215
D2 = DiagonalOperator(rand(N2))
216216

217-
TT = AbstractSciMLOperator[T1, T2]
218-
DD = Diagonal(AbstractSciMLOperator[D1, D2])
217+
TT = [T1, T2]
218+
DD = Diagonal([D1, D2])
219219

220220
op = TT' * DD * TT
221221
op = cache_operator(op, u)

0 commit comments

Comments
 (0)