From 2d9801c1eec5ccc911d1a9f36878f6dfc84c1c9e Mon Sep 17 00:00:00 2001 From: Viktor Svensson Date: Sat, 15 Apr 2023 17:37:33 +0200 Subject: [PATCH 1/2] allow multiplication of non-square blocks --- src/base_maths.jl | 30 ++++++++++++++++++++++++++---- src/linalg.jl | 8 +++----- test/base_maths.jl | 5 +++++ test/blockdiagonal.jl | 11 ++++++++++- 4 files changed, 44 insertions(+), 10 deletions(-) diff --git a/src/base_maths.jl b/src/base_maths.jl index b4e585b..f61c69e 100644 --- a/src/base_maths.jl +++ b/src/base_maths.jl @@ -6,11 +6,33 @@ function Base.isapprox(B::BlockDiagonal, M::AbstractMatrix; kwargs...)::Bool end function Base.isapprox(B1::BlockDiagonal, B2::BlockDiagonal; kwargs...) - return isequal_blocksizes(B1, B2) && all(isapprox.(blocks(B1), blocks(B2); kwargs...)) + return isequal_blocksizes(B1, B2) && all(isapprox(b1,b2; kwargs...) for (b1,b2) in zip(blocks(B1), blocks(B2))) end -function isequal_blocksizes(B1::BlockDiagonal, B2::BlockDiagonal)::Bool - return size(B1) == size(B2) && blocksizes(B1) == blocksizes(B2) +function isequal_blocksizes(B1::BlockDiagonal, B2::BlockDiagonal)::Bool + length(blocks(B1)) == length(blocks(B2)) || return false + for (b1, b2) in zip(blocks(B1), blocks(B2)) + size(b1) == size(b2) || return false + end + return true +end + +function can_block_multiply(B1::BlockDiagonal, B2::BlockDiagonal)::Bool + length(blocks(B1)) == length(blocks(B2)) || return false + for (b1, b2) in zip(blocks(B1), blocks(B2)) + size(b1, 2) == size(b2, 1) || return false + end + return true +end + +function can_block_multiply(C::BlockDiagonal, A::BlockDiagonal,B::BlockDiagonal)::Bool + length(blocks(C)) == length(blocks(A)) == length(blocks(B)) || return false + for (c, a, b) in zip(blocks(C), blocks(A), blocks(B)) + size(c, 1) == size(a, 1) && + size(c, 2) == size(b, 2) && + size(a, 2) == size(b, 1) || return false + end + return true end ## Addition @@ -81,7 +103,7 @@ Base.:*(B::BlockDiagonal, n::Number) = BlockDiagonal(n .* blocks(B)) # TODO make type stable, maybe via Broadcasting? function Base.:*(B1::BlockDiagonal, B2::BlockDiagonal) - if isequal_blocksizes(B1, B2) + if can_block_multiply(B1, B2) return BlockDiagonal(blocks(B1) .* blocks(B2)) else return Matrix(B1) * Matrix(B2) diff --git a/src/linalg.jl b/src/linalg.jl index 01449c9..612ba61 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -138,18 +138,16 @@ if VERSION ≥ v"1.3" end function _mul!(C::BlockDiagonal, A::BlockDiagonal, B::BlockDiagonal) - isequal_blocksizes(A, B) || throw(DimensionMismatch("A and B have different block sizes")) - isequal_blocksizes(C, A) || throw(DimensionMismatch("C has incompatible block sizes")) + can_block_multiply(C,A,B) || throw(DimensionMismatch("A has blocksizes $(blocksizes(A)), B has blocksizes $(blocksizes(B)), C has blocksizes $(blocksizes(C))")) for i in eachindex(blocks(C)) @inbounds LinearAlgebra.mul!(C.blocks[i], A.blocks[i], B.blocks[i]) end return C end - +# function _mul!(C::BlockDiagonal, A::BlockDiagonal, B::BlockDiagonal, α::Number, β::Number) - isequal_blocksizes(A, B) || throw(DimensionMismatch("A and B have different block sizes")) - isequal_blocksizes(C, A) || throw(DimensionMismatch("C has incompatible block sizes")) + can_block_multiply(C,A,B) || throw(DimensionMismatch("A has blocksizes $(blocksizes(A)), B has blocksizes $(blocksizes(B)), C has blocksizes $(blocksizes(C))")) for i in eachindex(blocks(C)) @inbounds LinearAlgebra.mul!(C.blocks[i], A.blocks[i], B.blocks[i], α, β) end diff --git a/test/base_maths.jl b/test/base_maths.jl index 854ea74..4f20ce5 100644 --- a/test/base_maths.jl +++ b/test/base_maths.jl @@ -183,6 +183,11 @@ using Test # Dimension check @test sum(size.(b4.blocks, 1)) == size(b4 * b5, 1) @test sum(size.(b5.blocks, 2)) == size(b4 * b5, 2) + + b6 = BlockDiagonal([ones(4, 2), 2 * ones(2, 4)]) + @test b4 * b6 isa BlockDiagonal + @test sum(size.(b4.blocks, 1)) == size(b4 * b6, 1) + @test sum(size.(b6.blocks, 2)) == size(b4 * b6, 2) end end # Multiplication end diff --git a/test/blockdiagonal.jl b/test/blockdiagonal.jl index 3b23632..2b51305 100644 --- a/test/blockdiagonal.jl +++ b/test/blockdiagonal.jl @@ -1,5 +1,5 @@ using BlockDiagonals -using BlockDiagonals: isequal_blocksizes +using BlockDiagonals: isequal_blocksizes, can_block_multiply using Random using Test @@ -77,6 +77,15 @@ using Test @test isequal_blocksizes(b1, b2) == false end + @testset "can_block_multiply" begin + @test can_block_multiply(b1, b1) == true + @test can_block_multiply(b1, b2) == false + + @test can_block_multiply(b1, b1, b1) == true + @test can_block_multiply(b1, b1, b2) == false + @test can_block_multiply(b2, b1, b1) == false + end + @testset "blocks size" begin B = BlockDiagonal([rand(3, 3), rand(4, 4)]) @test nblocks(B) == 2 From e139739d38f0da618c37a2b631241ac52686c697 Mon Sep 17 00:00:00 2001 From: Viktor Svensson Date: Sat, 15 Apr 2023 18:08:57 +0200 Subject: [PATCH 2/2] more tests --- test/base_maths.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/test/base_maths.jl b/test/base_maths.jl index 4f20ce5..8ace533 100644 --- a/test/base_maths.jl +++ b/test/base_maths.jl @@ -184,10 +184,14 @@ using Test @test sum(size.(b4.blocks, 1)) == size(b4 * b5, 1) @test sum(size.(b5.blocks, 2)) == size(b4 * b5, 2) - b6 = BlockDiagonal([ones(4, 2), 2 * ones(2, 4)]) - @test b4 * b6 isa BlockDiagonal - @test sum(size.(b4.blocks, 1)) == size(b4 * b6, 1) - @test sum(size.(b6.blocks, 2)) == size(b4 * b6, 2) + b6 = BlockDiagonal([ones(4, 1), 2 * ones(2, 2)]) + b46 = b4 * b6 + @test b46 isa BlockDiagonal + @test b46 == [4 * ones(2,1) zeros(2,2); zeros(3,1) 8* ones(3,2)] + @test sum(size.(b4.blocks, 1)) == size(b46, 1) + @test sum(size.(b6.blocks, 2)) == size(b46, 2) + + @test_throws DimensionMismatch b6 * b4 end end # Multiplication end