44# Block
55# ##
66
7+ function _DiagTrav end
8+
79"""
810 DiagTrav(A::AbstractMatrix)
911
1012converts a matrix to a block vector by traversing the anti-diagonals.
1113"""
1214struct DiagTrav{T, N, AA<: AbstractArray{T,N} } <: AbstractBlockVector{T}
1315 array:: AA
14- function DiagTrav {T, N, AA} (array:: AA ) where {T, N, AA<: AbstractArray{T,N} }
16+ global function _DiagTrav (array:: AA ) where {T, N, AA<: AbstractArray{T,N} }
1517 new {T,N,AA} (array)
1618 end
1719end
20+ DiagTrav {T,N,AA} (A:: AA ) where {T, N, AA<: AbstractArray{T,N} } = _DiagTrav (zero_bottomright (A))
21+
1822DiagTrav {T,N} (A:: AbstractArray ) where {T,N} = DiagTrav {T,N,typeof(A)} (A)
1923DiagTrav {T} (A:: AbstractArray{<:Any,N} ) where {T,N} = DiagTrav {T,N} (A)
2024DiagTrav (A:: AbstractArray{T} ) where T = DiagTrav {T} (A)
2125
26+
27+ zero_bottomright (array) = zero_bottomright (array, axes (array))
28+ zero_bottomright! (array) = zero_bottomright! (array, axes (array))
29+
30+ function zero_bottomright (X:: AbstractMatrix , _)
31+ m,n = size (X)
32+ μ = max (m,n)
33+ for j in rowsupport (X), k = (μ- j+ 2 : m) ∩ colsupport (X,j)
34+ iszero (X[k,j]) || return zero_bottomright! (copy (X))
35+ end
36+ X
37+ end
38+
39+ function zero_bottomright! (X:: AbstractMatrix{T} , _) where T
40+ m,n = size (X)
41+ μ = max (m,n)
42+ for j in rowsupport (X), k = (μ- j+ 2 : m) ∩ colsupport (X,j)
43+ X[k,j] = zero (T)
44+ end
45+ X
46+ end
47+
48+ function zero_bottomright (X:: AbstractArray{<:Any,3} , _)
49+ m,n,p = size (X)
50+ @assert m == n == p
51+ for ℓ = 0 : n- 1 , j= 0 : n- 1 , k= max (0 ,n- (ℓ+ j)): n- 1
52+ iszero (X[k+ 1 ,j+ 1 ,ℓ+ 1 ]) || return zero_bottomright! (copy (X))
53+ end
54+ X
55+ end
56+
57+ function zero_bottomright! (X:: AbstractArray{T,3} , _) where T
58+ m,n,p = size (X)
59+ @assert m == n == p
60+ for ℓ = 0 : n- 1 , j= 0 : n- 1 , k= max (0 ,n- (ℓ+ j)): n- 1
61+ X[k+ 1 ,j+ 1 ,ℓ+ 1 ] = zero (T)
62+ end
63+ X
64+ end
65+
66+
67+
2268function _krontrav_axes (A, B)
2369 m,n = length (A), length (B)
2470 mn = min (m,n)
@@ -35,6 +81,8 @@ axes(A::DiagTrav) = (_krontrav_axes(axes(A.array)...),)
3581
3682copy (A:: DiagTrav ) = DiagTrav (copy (A. array))
3783
84+ similar (A:: DiagTrav , :: Type{T} ) where T = DiagTrav (similar (A. array, T))
85+
3886struct DiagTravLayout{Lay} <: AbstractBlockLayout end
3987MemoryLayout (:: Type{<:DiagTrav{T, N, AA}} ) where {T,N,AA} = DiagTravLayout {typeof(MemoryLayout(AA))} ()
4088
@@ -62,7 +110,7 @@ function _diagtravgetindex(_, A::AbstractMatrix, K::Block{1})
62110end
63111
64112
65- _diagtravgetindex (:: AbstractStridedLayout , A:: AbstractMatrix , K:: Block{1} ) = layout_getindex (DiagTrav (A), K)
113+ _diagtravgetindex (:: AbstractStridedLayout , A:: AbstractMatrix , K:: Block{1} ) = layout_getindex (_DiagTrav (A), K)
66114
67115function _diagtravview (:: AbstractStridedLayout , A:: AbstractMatrix , K:: Block{1} )
68116 k = Int (K)
@@ -110,6 +158,7 @@ function _diagtravgetindex(::AbstractStridedLayout, A::AbstractArray{T,3}, K::Bl
110158end
111159
112160getindex (A:: DiagTrav , k:: Int ) = A[findblockindex (axes (A,1 ), k)]
161+ setindex! (A:: DiagTrav , v, k:: Int ) = A[findblockindex (axes (A,1 ), k)] = v
113162
114163function resize! (A:: DiagTrav{<:Any,2} , K:: Block{1} )
115164 k = Int (K)
@@ -261,6 +310,7 @@ convert(::Type{B}, A::KronTrav{<:Any,2}) where B<:BandedBlockBandedMatrix = conv
261310struct KronTravBandedBlockBandedLayout <: AbstractBandedBlockBandedLayout end
262311struct KronTravLayout{M} <: AbstractBlockLayout end
263312
313+ const KronTravLayouts = Union{KronTravBandedBlockBandedLayout, KronTravLayout}
264314
265315
266316krontravlayout (:: Vararg{Any,M} ) where M = KronTravLayout {M} ()
@@ -320,4 +370,27 @@ BroadcastStyle(::Type{KronTrav{T,N,AA,AXIS}}) where {T,N,AA,AXIS} =
320370# ##
321371
322372* (a:: Number , b:: KronTrav ) = KronTrav (a* first (b. args), tail (b. args)... )
323- * (a:: KronTrav , b:: Number ) = KronTrav (first (a. args)* b, tail (a. args)... )
373+ * (a:: KronTrav , b:: Number ) = KronTrav (first (a. args)* b, tail (a. args)... )
374+
375+
376+ function copy (M:: Mul{<:KronTravLayouts, <:DiagTravLayout} )
377+ K,x = M. A,M. B
378+ A,B = K. args
379+ _krontrav_mul_diagtrav (K. args, invdiagtrav (x), eltype (M))
380+ end
381+
382+ _krontrav_mul_diagtrav ((A,B), X:: AbstractMatrix , :: Type{T} ) where T = DiagTrav (convert (AbstractMatrix{T}, B* X* A' ))
383+ function _krontrav_mul_diagtrav ((A,B,C), X:: AbstractArray{<:Any,3} , :: Type{T} ) where T
384+ m,n,p = size (X)
385+ @assert m == n == p
386+ Y = similar (X, T)
387+ Z = similar (X, T)
388+ for k = 1 : n, j= 1 : n mul! (view (Y,k,j,:),A,view (X,k,j,:)) end
389+ for k = 1 : n, j= 1 : n mul! (view (Z,k,:,j),B,view (Y,k,:,j)) end
390+ for k = 1 : n, j= 1 : n mul! (view (Y,:,k,j),C,view (Z,:,k,j)) end
391+ DiagTrav (Y)
392+ end
393+
394+
395+
396+ # C = α*B*X*A' + β*C
0 commit comments