Skip to content

Commit c606a1f

Browse files
authored
Fast KronTrav * DiagTrav (#138)
1 parent ef7785d commit c606a1f

File tree

5 files changed

+118
-29
lines changed

5 files changed

+118
-29
lines changed

ext/LazyBandedMatricesInfiniteArraysExt.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ using LazyBandedMatrices.BlockArrays
44
using LazyBandedMatrices.ArrayLayouts
55

66
import Base: BroadcastStyle, copy, OneTo, oneto
7-
import LazyBandedMatrices: _krontrav_axes, _block_interlace_axes, _broadcast_sub_arguments, AbstractLazyBandedBlockBandedLayout, KronTravBandedBlockBandedLayout, krontravargs
7+
import LazyBandedMatrices: _krontrav_axes, _block_interlace_axes, _broadcast_sub_arguments, AbstractLazyBandedBlockBandedLayout, KronTravBandedBlockBandedLayout, krontravargs, DiagTravLayout
88
import InfiniteArrays: InfFill, TridiagonalToeplitzLayout, BidiagonalToeplitzLayout, LazyArrayStyle, OneToInf
99
import LazyBandedMatrices.ArrayLayouts: MemoryLayout, sublayout, RangeCumsum, Mul
1010
import LazyBandedMatrices.BlockArrays: sizes_from_blocks, BlockedOneTo, BlockSlice1, BlockSlice
11-
import LazyBandedMatrices.LazyArrays: BroadcastBandedLayout
11+
import LazyBandedMatrices.LazyArrays: BroadcastBandedLayout, AbstractPaddedLayout
1212

1313
const OneToInfCumsum = RangeCumsum{Int,OneToInf{Int}}
1414

@@ -53,4 +53,8 @@ _block_interlace_axes(nbc::Int, ax::NTuple{2,BlockedOneTo{Int,OneToInf{Int}}}...
5353
(blockedrange(Fill(length(ax) ÷ nbc, ∞)),blockedrange(Fill(mod1(length(ax),nbc), ∞)))
5454

5555

56+
# KronTrav * DiagTrav
57+
58+
copy(M::Mul{InfKronTravBandedBlockBandedLayout, Lay}) where Lay<:DiagTravLayout{<:AbstractPaddedLayout} = copy(Mul{KronTravBandedBlockBandedLayout, Lay}(M.A, M.B))
59+
5660
end

src/LazyBandedMatrices.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ import LinearAlgebra: transpose, adjoint, istriu, istril, isdiag, tril!, triu!,
1313

1414
import ArrayLayouts: MemoryLayout, bidiagonallayout, bidiagonaluplo, diagonaldata, supdiagonaldata, subdiagonaldata,
1515
symtridiagonallayout, tridiagonallayout, symmetriclayout,
16-
colsupport, rowsupport, sublayout, sub_materialize, _copyto!
16+
colsupport, rowsupport, sublayout, sub_materialize, _copyto!,
17+
materialize!, MulAdd, MatMulVecAdd
1718
import LazyArrays: ApplyLayout, AbstractPaddedLayout, PaddedLayout, PaddedColumns, BroadcastLayout, LazyArrayStyle, LazyLayout,
1819
arguments, call, tuple_type_memorylayouts, paddeddata, _broadcast_sub_arguments, resizedata!,
1920
_cumsum, convexunion, applylayout, AbstractLazyBandedLayout, ApplyBandedLayout, BroadcastBandedLayout, LazyBandedLayout

src/blockkron.jl

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,67 @@
44
# Block
55
###
66

7+
function _DiagTrav end
8+
79
"""
810
DiagTrav(A::AbstractMatrix)
911
1012
converts a matrix to a block vector by traversing the anti-diagonals.
1113
"""
1214
struct DiagTrav{T, N, AA<:AbstractArray{T,N}} <: AbstractBlockVector{T}
1315
array::AA
14-
function DiagTrav{T, N, AA}(array::AA) where {T, N, AA<:AbstractArray{T,N}}
16+
global function _DiagTrav(array::AA) where {T, N, AA<:AbstractArray{T,N}}
1517
new{T,N,AA}(array)
1618
end
1719
end
20+
DiagTrav{T,N,AA}(A::AA) where {T, N, AA<:AbstractArray{T,N}} = _DiagTrav(zero_bottomright(A))
21+
1822
DiagTrav{T,N}(A::AbstractArray) where {T,N} = DiagTrav{T,N,typeof(A)}(A)
1923
DiagTrav{T}(A::AbstractArray{<:Any,N}) where {T,N} = DiagTrav{T,N}(A)
2024
DiagTrav(A::AbstractArray{T}) where T = DiagTrav{T}(A)
2125

26+
27+
zero_bottomright(array) = zero_bottomright(array, axes(array))
28+
zero_bottomright!(array) = zero_bottomright!(array, axes(array))
29+
30+
function zero_bottomright(X::AbstractMatrix, _)
31+
m,n = size(X)
32+
μ = max(m,n)
33+
for j in rowsupport(X), k =-j+2:m) colsupport(X,j)
34+
iszero(X[k,j]) || return zero_bottomright!(copy(X))
35+
end
36+
X
37+
end
38+
39+
function zero_bottomright!(X::AbstractMatrix{T}, _) where T
40+
m,n = size(X)
41+
μ = max(m,n)
42+
for j in rowsupport(X), k =-j+2:m) colsupport(X,j)
43+
X[k,j] = zero(T)
44+
end
45+
X
46+
end
47+
48+
function zero_bottomright(X::AbstractArray{<:Any,3}, _)
49+
m,n,p = size(X)
50+
@assert m == n == p
51+
for= 0:n-1, j=0:n-1, k=max(0,n-(ℓ+j)):n-1
52+
iszero(X[k+1,j+1,ℓ+1]) || return zero_bottomright!(copy(X))
53+
end
54+
X
55+
end
56+
57+
function zero_bottomright!(X::AbstractArray{T,3}, _) where T
58+
m,n,p = size(X)
59+
@assert m == n == p
60+
for= 0:n-1, j=0:n-1, k=max(0,n-(ℓ+j)):n-1
61+
X[k+1,j+1,ℓ+1] = zero(T)
62+
end
63+
X
64+
end
65+
66+
67+
2268
function _krontrav_axes(A, B)
2369
m,n = length(A), length(B)
2470
mn = min(m,n)
@@ -35,6 +81,8 @@ axes(A::DiagTrav) = (_krontrav_axes(axes(A.array)...),)
3581

