Skip to content

Commit ac2fbec

Browse files
authored
Add blockkron (#123)
* Start work on blockkron * Fix axes, size * Add tests for matrix case * Add tests, clean up getindex
1 parent 470d388 commit ac2fbec

File tree

3 files changed

+128
-3
lines changed

3 files changed

+128
-3
lines changed

src/BlockArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ export PseudoBlockArray, PseudoBlockMatrix, PseudoBlockVector, PseudoBlockVecOrM
1414

1515
export undef_blocks, undef, findblock, findblockindex
1616

17-
export khatri_rao
17+
export khatri_rao, blockkron, BlockKron
1818

1919
export blockappend!, blockpush!, blockpushfirst!, blockpop!, blockpopfirst!
2020

src/blockproduct.jl

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ References
66
* Khatri, C. G., and Rao, C. Radhakrishna (1968) Solutions to Some Functional Equations and Their Applications to Characterization of Probability Distributions. Sankhya: Indian J. Statistics, Series A 30, 167–180.
77
"""
88
function khatri_rao(A::AbstractBlockMatrix, B::AbstractBlockMatrix)
9-
#
9+
#
1010
Ablksize = blocksize(A)
1111
Bblksize = blocksize(B)
1212

@@ -27,4 +27,90 @@ end
2727

2828
function khatri_rao(A::AbstractMatrix, B::AbstractMatrix)
2929
kron(A, B)
30-
end
30+
end
31+
32+
""""
33+
BlockKron(A...)
34+
35+
creates a lazy representation of kron(A...) with the natural
36+
block-structure imposed. This is a component in `blockkron(A...)`.
37+
"""
38+
struct BlockKron{T,N,ARGS<:Tuple} <: AbstractBlockArray{T,N}
39+
args::ARGS
40+
end
41+
42+
BlockKron{T,N}(A...) where {T,N} = BlockKron{T,N,typeof(A)}(A)
43+
BlockKron{T}(A::AbstractVector, B::AbstractVector, C::AbstractVector...) where {T} = BlockKron{T,1}(A, B, C...)
44+
BlockKron{T}(A, B, C...) where {T} = BlockKron{T,2}(A, B, C...)
45+
BlockKron(A, B, C...) = BlockKron{mapreduce(eltype,promote_type,(A,B,C...))}(A, B, C...)
46+
47+
48+
size(B::BlockKron) = size(Kron(B))
49+
50+
size(K::BlockKron, j::Int) = prod(size.(K.args, j))
51+
size(a::BlockKron{<:Any,1}) = (size(a,1),)
52+
size(a::BlockKron{<:Any,2}) = (size(a,1), size(a,2))
53+
54+
function axes(K::BlockKron{<:Any,1})
55+
A,B = K.args
56+
(blockedrange(fill(prod(size.(tail(K.args),1)), size(K.args[1],1))),)
57+
end
58+
59+
function axes(K::BlockKron{<:Any,2})
60+
A,B = K.args
61+
blockedrange.((fill(prod(size.(tail(K.args),1)), size(K.args[1],1)),
62+
fill(prod(size.(tail(K.args),2)), size(K.args[1],2))))
63+
end
64+
65+
kron_getindex((A,)::Tuple{AbstractVector}, k::Integer) = A[k]
66+
function kron_getindex((A,B)::NTuple{2,AbstractVector}, k::Integer)
67+
K,κ = divrem(k-1, length(B))
68+
A[K+1]*B[κ+1]
69+
end
70+
kron_getindex((A,)::Tuple{AbstractMatrix}, k::Integer, j::Integer) = A[k,j]
71+
function kron_getindex((A,B)::NTuple{2,AbstractVecOrMat}, k::Integer, j::Integer)
72+
K,κ = divrem(k-1, size(B,1))
73+
J,ξ = divrem(j-1, size(B,2))
74+
A[K+1,J+1]*B[κ+1+1]
75+
end
76+
77+
kron_getindex(args::Tuple, k::Integer, j::Integer) = kron_getindex(tuple(BlockKron(args[1:2]...), args[3:end]...), k, j)
78+
kron_getindex(args::Tuple, k::Integer) = kron_getindex(tuple(BlockKron(args[1:2]...), args[3:end]...), k)
79+
80+
getindex(K::BlockKron{<:Any,1}, k::Integer) = kron_getindex(K.args, k)
81+
getindex(K::BlockKron{<:Any,2}, k::Integer, j::Integer) = kron_getindex(K.args, k, j)
82+
83+
kron_getblock((a,b)::Tuple{Any,Any}, k::Integer) = a[k]*b
84+
kron_getblock(args, k::Integer) = args[1][k]*BlockKron(tail(args)...)
85+
86+
kron_getblock((a,b)::Tuple{Any,Any}, k::Integer, j::Integer) = a[k,j]*b
87+
kron_getblock(args, k::Integer, j::Integer) = args[1][k,j]*BlockKron(tail(args)...)
88+
89+
getblock(K::BlockKron{<:Any,1}, k::Integer) = kron_getblock(K.args, k)
90+
getblock(K::BlockKron{<:Any,2}, k::Integer, j::Integer) = kron_getblock(K.args, k, j)
91+
92+
# const SubKron{T,M1,M2,R1,R2} = SubArray{T,2,<:BlockKron{T,M1,M2},<:Tuple{<:BlockSlice{R1},<:BlockSlice{R2}}}
93+
94+
95+
# BroadcastStyle(::Type{<:SubKron{<:Any,<:Any,B,Block1,Block1}}) where B =
96+
# BroadcastStyle(B)
97+
98+
99+
# allow dispatch on memory layout
100+
_blockkron(_, A) = BlockArray(BlockKron(A...))
101+
102+
103+
""""
104+
blockkron(A...)
105+
106+
creates a blocked version of kron(A...) with the natural
107+
block-structure imposed.
108+
"""
109+
blockkron(A...) = _blockkron(map(MemoryLayout,A), A)
110+
111+
"""
112+
blockvec(A::AbstractMatrix)
113+
114+
creates a blocked version of `vec(A)`, with the block structure used to represent the columns.
115+
"""
116+
blockvec(A::AbstractMatrix) = PseudoBlockVector(vec(A), Fill(size(A,1), size(A,2)))

test/test_blockproduct.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,42 @@ end
100100
@test khatri_rao(A, B) kron(A, B)
101101
end
102102

103+
@testset "blockkron" begin
104+
a = [1,2]
105+
b = [3,4,5]
106+
= BlockKron(a,b) # lazy version of blockkron
107+
k = blockkron(a,b)
108+
@test k ==== kron(a,b)
109+
@test blocksize(k) == blocksize(k̃) == size(a)
110+
111+
@test k[Block(1)] == k̃[Block(1)] == a[1]*b
112+
@test k[Block(2)] == k̃[Block(2)] == a[2]*b
113+
c = 6:8
114+
= BlockKron(a,b,c)
115+
@test== blockkron(a,b,c) == kron(a,b,c)
116+
@test k̄[Block(1)][Block(1)] == a[1]*b[1]*c
117+
@test k̄[Block(1)][Block(2)] == a[1]*b[2]*c
118+
@test k̄[Block(2)][Block(3)] == a[2]*b[3]*c
119+
120+
A = randn(2,3)
121+
B = randn(3,4)
122+
K = blockkron(A,B)
123+
= BlockKron(A,B)
124+
@test K ==== kron(A,B)
125+
@test blocksize(K) == blocksize(K̃) == size(A)
126+
@test K[Block(1,1)] == K̃[Block(1),Block(1)] == A[1,1]*B
127+
@test K[Block(2,3)] == K̃[Block(2),Block(3)] == A[2,3]*B
128+
C = randn(2,5)
129+
= BlockKron(A,B,C)
130+
@test== blockkron(A,B,C) == kron(A,B,C)
131+
@test K̄[Block(1,1)][Block(1,1)] A[1,1]*B[1,1]*C
132+
@test K̄[Block(2,3)][Block(3,4)] A[2,3]*B[3,4]*C
133+
134+
@test blockkron(a,B) == kron(a,B)
135+
@test blockkron(A,b) == kron(A,b)
136+
@test blockkron(A,b,c) == kron(A,b,c)
137+
@test blockkron(A,b,C) == kron(A,b,C)
138+
139+
@test_throws MethodError BlockKron()
140+
@test_throws MethodError BlockKron(a)
141+
end

0 commit comments

Comments
 (0)