diff --git a/ext/ChainRulesCoreExt.jl b/ext/ChainRulesCoreExt.jl index 5c06344..dacc83d 100644 --- a/ext/ChainRulesCoreExt.jl +++ b/ext/ChainRulesCoreExt.jl @@ -58,7 +58,7 @@ function ChainRulesCore.rrule( ::typeof(*), bm::BlockDiagonal{T,V}, v::StridedVector{T}, -) where {T<:Union{Real,Complex},V<:Matrix{T}} +) where {T<:Union{Real,Complex},V<:AbstractVector{<:Matrix{T}}} y = bm * v diff --git a/src/blockdiagonal.jl b/src/blockdiagonal.jl index 27bce62..8191a21 100644 --- a/src/blockdiagonal.jl +++ b/src/blockdiagonal.jl @@ -1,20 +1,20 @@ # Core functionality for the `BlockDiagonal` type """ - BlockDiagonal{T, V<:AbstractMatrix{T}} <: AbstractMatrix{T} + BlockDiagonal{T, BV<:AbstractVector{<:AbstractMatrix{T}}} <: AbstractMatrix{T} A matrix with matrices on the diagonal, and zeros off the diagonal. """ -struct BlockDiagonal{T,V<:AbstractMatrix{T}} <: AbstractMatrix{T} - blocks::Vector{V} +struct BlockDiagonal{T,BV<:AbstractVector{<:AbstractMatrix{T}}} <: AbstractMatrix{T} + blocks::BV - function BlockDiagonal{T,V}(blocks::Vector{V}) where {T,V<:AbstractMatrix{T}} - return new{T,V}(blocks) + function BlockDiagonal{T,BV}(blocks::AbstractVector{V}) where {T,V<:AbstractMatrix{T},BV<:AbstractVector{V}} + return new{T,typeof(blocks)}(blocks) end end -function BlockDiagonal(blocks::Vector{V}) where {T,V<:AbstractMatrix{T}} - return BlockDiagonal{T,V}(blocks) +function BlockDiagonal(blocks::AbstractVector{V}) where {T,V<:AbstractMatrix{T}} + return BlockDiagonal{T,typeof(blocks)}(blocks) end BlockDiagonal(B::BlockDiagonal) = B @@ -22,7 +22,7 @@ BlockDiagonal(B::BlockDiagonal) = B is_square(A::AbstractMatrix) = size(A, 1) == size(A, 2) """ - blocks(B::BlockDiagonal{T, V}) -> Vector{V} + blocks(B::BlockDiagonal{T, Vector{Matrix{T}}}) -> Vector{Matrix{T}} Return the on-diagonal blocks of B. """ @@ -86,7 +86,7 @@ function getblock(B::BlockDiagonal{T}, p::Integer, q::Integer) where {T} return p == q ? blocks(B)[p] : Zeros{T}(blocksize(B, p, q)) end -function setblock!(B::BlockDiagonal{T,V}, v::V, p::Integer) where {T,V} +function setblock!(B::BlockDiagonal{T,V}, v::W, p::Integer) where {T,W,V<:AbstractVector{W}} if blocksize(B, p) != size(v) throw( DimensionMismatch( @@ -97,7 +97,7 @@ function setblock!(B::BlockDiagonal{T,V}, v::V, p::Integer) where {T,V} return blocks(B)[p] = v end -function setblock!(B::BlockDiagonal{T,V}, v::V, p::Int, q::Int) where {T,V} +function setblock!(B::BlockDiagonal{T,V}, v::W, p::Int, q::Int) where {T,W,V<:AbstractVector{W}} p == q || throw(ArgumentError("Cannot set off-diagonal block ($p, $q) to non-zero value.")) return setblock!(B, v, p) @@ -155,7 +155,7 @@ end end function Base.convert(::Type{BlockDiagonal{T,M}}, b::BlockDiagonal) where {T,M} - new_blocks = convert.(M, blocks(b)) + new_blocks = convert(M, blocks(b)) return BlockDiagonal(new_blocks)::BlockDiagonal{T,M} end diff --git a/test/blockdiagonal.jl b/test/blockdiagonal.jl index be39df5..805038c 100644 --- a/test/blockdiagonal.jl +++ b/test/blockdiagonal.jl @@ -175,7 +175,7 @@ using Test b = BlockDiagonal([special]) convert_first = BlockDiagonal([convert(Matrix, special)]) - convert_last = convert(BlockDiagonal{Float64,Matrix{Float64}}, b) + convert_last = convert(BlockDiagonal{Float64,Vector{Matrix{Float64}}}, b) @test convert_first == convert_last end diff --git a/test/linalg.jl b/test/linalg.jl index 9d35de0..66cab13 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -160,7 +160,7 @@ end @test C.UL ≈ C.U @test C.uplo === 'U' @test C.info == 0 - @test typeof(C) == Cholesky{Float64,BlockDiagonal{Float64,Matrix{Float64}}} + @test typeof(C) == Cholesky{Float64,BlockDiagonal{Float64,Vector{Matrix{Float64}}}} @test PDMat(cholesky(BD)) == PDMat(cholesky(Matrix(BD))) M = BlockDiagonal(map(Matrix, blocks(C.L))) @@ -171,7 +171,7 @@ end @test C.UL ≈ C.L @test C.uplo === 'L' @test C.info == 0 - @test typeof(C) == Cholesky{Float64,BlockDiagonal{Float64,Matrix{Float64}}} + @test typeof(C) == Cholesky{Float64,BlockDiagonal{Float64,Vector{Matrix{Float64}}}} # we didn't think we needed to support this, but #109 d = Diagonal(rand(5))