3682
copy(A::DiagTrav) = DiagTrav(copy(A.array))
3783

84+
similar(A::DiagTrav, ::Type{T}) where T = DiagTrav(similar(A.array, T))
85+
3886
struct DiagTravLayout{Lay} <: AbstractBlockLayout end
3987
MemoryLayout(::Type{<:DiagTrav{T, N, AA}}) where {T,N,AA} = DiagTravLayout{typeof(MemoryLayout(AA))}()
4088

@@ -62,7 +110,7 @@ function _diagtravgetindex(_, A::AbstractMatrix, K::Block{1})
62110
end
63111

64112

65-
_diagtravgetindex(::AbstractStridedLayout, A::AbstractMatrix, K::Block{1}) = layout_getindex(DiagTrav(A), K)
113+
_diagtravgetindex(::AbstractStridedLayout, A::AbstractMatrix, K::Block{1}) = layout_getindex(_DiagTrav(A), K)
66114

67115
function _diagtravview(::AbstractStridedLayout, A::AbstractMatrix, K::Block{1})
68116
k = Int(K)
@@ -110,6 +158,7 @@ function _diagtravgetindex(::AbstractStridedLayout, A::AbstractArray{T,3}, K::Bl
110158
end
111159

112160
getindex(A::DiagTrav, k::Int) = A[findblockindex(axes(A,1), k)]
161+
setindex!(A::DiagTrav, v, k::Int) = A[findblockindex(axes(A,1), k)] = v
113162

114163
function resize!(A::DiagTrav{<:Any,2}, K::Block{1})
115164
k = Int(K)
@@ -261,6 +310,7 @@ convert(::Type{B}, A::KronTrav{<:Any,2}) where B<:BandedBlockBandedMatrix = conv
261310
struct KronTravBandedBlockBandedLayout <: AbstractBandedBlockBandedLayout end
262311
struct KronTravLayout{M} <: AbstractBlockLayout end
263312

313+
const KronTravLayouts = Union{KronTravBandedBlockBandedLayout, KronTravLayout}
264314

265315

266316
krontravlayout(::Vararg{Any,M}) where M = KronTravLayout{M}()
@@ -320,4 +370,27 @@ BroadcastStyle(::Type{KronTrav{T,N,AA,AXIS}}) where {T,N,AA,AXIS} =
320370
###
321371

322372
*(a::Number, b::KronTrav) = KronTrav(a*first(b.args), tail(b.args)...)
323-
*(a::KronTrav, b::Number) = KronTrav(first(a.args)*b, tail(a.args)...)
373+
*(a::KronTrav, b::Number) = KronTrav(first(a.args)*b, tail(a.args)...)
374+
375+
376+
function copy(M::Mul{<:KronTravLayouts, <:DiagTravLayout})
377+
K,x = M.A,M.B
378+
A,B = K.args
379+
_krontrav_mul_diagtrav(K.args, invdiagtrav(x), eltype(M))
380+
end
381+
382+
_krontrav_mul_diagtrav((A,B), X::AbstractMatrix, ::Type{T}) where T = DiagTrav(convert(AbstractMatrix{T}, B*X*A'))
383+
function _krontrav_mul_diagtrav((A,B,C), X::AbstractArray{<:Any,3}, ::Type{T}) where T
384+
m,n,p = size(X)
385+
@assert m == n == p
386+
Y = similar(X, T)
387+
Z = similar(X, T)
388+
for k = 1:n, j=1:n mul!(view(Y,k,j,:),A,view(X,k,j,:)) end
389+
for k = 1:n, j=1:n mul!(view(Z,k,:,j),B,view(Y,k,:,j)) end
390+
for k = 1:n, j=1:n mul!(view(Y,:,k,j),C,view(Z,:,k,j)) end
391+
DiagTrav(Y)
392+
end
393+
394+
395+
396+
# C = α*B*X*A' + β*C

test/test_blockkron.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,7 @@ LinearAlgebra.factorize(A::MyLazyArray) = factorize(A.data)
6868

6969
@testset "InvDiagTrav" begin
7070
A = [1 2 3; 4 5 6; 7 8 9]
71-
@test invdiagtrav(BlockedVector(DiagTrav(A))) == [1 2 3; 4 5 0; 7 0 0]
72-
@test invdiagtrav(DiagTrav(A)) == A
71+
@test invdiagtrav(BlockedVector(DiagTrav(A))) == invdiagtrav(DiagTrav(A)) == [1 2 3; 4 5 0; 7 0 0]
7372
end
7473

7574
@testset "BlockKron" begin
@@ -124,7 +123,9 @@ LinearAlgebra.factorize(A::MyLazyArray) = factorize(A.data)
124123
@test copy(K) == K
125124

126125
X = [9 10; 11 0]
127-
@test K*DiagTrav(X) == DiagTrav(B*X*A')
126+
Y = [9 10; 11 12]
127+
@test K*DiagTrav(X) == K*DiagTrav(Y) == DiagTrav(B*X*A')
128+
@test K*DiagTrav(X) isa DiagTrav
128129

129130
@test K[Block.(Base.OneTo(2)), Block.(Base.OneTo(2))] == K[Block.(1:2), Block.(1:2)] == K
130131
end
@@ -160,9 +161,11 @@ LinearAlgebra.factorize(A::MyLazyArray) = factorize(A.data)
160161
K = KronTrav(A,B,C)
161162

162163
X = randn(n,n,n)
164+
Y = copy(X)
163165
for= 0:n-1, j=0:n-1, k=max(0,n-(ℓ+j)):n-1
164166
X[k+1,j+1,ℓ+1] = 0
165167
end
168+
@test DiagTrav(Y).array == X
166169
Y = float(similar(X))
167170
for k = 1:n, j=1:n Y[k,j,:] = A*X[k,j,:] end
168171
for k = 1:n, j=1:n Y[k,:,j] = B*Y[k,:,j] end

test/test_lazybandedinf.jl

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,28 @@ const OneToInfBlocks = InfiniteArraysBlockArraysExt.OneToInfBlocks
1212
const InfKronTravBandedBlockBandedLayout = LazyBandedMatricesInfiniteArraysExt.InfKronTravBandedBlockBandedLayout
1313

1414
@testset "∞ LazyBandedMatrices" begin
15-
@test MemoryLayout(LazyBandedMatrices.Tridiagonal(Fill(1,∞), Zeros(∞), Fill(3,∞))) isa TridiagonalToeplitzLayout
16-
@test MemoryLayout(LazyBandedMatrices.Bidiagonal(Fill(1,∞), Zeros(∞), :U)) isa BidiagonalToeplitzLayout
17-
@test MemoryLayout(LazyBandedMatrices.SymTridiagonal(Fill(1,∞), Zeros(∞))) isa TridiagonalToeplitzLayout
18-
19-
T = LazyBandedMatrices.Tridiagonal(Fill(1,∞), Zeros(∞), Fill(3,∞))
20-
@test T[2:∞,3:∞] isa SubArray
21-
@test exp.(T) isa BroadcastMatrix
22-
@test exp.(T)[2:∞,3:∞][1:10,1:10] == exp.(T[2:∞,3:∞])[1:10,1:10] == exp.(T[2:11,3:12])
23-
@test exp.(T)[2:∞,3:∞] isa BroadcastMatrix
24-
@test exp.(T[2:∞,3:∞]) isa BroadcastMatrix
25-
26-
B = LazyBandedMatrices.Bidiagonal(Fill(1,∞), Zeros(∞), :U)
27-
@test B[2:∞,3:∞] isa SubArray
28-
@test exp.(B) isa BroadcastMatrix
29-
@test exp.(B)[2:∞,3:∞][1:10,1:10] == exp.(B[2:∞,3:∞])[1:10,1:10] == exp.(B[2:11,3:12])
30-
@test exp.(B)[2:∞,3:∞] isa BroadcastMatrix
31-
32-
@testset "Diagonal{Fill} * Bidiagonal" begin
33-
A, B = Diagonal(Fill(2,∞)) , LazyBandedMatrices.Bidiagonal(exp.(1:∞), exp.(1:∞), :L)
34-
@test (A*B)[1:10,1:10] (B*A)[1:10,1:10] 2B[1:10,1:10]
15+
@testset "Tri/Bidiagonal" begin
16+
@test MemoryLayout(LazyBandedMatrices.Tridiagonal(Fill(1,∞), Zeros(∞), Fill(3,∞))) isa TridiagonalToeplitzLayout
17+
@test MemoryLayout(LazyBandedMatrices.Bidiagonal(Fill(1,∞), Zeros(∞), :U)) isa BidiagonalToeplitzLayout
18+
@test MemoryLayout(LazyBandedMatrices.SymTridiagonal(Fill(1,∞), Zeros(∞))) isa TridiagonalToeplitzLayout
19+
20+
T = LazyBandedMatrices.Tridiagonal(Fill(1,∞), Zeros(∞), Fill(3,∞))
21+
@test T[2:∞,3:∞] isa SubArray
22+
@test exp.(T) isa BroadcastMatrix
23+
@test exp.(T)[2:∞,3:∞][1:10,1:10] == exp.(T[2:∞,3:∞])[1:10,1:10] == exp.(T[2:11,3:12])
24+
@test exp.(T)[2:∞,3:∞] isa BroadcastMatrix
25+
@test exp.(T[2:∞,3:∞]) isa BroadcastMatrix
26+
27+
B = LazyBandedMatrices.Bidiagonal(Fill(1,∞), Zeros(∞), :U)
28+
@test B[2:∞,3:∞] isa SubArray
29+
@test exp.(B) isa BroadcastMatrix
30+
@test exp.(B)[2:∞,3:∞][1:10,1:10] == exp.(B[2:∞,3:∞])[1:10,1:10] == exp.(B[2:11,3:12])
31+
@test exp.(B)[2:∞,3:∞] isa BroadcastMatrix
32+
33+
@testset "Diagonal{Fill} * Bidiagonal" begin
34+
A, B = Diagonal(Fill(2,∞)) , LazyBandedMatrices.Bidiagonal(exp.(1:∞), exp.(1:∞), :L)
35+
@test (A*B)[1:10,1:10] (B*A)[1:10,1:10] 2B[1:10,1:10]
36+
end
3537
end
3638

3739
@testset "∞-unit blocks" begin
@@ -102,6 +104,12 @@ const InfKronTravBandedBlockBandedLayout = LazyBandedMatricesInfiniteArraysExt.I
102104

103105
@test A*A isa KronTrav
104106
@test (A*A)[Block.(Base.OneTo(3)), Block.(Base.OneTo(3))] A[Block.(1:3), Block.(1:4)]A[Block.(1:4), Block.(1:3)]
107+
108+
@testset "mul" begin
109+
X = zeros(∞,∞); X[1,1] = 1;
110+
KR = Block.(1:10)
111+
@test (A*DiagTrav(X))[KR] == ((A + 0I) * DiagTrav(X))[KR]
112+
end
105113
end
106114

107115
@testset "BlockHcat copyto!" begin

0 commit comments

Comments
 (0)