Skip to content

Commit 4f04da5

Browse files
Allows blocks to be <:AbstractVector or <:Tuple
- This is a breaking change, as the type param `V` is now different, being the type of the `blocks` collection, rather than the `eltype` of that collection. - This allows allows the collection to be something other than a `Vector` e.g. - a `CuArray` of matrices (allowing use on GPU hopefully) - a `Tuple` with different concrete subtypes of `AbstractMatrix{T}` (avoiding heterogeneous block to being abstractly typed i.e. a way to avoid `Vector{AbstractMatrix{T}}`)
1 parent fe99ce7 commit 4f04da5

File tree

6 files changed

+68
-32
lines changed

6 files changed

+68
-32
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockDiagonals"
22
uuid = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
33
authors = ["Invenia Technical Computing Corporation"]
4-
version = "0.1.11"
4+
version = "0.2.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/blockdiagonal.jl

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
11
# Core functionality for the `BlockDiagonal` type
22

33
"""
4-
BlockDiagonal{T, V<:AbstractMatrix{T}} <: AbstractMatrix{T}
4+
BlockDiagonal{T, V} <: AbstractMatrix{T}
5+
BlockDiagonal(blocks::V) -> BlockDiagonal{T,V}
56
67
A matrix with matrices on the diagonal, and zeros off the diagonal.
7-
"""
8-
struct BlockDiagonal{T, V<:AbstractMatrix{T}} <: AbstractMatrix{T}
9-
blocks::Vector{V}
108
11-
function BlockDiagonal{T, V}(blocks::Vector{V}) where {T, V<:AbstractMatrix{T}}
12-
return new{T, V}(blocks)
13-
end
9+
!!! info "`V` type"
10+
`blocks::V` should be a `Tuple` or `AbstractVector` where each component (each block) is
11+
`<:AbstractMatrix{T}` for some common element type `T`.
12+
"""
13+
struct BlockDiagonal{T, V} <: AbstractMatrix{T}
14+
blocks::V
1415
end
1516

16-
function BlockDiagonal(blocks::Vector{V}) where {T, V<:AbstractMatrix{T}}
17+
function BlockDiagonal(blocks::V) where {
18+
T, V<:Union{Tuple{Vararg{<:AbstractMatrix{T}}}, AbstractVector{<:AbstractMatrix{T}}}
19+
}
1720
return BlockDiagonal{T, V}(blocks)
1821
end
1922

20-
function ChainRulesCore.rrule(::Type{<:BlockDiagonal}, blocks::Vector{V}) where {V}
23+
function ChainRulesCore.rrule(::Type{<:BlockDiagonal}, blocks::V) where {V}
2124
BlockDiagonal_pullback::Composite) = (NO_FIELDS, Δ.blocks)
2225
return BlockDiagonal(blocks), BlockDiagonal_pullback
2326
end
@@ -132,7 +135,7 @@ function ChainRulesCore.rrule(::Type{<:Base.Matrix}, B::T) where {T<:BlockDiagon
132135
Δblocks = map(eachindex(nrows)) do n
133136
block_rows = row_idxs[n]:(row_idxs[n] + nrows[n] - 1)
134137
block_cols = col_idxs[n]:(col_idxs[n] + ncols[n] - 1)
135-
return Δ[block_rows, block_cols]
138+
return Δ[block_rows, block_cols]
136139
end
137140
return (NO_FIELDS, Composite{T}(blocks=Δblocks))
138141
end
@@ -173,7 +176,9 @@ function _block_indices(B::BlockDiagonal, i::Integer, j::Integer)
173176
p += 1
174177
j -= ncols[p]
175178
end
176-
i -= sum(nrows[1:(p-1)])
179+
if !isempty(nrows[1:(p-1)])
180+
i -= sum(nrows[1:(p-1)])
181+
end
177182
# if row `i` outside of block `p`, set `p` to place-holder value `-1`
178183
if i <= 0 || i > nrows[p]
179184
p = -1

src/linalg.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ svdvals_blockwise(B::BlockDiagonal) = mapreduce(svdvals, vcat, blocks(B))
4242
LinearAlgebra.svdvals(B::BlockDiagonal) = sort!(svdvals_blockwise(B); rev=true)
4343

4444
# `B = U * Diagonal(S) * Vt` with `U` and `Vt` `BlockDiagonal` (`S` only sorted block-wise).
45-
function svd_blockwise(B::BlockDiagonal{T}; full::Bool=false) where T
45+
function svd_blockwise(B::BlockDiagonal{T, <:AbstractVector}; full::Bool=false) where T
4646
U = Matrix{float(T)}[]
4747
S = Vector{float(T)}()
4848
Vt = Matrix{float(T)}[]
@@ -54,6 +54,17 @@ function svd_blockwise(B::BlockDiagonal{T}; full::Bool=false) where T
5454
end
5555
return BlockDiagonal(U), S, BlockDiagonal(Vt)
5656
end
57+
function svd_blockwise(B::BlockDiagonal{T, <:Tuple}; full::Bool=false) where T
58+
S = Vector{float(T)}()
59+
U_Vt = ntuple(length(blocks(B))) do i
60+
F = svd(getblock(B, i), full=full)
61+
append!(S, F.S)
62+
(F.U, F.Vt)
63+
end
64+
U = first.(U_Vt)
65+
Vt = last.(U_Vt)
66+
return BlockDiagonal(U), S, BlockDiagonal(Vt)
67+
end
5768

5869
function LinearAlgebra.svd(B::BlockDiagonal; full::Bool=false)::SVD
5970
U, S, Vt = svd_blockwise(B, full=full)

test/base_maths.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,14 @@ using Test
77
rng = MersenneTwister(123456)
88
N1, N2, N3 = 3, 4, 5
99
N = N1 + N2 + N3
10-
b1 = BlockDiagonal([rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N3, N3)])
11-
b2 = BlockDiagonal([rand(rng, N1, N1), rand(rng, N3, N3), rand(rng, N2, N2)])
12-
b3 = BlockDiagonal([rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N2, N2)])
10+
blocks1 = [rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N3, N3)]
11+
blocks2 = [rand(rng, N1, N1), rand(rng, N3, N3), rand(rng, N2, N2)]
12+
blocks3 = [rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N2, N2)]
13+
14+
@testset "$T" for (T, (b1, b2, b3)) in (
15+
Tuple => (BlockDiagonal(Tuple(blocks1)), BlockDiagonal(Tuple(blocks2)), BlockDiagonal(Tuple(blocks3))),
16+
Vector => (BlockDiagonal(blocks1), BlockDiagonal(blocks2), BlockDiagonal(blocks3)),
17+
)
1318
A = rand(rng, N, N + N1)
1419
B = rand(rng, N + N1, N + N2)
1520
A′, B′ = A', B'
@@ -127,8 +132,8 @@ using Test
127132
end
128133

129134
@testset "Non-Square BlockDiagonal * Non-Square BlockDiagonal" begin
130-
b4 = BlockDiagonal([ones(2, 4), 2 * ones(3, 2)])
131-
b5 = BlockDiagonal([3 * ones(2, 2), 2 * ones(4, 1)])
135+
b4 = BlockDiagonal(T([ones(2, 4), 2 * ones(3, 2)]))
136+
b5 = BlockDiagonal(T([3 * ones(2, 2), 2 * ones(4, 1)]))
132137

133138
@test b4 * b5 isa Array
134139
@test b4 * b5 == [6 * ones(2, 2) 4 * ones(2, 1); zeros(3, 2) 8 * ones(3, 1)]
@@ -138,3 +143,4 @@ using Test
138143
end
139144
end # Multiplication
140145
end
146+
end

test/blockdiagonal.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,14 @@ end
1313
rng = MersenneTwister(123456)
1414
N1, N2, N3 = 3, 4, 5
1515
N = N1 + N2 + N3
16-
b1 = BlockDiagonal([rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N3, N3)])
17-
b2 = BlockDiagonal([rand(rng, N1, N1), rand(rng, N3, N3), rand(rng, N2, N2)])
18-
b3 = BlockDiagonal([rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N2, N2)])
16+
blocks1 = [rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N3, N3)]
17+
blocks2 = [rand(rng, N1, N1), rand(rng, N3, N3), rand(rng, N2, N2)]
18+
blocks3 = [rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N2, N2)]
19+
20+
@testset "$T" for (T, (b1, b2, b3)) in (
21+
Tuple => (BlockDiagonal(Tuple(blocks1)), BlockDiagonal(Tuple(blocks2)), BlockDiagonal(Tuple(blocks3))),
22+
Vector => (BlockDiagonal(blocks1), BlockDiagonal(blocks2), BlockDiagonal(blocks3)),
23+
)
1924
A = rand(rng, N, N + N1)
2025
B = rand(rng, N + N1, N + N2)
2126
A′, B′ = A', B'
@@ -48,8 +53,10 @@ end
4853
end
4954

5055
@testset "parent" begin
51-
@test parent(b1) isa Vector{<:AbstractMatrix}
56+
@test parent(b1) isa Union{Tuple,AbstractVector}
57+
@test eltype(parent(b1)) <: AbstractMatrix
5258
@test parent(BlockDiagonal([X, Y])) == [X, Y]
59+
@test parent(BlockDiagonal((X, Y))) == (X, Y)
5360
end
5461

5562
@testset "similar" begin
@@ -138,3 +145,4 @@ end
138145
@test_throws DimensionMismatch copy!(b2, b1)
139146
end
140147
end
148+
end

test/linalg.jl

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,14 @@ using Test
88
rng = MersenneTwister(123456)
99
N1, N2, N3 = 3, 4, 5
1010
N = N1 + N2 + N3
11-
b1 = BlockDiagonal([rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N3, N3)])
12-
b2 = BlockDiagonal([rand(rng, N1, N1), rand(rng, N3, N3), rand(rng, N2, N2)])
13-
b3 = BlockDiagonal([rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N2, N2)])
11+
blocks1 = [rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N3, N3)]
12+
blocks2 = [rand(rng, N1, N1), rand(rng, N3, N3), rand(rng, N2, N2)]
13+
blocks3 = [rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N2, N2)]
14+
15+
@testset "$T" for (T, (b1, b2, b3)) in (
16+
Tuple => (BlockDiagonal(Tuple(blocks1)), BlockDiagonal(Tuple(blocks2)), BlockDiagonal(Tuple(blocks3))),
17+
Vector => (BlockDiagonal(blocks1), BlockDiagonal(blocks2), BlockDiagonal(blocks3)),
18+
)
1419
A = rand(rng, N, N + N1)
1520
B = rand(rng, N + N1, N + N2)
1621
A′, B′ = A', B'
@@ -55,7 +60,7 @@ using Test
5560

5661
@testset "eigvals on LinearAlgebra types" begin
5762
# `eigvals` has different methods for different types, e.g. Hermitian
58-
b_herm = BlockDiagonal([Hermitian(rand(rng, 3, 3) + I) for _ in 1:3])
63+
b_herm = BlockDiagonal(T(Hermitian(rand(rng, 3, 3) + I) for _ in 1:3))
5964
@test eigvals(b_herm) eigvals(Matrix(b_herm))
6065
@test eigvals(b_herm, 1.0, 2.0) eigvals(Hermitian(Matrix(b_herm)), 1.0, 2.0)
6166
end
@@ -75,21 +80,21 @@ using Test
7580
0.0 1.0 5.0
7681
0.0 0.0 3.0]
7782

78-
B = BlockDiagonal([X, X])
83+
B = BlockDiagonal(T([X, X]))
7984
C = cholesky(B)
8085
@test C isa Cholesky{Float64, <:BlockDiagonal{Float64}}
8186
@test C.U cholesky(Matrix(B)).U
82-
@test C.U BlockDiagonal([U, U])
83-
@test C.L BlockDiagonal([U', U'])
87+
@test C.U BlockDiagonal(T([U, U]))
88+
@test C.L BlockDiagonal(T([U', U']))
8489
@test C.UL C.U
8590
@test C.uplo === 'U'
8691
@test C.info == 0
8792

8893
M = BlockDiagonal(map(Matrix, blocks(C.L)))
8994
C = Cholesky(M, 'L', 0)
9095
@test C.U cholesky(Matrix(B)).U
91-
@test C.U BlockDiagonal([U, U])
92-
@test C.L BlockDiagonal([U', U'])
96+
@test C.U BlockDiagonal(T([U, U]))
97+
@test C.L BlockDiagonal(T([U', U']))
9398
@test C.UL C.L
9499
@test C.uplo === 'L'
95100
@test C.info == 0
@@ -99,7 +104,7 @@ using Test
99104
X = [ 4 12 -16
100105
12 37 -43
101106
-16 -43 98]
102-
B = BlockDiagonal([X, X])
107+
B = BlockDiagonal(T([X, X]))
103108

104109
@testset "full=$full" for full in (true, false)
105110

@@ -142,3 +147,4 @@ using Test
142147
end
143148
end # SVD
144149
end
150+
end

0 commit comments

Comments
 (0)