Skip to content

Commit ac3ddc2

Browse files
Allows blocks to be <:AbstractVector or <:Tuple
- This is a breaking change, as the type param `V` is now different, being the type of the `blocks` collection, rather than the `eltype` of that collection. - This allows allows the collection to be something other than a `Vector` e.g. - a `CuArray` of matrices (allowing use on GPU hopefully) - a `Tuple` with different concrete subtypes of `AbstractMatrix{T}` (avoiding heterogeneous block to being abstractly typed i.e. a way to avoid `Vector{AbstractMatrix{T}}`)
1 parent fda7259 commit ac3ddc2

File tree

7 files changed

+68
-32
lines changed

7 files changed

+68
-32
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockDiagonals"
22
uuid = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
33
authors = ["Invenia Technical Computing Corporation"]
4-
version = "0.1.14"
4+
version = "0.2.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/blockdiagonal.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
# Core functionality for the `BlockDiagonal` type
22

33
"""
4-
BlockDiagonal{T, V<:AbstractMatrix{T}} <: AbstractMatrix{T}
4+
BlockDiagonal{T, V} <: AbstractMatrix{T}
5+
BlockDiagonal(blocks::V) -> BlockDiagonal{T,V}
56
67
A matrix with matrices on the diagonal, and zeros off the diagonal.
7-
"""
8-
struct BlockDiagonal{T, V<:AbstractMatrix{T}} <: AbstractMatrix{T}
9-
blocks::Vector{V}
108
11-
function BlockDiagonal{T, V}(blocks::Vector{V}) where {T, V<:AbstractMatrix{T}}
12-
return new{T, V}(blocks)
13-
end
9+
!!! info "`V` type"
10+
`blocks::V` should be a `Tuple` or `AbstractVector` where each component (each block) is
11+
`<:AbstractMatrix{T}` for some common element type `T`.
12+
"""
13+
struct BlockDiagonal{T, V} <: AbstractMatrix{T}
14+
blocks::V
1415
end
1516

16-
function BlockDiagonal(blocks::Vector{V}) where {T, V<:AbstractMatrix{T}}
17+
function BlockDiagonal(blocks::V) where {
18+
T, V<:Union{Tuple{Vararg{<:AbstractMatrix{T}}}, AbstractVector{<:AbstractMatrix{T}}}
19+
}
1720
return BlockDiagonal{T, V}(blocks)
1821
end
1922

@@ -151,7 +154,9 @@ function _block_indices(B::BlockDiagonal, i::Integer, j::Integer)
151154
p += 1
152155
j -= ncols[p]
153156
end
154-
i -= sum(nrows[1:(p-1)])
157+
if !isempty(nrows[1:(p-1)])
158+
i -= sum(nrows[1:(p-1)])
159+
end
155160
# if row `i` outside of block `p`, set `p` to place-holder value `-1`
156161
if i <= 0 || i > nrows[p]
157162
p = -1

src/chainrules.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# constructor
2-
function ChainRulesCore.rrule(::Type{<:BlockDiagonal}, blocks::Vector{V}) where {V}
2+
function ChainRulesCore.rrule(::Type{<:BlockDiagonal}, blocks::V) where {V<:AbstractVector}
33
BlockDiagonal_pullback::Composite) = (NO_FIELDS, Δ.blocks)
44
return BlockDiagonal(blocks), BlockDiagonal_pullback
55
end
@@ -27,7 +27,7 @@ function ChainRulesCore.rrule(
2727
::typeof(*),
2828
bm::BlockDiagonal{T, V},
2929
v::StridedVector{T}
30-
) where {T<:Union{Real, Complex}, V<:Matrix{T}}
30+
) where {T<:Union{Real, Complex}, V<:Vector{Matrix{T}}}
3131

3232
y = bm * v
3333

src/linalg.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ svdvals_blockwise(B::BlockDiagonal) = mapreduce(svdvals, vcat, blocks(B))
7676
LinearAlgebra.svdvals(B::BlockDiagonal) = sort!(svdvals_blockwise(B); rev=true)
7777

7878
# `B = U * Diagonal(S) * Vt` with `U` and `Vt` `BlockDiagonal` (`S` only sorted block-wise).
79-
function svd_blockwise(B::BlockDiagonal{T}; full::Bool=false) where T
79+
function svd_blockwise(B::BlockDiagonal{T, <:AbstractVector}; full::Bool=false) where T
8080
U = Matrix{float(T)}[]
8181
S = Vector{float(T)}()
8282
Vt = Matrix{float(T)}[]
@@ -88,6 +88,17 @@ function svd_blockwise(B::BlockDiagonal{T}; full::Bool=false) where T
8888
end
8989
return BlockDiagonal(U), S, BlockDiagonal(Vt)
9090
end
91+
function svd_blockwise(B::BlockDiagonal{T, <:Tuple}; full::Bool=false) where T
92+
S = Vector{float(T)}()
93+
U_Vt = ntuple(length(blocks(B))) do i
94+
F = svd(getblock(B, i), full=full)
95+
append!(S, F.S)
96+
(F.U, F.Vt)
97+
end
98+
U = first.(U_Vt)
99+
Vt = last.(U_Vt)
100+
return BlockDiagonal(U), S, BlockDiagonal(Vt)
101+
end
91102

92103
function LinearAlgebra.svd(B::BlockDiagonal; full::Bool=false)::SVD
93104
U, S, Vt = svd_blockwise(B, full=full)

test/base_maths.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,14 @@ using Test
77
rng = MersenneTwister(123456)
88
N1, N2, N3 = 3, 4, 5
99
N = N1 + N2 + N3
10-
b1 = BlockDiagonal([rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N3, N3)])
11-
b2 = BlockDiagonal([rand(rng, N1, N1), rand(rng, N3, N3), rand(rng, N2, N2)])
12-
b3 = BlockDiagonal([rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N2, N2)])
10+
blocks1 = [rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N3, N3)]
11+
blocks2 = [rand(rng, N1, N1), rand(rng, N3, N3), rand(rng, N2, N2)]
12+
blocks3 = [rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N2, N2)]
13+
14+
@testset "$T" for (T, (b1, b2, b3)) in (
15+
Tuple => (BlockDiagonal(Tuple(blocks1)), BlockDiagonal(Tuple(blocks2)), BlockDiagonal(Tuple(blocks3))),
16+
Vector => (BlockDiagonal(blocks1), BlockDiagonal(blocks2), BlockDiagonal(blocks3)),
17+
)
1318
A = rand(rng, N, N + N1)
1419
B = rand(rng, N + N1, N + N2)
1520
A′, B′ = A', B'
@@ -127,8 +132,8 @@ using Test
127132
end
128133

129134
@testset "Non-Square BlockDiagonal * Non-Square BlockDiagonal" begin
130-
b4 = BlockDiagonal([ones(2, 4), 2 * ones(3, 2)])
131-
b5 = BlockDiagonal([3 * ones(2, 2), 2 * ones(4, 1)])
135+
b4 = BlockDiagonal(T([ones(2, 4), 2 * ones(3, 2)]))
136+
b5 = BlockDiagonal(T([3 * ones(2, 2), 2 * ones(4, 1)]))
132137

133138
@test b4 * b5 isa Array
134139
@test b4 * b5 == [6 * ones(2, 2) 4 * ones(2, 1); zeros(3, 2) 8 * ones(3, 1)]
@@ -138,3 +143,4 @@ using Test
138143
end
139144
end # Multiplication
140145
end
146+
end

test/blockdiagonal.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,14 @@ using Test
77
rng = MersenneTwister(123456)
88
N1, N2, N3 = 3, 4, 5
99
N = N1 + N2 + N3
10-
b1 = BlockDiagonal([rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N3, N3)])
11-
b2 = BlockDiagonal([rand(rng, N1, N1), rand(rng, N3, N3), rand(rng, N2, N2)])
12-
b3 = BlockDiagonal([rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N2, N2)])
10+
blocks1 = [rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N3, N3)]
11+
blocks2 = [rand(rng, N1, N1), rand(rng, N3, N3), rand(rng, N2, N2)]
12+
blocks3 = [rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N2, N2)]
13+
14+
@testset "$T" for (T, (b1, b2, b3)) in (
15+
Tuple => (BlockDiagonal(Tuple(blocks1)), BlockDiagonal(Tuple(blocks2)), BlockDiagonal(Tuple(blocks3))),
16+
Vector => (BlockDiagonal(blocks1), BlockDiagonal(blocks2), BlockDiagonal(blocks3)),
17+
)
1318
A = rand(rng, N, N + N1)
1419
B = rand(rng, N + N1, N + N2)
1520
A′, B′ = A', B'
@@ -42,8 +47,10 @@ using Test
4247
end
4348

4449
@testset "parent" begin
45-
@test parent(b1) isa Vector{<:AbstractMatrix}
50+
@test parent(b1) isa Union{Tuple,AbstractVector}
51+
@test eltype(parent(b1)) <: AbstractMatrix
4652
@test parent(BlockDiagonal([X, Y])) == [X, Y]
53+
@test parent(BlockDiagonal((X, Y))) == (X, Y)
4754
end
4855

4956
@testset "similar" begin
@@ -117,3 +124,4 @@ using Test
117124
@test_throws DimensionMismatch copy!(b2, b1)
118125
end
119126
end
127+
end

test/linalg.jl

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,14 @@ using Test
88
rng = MersenneTwister(123456)
99
N1, N2, N3 = 3, 4, 5
1010
N = N1 + N2 + N3
11-
b1 = BlockDiagonal([rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N3, N3)])
12-
b2 = BlockDiagonal([rand(rng, N1, N1), rand(rng, N3, N3), rand(rng, N2, N2)])
13-
b3 = BlockDiagonal([rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N2, N2)])
11+
blocks1 = [rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N3, N3)]
12+
blocks2 = [rand(rng, N1, N1), rand(rng, N3, N3), rand(rng, N2, N2)]
13+
blocks3 = [rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N2, N2)]
14+
15+
@testset "$T" for (T, (b1, b2, b3)) in (
16+
Tuple => (BlockDiagonal(Tuple(blocks1)), BlockDiagonal(Tuple(blocks2)), BlockDiagonal(Tuple(blocks3))),
17+
Vector => (BlockDiagonal(blocks1), BlockDiagonal(blocks2), BlockDiagonal(blocks3)),
18+
)
1419
A = rand(rng, N, N + N1)
1520
B = rand(rng, N + N1, N + N2)
1621
A′, B′ = A', B'
@@ -143,7 +148,7 @@ using Test
143148

144149
@testset "eigvals on LinearAlgebra types" begin
145150
# `eigvals` has different methods for different types, e.g. Hermitian
146-
b_herm = BlockDiagonal([Hermitian(rand(rng, 3, 3) + I) for _ in 1:3])
151+
b_herm = BlockDiagonal(T(Hermitian(rand(rng, 3, 3) + I) for _ in 1:3))
147152
@test eigvals(b_herm) eigvals(Matrix(b_herm))
148153
@test eigvals(b_herm, 1.0, 2.0) eigvals(Hermitian(Matrix(b_herm)), 1.0, 2.0)
149154
end
@@ -163,21 +168,21 @@ using Test
163168
0.0 1.0 5.0
164169
0.0 0.0 3.0]
165170

166-
B = BlockDiagonal([X, X])
171+
B = BlockDiagonal(T([X, X]))
167172
C = cholesky(B)
168173
@test C isa Cholesky{Float64, <:BlockDiagonal{Float64}}
169174
@test C.U cholesky(Matrix(B)).U
170-
@test C.U BlockDiagonal([U, U])
171-
@test C.L BlockDiagonal([U', U'])
175+
@test C.U BlockDiagonal(T([U, U]))
176+
@test C.L BlockDiagonal(T([U', U']))
172177
@test C.UL C.U
173178
@test C.uplo === 'U'
174179
@test C.info == 0
175180

176181
M = BlockDiagonal(map(Matrix, blocks(C.L)))
177182
C = Cholesky(M, 'L', 0)
178183
@test C.U cholesky(Matrix(B)).U
179-
@test C.U BlockDiagonal([U, U])
180-
@test C.L BlockDiagonal([U', U'])
184+
@test C.U BlockDiagonal(T([U, U]))
185+
@test C.L BlockDiagonal(T([U', U']))
181186
@test C.UL C.L
182187
@test C.uplo === 'L'
183188
@test C.info == 0
@@ -187,7 +192,7 @@ using Test
187192
X = [ 4 12 -16
188193
12 37 -43
189194
-16 -43 98]
190-
B = BlockDiagonal([X, X])
195+
B = BlockDiagonal(T([X, X]))
191196

192197
@testset "full=$full" for full in (true, false)
193198

@@ -230,3 +235,4 @@ using Test
230235
end
231236
end # SVD
232237
end
238+
end

0 commit comments

Comments
 (0)