Skip to content

Commit 00fab43

Browse files
authored
Correct +(::BlockDiagonal, ::StridedMatrix) for non-square blocks (#125)
* Add erroring test * Fix `+(::BlockDiagonal,::StridedMatrix)` for non-square blocks * Bump patch version * Add addition test for dense `Matrix` with non-blocksquare `BlockDiagonal`
1 parent 3ac9c89 commit 00fab43

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockDiagonals"
22
uuid = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
33
authors = ["Invenia Technical Computing Corporation"]
4-
version = "0.1.40"
4+
version = "0.1.41"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/base_maths.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,14 @@ function Base.:+(B::BlockDiagonal, M::StridedMatrix)
3333
end
3434
A = copy(M)
3535
row = 1
36-
for (j, block) in enumerate(blocks(B))
37-
nrows = size(block, 1)
38-
rows = row:(row + nrows-1)
39-
A[rows, rows] .+= block
36+
col = 1
37+
for block in blocks(B)
38+
nrows, ncols = size(block)
39+
rows = row:(row + nrows - 1)
40+
cols = col:(col + ncols - 1)
41+
A[rows, cols] .+= block
4042
row += nrows
43+
col += ncols
4144
end
4245
return A
4346
end

test/base_maths.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ using Test
2222
bi = BlockDiagonal([zeros(Int, N1, N1), zeros(Int, N2, N2), zeros(Int, N3, N3)])
2323

2424
bns = BlockDiagonal([rand(rng, N1, N2), rand(rng, N2, N3), rand(rng, N3, N1)])
25+
C = rand(rng, N, N)
2526

2627
@testset "Addition" begin
2728
@testset "BlockDiagonal + BlockDiagonal" begin
@@ -35,6 +36,11 @@ using Test
3536
@test b1 + Matrix(b1) == b1 + b1
3637
@test_throws DimensionMismatch b1 + Matrix(b3)
3738

39+
# Test on non-square blocks
40+
# https://github.com/invenia/BlockDiagonals.jl/issues/124
41+
@test bns + Matrix(bns) == Matrix(bns) + Matrix(bns)
42+
@test bns + C == Matrix(bns) + C
43+
3844
# Matrix + BlockDiagonal
3945
@test Matrix(b1) + b1 isa Matrix
4046
@test Matrix(b1) + b1 == b1 + b1

0 commit comments

Comments
 (0)