Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions src/base_maths.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,33 @@ function Base.isapprox(B::BlockDiagonal, M::AbstractMatrix; kwargs...)::Bool
end

function Base.isapprox(B1::BlockDiagonal, B2::BlockDiagonal; kwargs...)
return isequal_blocksizes(B1, B2) && all(isapprox.(blocks(B1), blocks(B2); kwargs...))
return isequal_blocksizes(B1, B2) && all(isapprox(b1,b2; kwargs...) for (b1,b2) in zip(blocks(B1), blocks(B2)))
end

function isequal_blocksizes(B1::BlockDiagonal, B2::BlockDiagonal)::Bool
return size(B1) == size(B2) && blocksizes(B1) == blocksizes(B2)
function isequal_blocksizes(B1::BlockDiagonal, B2::BlockDiagonal)::Bool
length(blocks(B1)) == length(blocks(B2)) || return false
for (b1, b2) in zip(blocks(B1), blocks(B2))
size(b1) == size(b2) || return false
end
return true
end

function can_block_multiply(B1::BlockDiagonal, B2::BlockDiagonal)::Bool
length(blocks(B1)) == length(blocks(B2)) || return false
for (b1, b2) in zip(blocks(B1), blocks(B2))
size(b1, 2) == size(b2, 1) || return false
end
return true
end

function can_block_multiply(C::BlockDiagonal, A::BlockDiagonal,B::BlockDiagonal)::Bool
length(blocks(C)) == length(blocks(A)) == length(blocks(B)) || return false
for (c, a, b) in zip(blocks(C), blocks(A), blocks(B))
size(c, 1) == size(a, 1) &&
size(c, 2) == size(b, 2) &&
size(a, 2) == size(b, 1) || return false
end
return true
end

## Addition
Expand Down Expand Up @@ -81,7 +103,7 @@ Base.:*(B::BlockDiagonal, n::Number) = BlockDiagonal(n .* blocks(B))

# TODO make type stable, maybe via Broadcasting?
function Base.:*(B1::BlockDiagonal, B2::BlockDiagonal)
if isequal_blocksizes(B1, B2)
if can_block_multiply(B1, B2)
return BlockDiagonal(blocks(B1) .* blocks(B2))
else
return Matrix(B1) * Matrix(B2)
Expand Down
8 changes: 3 additions & 5 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,18 +138,16 @@ if VERSION ≥ v"1.3"
end

function _mul!(C::BlockDiagonal, A::BlockDiagonal, B::BlockDiagonal)
isequal_blocksizes(A, B) || throw(DimensionMismatch("A and B have different block sizes"))
isequal_blocksizes(C, A) || throw(DimensionMismatch("C has incompatible block sizes"))
can_block_multiply(C,A,B) || throw(DimensionMismatch("A has blocksizes $(blocksizes(A)), B has blocksizes $(blocksizes(B)), C has blocksizes $(blocksizes(C))"))
for i in eachindex(blocks(C))
@inbounds LinearAlgebra.mul!(C.blocks[i], A.blocks[i], B.blocks[i])
end

return C
end

#
function _mul!(C::BlockDiagonal, A::BlockDiagonal, B::BlockDiagonal, α::Number, β::Number)
isequal_blocksizes(A, B) || throw(DimensionMismatch("A and B have different block sizes"))
isequal_blocksizes(C, A) || throw(DimensionMismatch("C has incompatible block sizes"))
can_block_multiply(C,A,B) || throw(DimensionMismatch("A has blocksizes $(blocksizes(A)), B has blocksizes $(blocksizes(B)), C has blocksizes $(blocksizes(C))"))
for i in eachindex(blocks(C))
@inbounds LinearAlgebra.mul!(C.blocks[i], A.blocks[i], B.blocks[i], α, β)
end
Expand Down
9 changes: 9 additions & 0 deletions test/base_maths.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,15 @@ using Test
# Dimension check
@test sum(size.(b4.blocks, 1)) == size(b4 * b5, 1)
@test sum(size.(b5.blocks, 2)) == size(b4 * b5, 2)

b6 = BlockDiagonal([ones(4, 1), 2 * ones(2, 2)])
b46 = b4 * b6
@test b46 isa BlockDiagonal
@test b46 == [4 * ones(2,1) zeros(2,2); zeros(3,1) 8* ones(3,2)]
@test sum(size.(b4.blocks, 1)) == size(b46, 1)
@test sum(size.(b6.blocks, 2)) == size(b46, 2)

@test_throws DimensionMismatch b6 * b4
end
end # Multiplication
end
11 changes: 10 additions & 1 deletion test/blockdiagonal.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using BlockDiagonals
using BlockDiagonals: isequal_blocksizes
using BlockDiagonals: isequal_blocksizes, can_block_multiply
using Random
using Test

Expand Down Expand Up @@ -77,6 +77,15 @@ using Test
@test isequal_blocksizes(b1, b2) == false
end

@testset "can_block_multiply" begin
@test can_block_multiply(b1, b1) == true
@test can_block_multiply(b1, b2) == false

@test can_block_multiply(b1, b1, b1) == true
@test can_block_multiply(b1, b1, b2) == false
@test can_block_multiply(b2, b1, b1) == false
end

@testset "blocks size" begin
B = BlockDiagonal([rand(3, 3), rand(4, 4)])
@test nblocks(B) == 2
Expand Down