Skip to content

Commit 763f359

Browse files
emeralidlfivefifty
authored andcommitted
Optimizing determinants and traces of Kron (#78)
* Optimize determinant and trace calculation for Kron * Add tests for det and tr of Kron * Optimize determinant and trace calculation for Kron * Add tests for det and tr of Kron * minor * Kron Tests: compare determinant calculation against Julia's implementation for kronecker of rectangular factors
1 parent db8b207 commit 763f359

File tree

3 files changed

+73
-16
lines changed

3 files changed

+73
-16
lines changed

src/LazyArrays.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ module LazyArrays
66
using Base, Base.Broadcast, LinearAlgebra, FillArrays, StaticArrays, ArrayLayouts
77
import LinearAlgebra.BLAS
88

9-
import Base: AbstractArray, AbstractMatrix, AbstractVector,
9+
import Base: AbstractArray, AbstractMatrix, AbstractVector,
1010
ReinterpretArray, ReshapedArray, AbstractCartesianIndex, Slice,
1111
RangeIndex, BroadcastStyle, copyto!, length, broadcastable, axes,
1212
getindex, eltype, tail, IndexStyle, IndexLinear, getproperty,
@@ -36,8 +36,8 @@ import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, broadcas
3636
combine_eltypes, DefaultArrayStyle, instantiate, materialize,
3737
materialize!, eltypes
3838

39-
import LinearAlgebra: AbstractTriangular, AbstractQ, checksquare, pinv, fill!, tilebufsize, Abuf, Bbuf, Cbuf, dot, factorize, qr, lu, cholesky,
40-
norm2, norm1, normInf, normMinusInf
39+
import LinearAlgebra: AbstractTriangular, AbstractQ, checksquare, pinv, fill!, tilebufsize, Abuf, Bbuf, Cbuf, dot, factorize, qr, lu, cholesky,
40+
norm2, norm1, normInf, normMinusInf, det, tr
4141

4242
import LinearAlgebra.BLAS: BlasFloat, BlasReal, BlasComplex
4343

@@ -55,8 +55,8 @@ if VERSION < v"1.2-"
5555
import Base: has_offset_axes
5656
require_one_based_indexing(A...) = !has_offset_axes(A...) || throw(ArgumentError("offset arrays are not supported but got an array with index other than 1"))
5757
else
58-
import Base: require_one_based_indexing
59-
end
58+
import Base: require_one_based_indexing
59+
end
6060

6161
export Mul, Applied, MulArray, MulVector, MulMatrix, InvMatrix, PInvMatrix,
6262
Hcat, Vcat, Kron, BroadcastArray, BroadcastMatrix, BroadcastVector, cache, Ldiv, Inv, PInv, Diff, Cumsum,

src/lazyoperations.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,42 @@ axes(a::Kron{<:Any,1}) = (OneTo(size(a,1)),)
2323
axes(a::Kron{<:Any,2}) = (OneTo(size(a,1)), OneTo(size(a,2)))
2424
axes(a::Kron{<:Any,N}) where {N} = (@_inline_meta; ntuple(M -> OneTo(size(a, M)), Val(N)))
2525

26+
27+
function det(K::Kron{<:Any, 2})
28+
(size(K, 1) == size(K, 2)) || throw(DimensionMismatch("matrix is not square: dimensions are $(size(K))"))
29+
30+
d = 1.
31+
s = size(K, 1)
32+
33+
for A in K.args
34+
if size(A, 1) == size(A, 2)
35+
dA = det(A)
36+
if iszero(dA)
37+
return dA
38+
end
39+
d *= dA^(s ÷ size(A, 1))
40+
else
41+
# The Kronecker Product of rectangular matrices, if it is square, will
42+
# have determinant zero. This can be shown by using the fact that
43+
# rank(A ⊗ B) = rank(A)rank(B) and showing that this is strictly less
44+
# than the number of rows in the resulting Kronecker matrix. Hence,
45+
# since A ⊗ B does not have full rank, its determinant must be zero.
46+
return zero(d)
47+
end
48+
end
49+
return d
50+
end
51+
52+
function tr(K::Kron{<:Any, 2})
53+
(size(K, 1) == size(K, 2)) || throw(DimensionMismatch("matrix is not square: dimensions are $(size(K))"))
54+
if all(A -> (size(A, 1) == size(A, 2)), K.args) # check if all component matrices are square
55+
return prod(tr.(K.args))
56+
else
57+
return sum(diag(K))
58+
end
59+
end
60+
61+
2662
getindex(K::Kron{T,1,<:Tuple{<:AbstractVector}}, k::Int) where T =
2763
first(K.args)[k]
2864

test/runtests.jl

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,35 @@ include("broadcasttests.jl")
5858

5959
A = randn(3,2)
6060
B = randn(4,6)
61-
K = Kron(A,B)
62-
@test [K[k,j] for k=1:size(K,1), j=1:size(K,2)] == Array(Kron(A,B)) == kron(A,B)
63-
K = Kron(A,B')
64-
@test [K[k,j] for k=1:size(K,1), j=1:size(K,2)] == Array(Kron(A,B')) == kron(A,B')
65-
K = Kron(A',B)
66-
@test [K[k,j] for k=1:size(K,1), j=1:size(K,2)] == Array(Kron(A',B)) == kron(A',B)
67-
K = Kron(A',B')
68-
@test [K[k,j] for k=1:size(K,1), j=1:size(K,2)] == Array(Kron(A',B')) == kron(A',B')
61+
K, k = Kron(A,B), kron(A,B)
62+
@test [K[k,j] for k=1:size(K,1), j=1:size(K,2)] == Array(Kron(A,B)) == k
63+
@test det(K) == 0 # kronecker of rectangular factors
64+
@test isapprox(det(k), det(K); atol=eps(eltype(K)), rtol=0)
65+
@test tr(K) tr(k)
66+
67+
K, k = Kron(A,B'), kron(A,B')
68+
@test [K[k,j] for k=1:size(K,1), j=1:size(K,2)] == Array(Kron(A,B')) == k
69+
@test_throws DimensionMismatch det(K)
70+
@test_throws DimensionMismatch tr(K)
71+
72+
K, k = Kron(A',B), kron(A',B)
73+
@test [K[k,j] for k=1:size(K,1), j=1:size(K,2)] == Array(Kron(A',B)) == k
74+
@test_throws DimensionMismatch det(K)
75+
@test_throws DimensionMismatch tr(K)
76+
77+
K, k = Kron(A',B'), kron(A',B')
78+
@test [K[k,j] for k=1:size(K,1), j=1:size(K,2)] == Array(Kron(A',B')) == k
79+
@test det(K) == 0 # kronecker of rectangular factors
80+
@test isapprox(det(k), det(K); atol=eps(eltype(K)), rtol=0)
81+
@test tr(K) tr(k)
82+
83+
A = randn(3,3)
84+
B = randn(6,6)
85+
C = randn(2,2)
86+
K, k = Kron(A,B,C), kron(A,B,C)
87+
@test [K[k,j] for k=1:size(K,1), j=1:size(K,2)] == Array(Kron(A,B,C)) == k
88+
@test det(K) det(k)
89+
@test tr(K) tr(k)
6990

7091
A = randn(3,2)
7192
B = randn(4,6)
@@ -157,7 +178,7 @@ end
157178
end
158179

159180
@testset "colsupport past size" begin
160-
C = cache(Zeros(5,5)); C[5,1];
181+
C = cache(Zeros(5,5)); C[5,1];
161182
@test colsupport(C,1) == Base.OneTo(5)
162183
@test colsupport(C,3) == 1:0
163184
@test rowsupport(C,1) == Base.OneTo(1)
@@ -258,12 +279,12 @@ end
258279
v == BroadcastVector(exp, [1,2,3]) == exp.([1,2,3])
259280

260281
Base.IndexStyle(typeof(BroadcastVector(exp, [1,2,3]))) == IndexLinear()
261-
282+
262283
bc = broadcasted(exp,[1 2; 3 4])
263284
M = BroadcastArray(exp, [1 2; 3 4])
264285
@test BroadcastArray(bc) == BroadcastMatrix(bc) == BroadcastMatrix{Float64,typeof(exp),typeof(bc.args)}(bc) ==
265286
M == BroadcastMatrix(BroadcastMatrix(bc)) == BroadcastMatrix(exp,[1 2; 3 4]) == exp.([1 2; 3 4])
266-
287+
267288
@test exp.(v') isa BroadcastMatrix
268289
@test exp.(transpose(v)) isa BroadcastMatrix
269290
@test exp.(M') isa BroadcastMatrix

0 commit comments

Comments
 (0)