Skip to content

Commit 0108219

Browse files
committed
Make BlockDiagonal block storage generic
Before this commit, `BlockDiagonal` uses a `Vector{V}` to store the blocks. The issue is that for trimmable binaries it is advantageous to use stronger typing, such as `SizedVector` from `StaticArrays`. To support such cases, we make the vector type generic, and make the required modifications in the rest of the package.
1 parent d1b2191 commit 0108219

File tree

4 files changed

+13
-13
lines changed

4 files changed

+13
-13
lines changed

ext/ChainRulesCoreExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ function ChainRulesCore.rrule(
5858
::typeof(*),
5959
bm::BlockDiagonal{T,V},
6060
v::StridedVector{T},
61-
) where {T<:Union{Real,Complex},V<:Matrix{T}}
61+
) where {T<:Union{Real,Complex},V<:AbstractVector{<:Matrix{T}}}
6262

6363
y = bm * v
6464

src/blockdiagonal.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55
66
A matrix with matrices on the diagonal, and zeros off the diagonal.
77
"""
8-
struct BlockDiagonal{T,V<:AbstractMatrix{T}} <: AbstractMatrix{T}
9-
blocks::Vector{V}
8+
struct BlockDiagonal{T,BV<:AbstractVector{<:AbstractMatrix{T}}} <: AbstractMatrix{T}
9+
blocks::BV
1010

11-
function BlockDiagonal{T,V}(blocks::Vector{V}) where {T,V<:AbstractMatrix{T}}
12-
return new{T,V}(blocks)
11+
function BlockDiagonal{T,BV}(blocks::AbstractVector{V}) where {T,V<:AbstractMatrix{T},BV<:AbstractVector{V}}
12+
return new{T,typeof(blocks)}(blocks)
1313
end
1414
end
1515

16-
function BlockDiagonal(blocks::Vector{V}) where {T,V<:AbstractMatrix{T}}
17-
return BlockDiagonal{T,V}(blocks)
16+
function BlockDiagonal(blocks::AbstractVector{V}) where {T,V<:AbstractMatrix{T}}
17+
return BlockDiagonal{T,typeof(blocks)}(blocks)
1818
end
1919

2020
BlockDiagonal(B::BlockDiagonal) = B
@@ -86,7 +86,7 @@ function getblock(B::BlockDiagonal{T}, p::Integer, q::Integer) where {T}
8686
return p == q ? blocks(B)[p] : Zeros{T}(blocksize(B, p, q))
8787
end
8888

89-
function setblock!(B::BlockDiagonal{T,V}, v::V, p::Integer) where {T,V}
89+
function setblock!(B::BlockDiagonal{T,V}, v::W, p::Integer) where {T,W,V<:AbstractVector{W}}
9090
if blocksize(B, p) != size(v)
9191
throw(
9292
DimensionMismatch(
@@ -97,7 +97,7 @@ function setblock!(B::BlockDiagonal{T,V}, v::V, p::Integer) where {T,V}
9797
return blocks(B)[p] = v
9898
end
9999

100-
function setblock!(B::BlockDiagonal{T,V}, v::V, p::Int, q::Int) where {T,V}
100+
function setblock!(B::BlockDiagonal{T,V}, v::W, p::Int, q::Int) where {T,W,V<:AbstractVector{W}}
101101
p == q ||
102102
throw(ArgumentError("Cannot set off-diagonal block ($p, $q) to non-zero value."))
103103
return setblock!(B, v, p)
@@ -155,7 +155,7 @@ end
155155
end
156156

157157
function Base.convert(::Type{BlockDiagonal{T,M}}, b::BlockDiagonal) where {T,M}
158-
new_blocks = convert.(M, blocks(b))
158+
new_blocks = convert(M, blocks(b))
159159
return BlockDiagonal(new_blocks)::BlockDiagonal{T,M}
160160
end
161161

test/blockdiagonal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ using Test
175175
b = BlockDiagonal([special])
176176

177177
convert_first = BlockDiagonal([convert(Matrix, special)])
178-
convert_last = convert(BlockDiagonal{Float64,Matrix{Float64}}, b)
178+
convert_last = convert(BlockDiagonal{Float64,Vector{Matrix{Float64}}}, b)
179179

180180
@test convert_first == convert_last
181181
end

test/linalg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ end
160160
@test C.UL C.U
161161
@test C.uplo === 'U'
162162
@test C.info == 0
163-
@test typeof(C) == Cholesky{Float64,BlockDiagonal{Float64,Matrix{Float64}}}
163+
@test typeof(C) == Cholesky{Float64,BlockDiagonal{Float64,Vector{Matrix{Float64}}}}
164164
@test PDMat(cholesky(BD)) == PDMat(cholesky(Matrix(BD)))
165165

166166
M = BlockDiagonal(map(Matrix, blocks(C.L)))
@@ -171,7 +171,7 @@ end
171171
@test C.UL C.L
172172
@test C.uplo === 'L'
173173
@test C.info == 0
174-
@test typeof(C) == Cholesky{Float64,BlockDiagonal{Float64,Matrix{Float64}}}
174+
@test typeof(C) == Cholesky{Float64,BlockDiagonal{Float64,Vector{Matrix{Float64}}}}
175175

176176
# we didn't think we needed to support this, but #109
177177
d = Diagonal(rand(5))

0 commit comments

Comments
 (0)