Skip to content

Commit b71450c

Browse files
authored
Work on 3-argument KronTrav (#134)
* Work on 3-argument KronTrav * Update blockkron.jl * 3-vector KronTrav
1 parent bf02ac0 commit b71450c

File tree

2 files changed

+36
-7
lines changed

2 files changed

+36
-7
lines changed

src/blockkron.jl

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ function _diagtravgetindex(::AbstractPaddedLayout{<:AbstractStridedLayout}, A::A
9595
end
9696
end
9797

98+
9899
function _diagtravgetindex(::AbstractStridedLayout, A::AbstractArray{T,3}, K::Block{1}) where T
99100
k = Int(K)
100101
m,n,p = size(A)
@@ -103,7 +104,7 @@ function _diagtravgetindex(::AbstractStridedLayout, A::AbstractArray{T,3}, K::Bl
103104
st3 = stride(A,3)
104105
ret = T[]
105106
for j = 0:k-1
106-
append!(ret, view(A, range(j*st + k-j; step=st3-st, length=j+1)))
107+
append!(ret, view(A, range(j*st + k-j; step=st3-st, length=j+1))) # this matches lexigraphical order
107108
end
108109
ret
109110
end
@@ -152,7 +153,7 @@ end
152153
size(A::InvDiagTrav) = (blocksize(A.vector,1),blocksize(A.vector,1))
153154

154155
function getindex(A::InvDiagTrav{T}, k::Int, j::Int) where T
155-
if k+j-1  blocksize(A.vector,1)
156+
if k+j-1 blocksize(A.vector,1)
156157
A.vector[Block(k+j-1)][j]
157158
else
158159
zero(T)
@@ -185,20 +186,35 @@ KronTrav(A::AbstractArray...) = KronTrav{mapreduce(eltype, promote_type, A)}(A..
185186
copy(K::KronTrav) = KronTrav(map(copy,K.args), K.axes)
186187
axes(A::KronTrav) = A.axes
187188

188-
function getindex(M::KronTrav{<:Any,1}, K::Block{1})
189-
A,B = M.args
189+
190+
191+
function _krontrav_getindex(K::Block{1}, A, B)
190192
m,n = length(A), length(B)
191193
mn = min(m,n)
192194
k = Int(K)
193-
if k  mn
195+
if k mn
194196
A[1:k] .* B[k:-1:1]
195-
elseif m < n
197+
elseif m < n
196198
A .* B[k:-1:(k-m+1)]
197199
else # n < m
198200
A[(k-n+1):k] .* B[end:-1:1]
199201
end
200202
end
201203

204+
205+
206+
function _krontrav_getindex(K::Block{1}, A, B, C)
207+
@assert length(A) == length(B) == length(C) # TODO: generalise
208+
209+
# make a tuple corresponding to lexigraphical order
210+
ret = Vector{promote_type(eltype(A),eltype(B),eltype(C))}()
211+
n = Int(K)
212+
for k = 1:n, j=1:k
213+
push!(ret, C[n-k+1]B[k-j+1]A[j])
214+
end
215+
ret
216+
end
217+
202218
function _krontrav_getindex(K::Block{2}, A, B)
203219
m,n = size(A), size(B)
204220
@assert m == n
@@ -219,6 +235,7 @@ function _krontrav_getindex(Kin::Block{2}, A, B, C)
219235
AB
220236
end
221237

238+
getindex(M::KronTrav{<:Any,1}, K::Block{1}) = _krontrav_getindex(K, M.args...)
222239
getindex(M::KronTrav{<:Any,2}, K::Block{2}) = _krontrav_getindex(K, M.args...)
223240

224241

test/test_blockkron.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,21 @@ LinearAlgebra.factorize(A::MyLazyArray) = factorize(A.data)
9999
a = [1,2,3]
100100
b = [4,5,6]
101101
c = [7,8]
102-
@test KronTrav(a,b) == DiagTrav(b*a')
102+
@test KronTrav(a,b) == DiagTrav(b*a') == DiagTrav(kron(a',b))
103103
@test KronTrav(a,c) == [7,8,14,16,21]
104104
@test KronTrav(c,a) == [7,14,8,21,16]
105+
106+
X = rotl90(Matrix(UpperTriangular(randn(3,3)))) # triangle of coefficients
107+
@test KronTrav(a,b)' * DiagTrav(X) b'*X*a sum(b .* X .* a')
108+
end
109+
110+
@testset "3-vectors" begin
111+
a = [1,2,3]
112+
b = [4,5,6]
113+
c = [7,8,9]
114+
115+
X = [k + j + l - 2 3 ? randn() : 0.0 for k=1:3,j=1:3,l=1:3]
116+
@test KronTrav(a,b,c)' * DiagTrav(X) sum(c .* X .* b' .* reshape(a,1,1,3))
105117
end
106118

107119
@testset "matrix" begin

0 commit comments

Comments
 (0)