Skip to content

Commit fda7259

Browse files
authored
Merge pull request #58 from invenia/mz/mul
rrule for BlockDiagonal * Vector multiplication
2 parents 57d8490 + 7ce3448 commit fda7259

File tree

6 files changed

+50
-8
lines changed

6 files changed

+50
-8
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
/Manifest.toml
22
docs/build/
3+
dev/*

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockDiagonals"
22
uuid = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
33
authors = ["Invenia Technical Computing Corporation"]
4-
version = "0.1.13"
4+
version = "0.1.14"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -10,7 +10,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010

1111
[compat]
1212
ChainRulesCore = "0.9"
13-
ChainRulesTestUtils = "0.6"
13+
ChainRulesTestUtils = "0.6.3"
1414
FillArrays = "0.6, 0.7, 0.8, 0.9, 0.10"
1515
julia = "1"
1616

src/chainrules.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,38 @@ function ChainRulesCore.rrule(::Type{<:Base.Matrix}, B::T) where {T<:BlockDiagon
2222
return Matrix(B), Matrix_pullback
2323
end
2424

25+
# multiplication
26+
function ChainRulesCore.rrule(
27+
::typeof(*),
28+
bm::BlockDiagonal{T, V},
29+
v::StridedVector{T}
30+
) where {T<:Union{Real, Complex}, V<:Matrix{T}}
31+
32+
y = bm * v
33+
34+
# needed for computing Δ * v' blockwise
35+
nrows = size.(bm.blocks, 1)
36+
ncols = size.(bm.blocks, 2)
37+
row_idxs = cumsum(nrows) .- nrows .+ 1
38+
col_idxs = cumsum(ncols) .- ncols .+ 1
39+
40+
function bm_vector_mul_pullback(Δ)
41+
Δblocks = map(eachindex(nrows)) do i
42+
block_rows = row_idxs[i]:(row_idxs[i] + nrows[i] - 1)
43+
block_cols = col_idxs[i]:(col_idxs[i] + ncols[i] - 1)
44+
return InplaceableThunk(
45+
@thunk(Δ[block_rows] * v[block_cols]'),
46+
-> mul!(X̄, Δ[block_rows], v[block_cols]', true, true)
47+
)
48+
end
49+
return (
50+
NO_FIELDS,
51+
Composite{BlockDiagonal{T, V}}(;blocks=Δblocks),
52+
InplaceableThunk(
53+
@thunk(bm' * Δ),
54+
-> mul!(X̄, bm', Δ, true, true)
55+
),
56+
)
57+
end
58+
return y, bm_vector_mul_pullback
59+
end

test/blockdiagonal.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,6 @@ using BlockDiagonals: isequal_blocksizes
33
using Random
44
using Test
55

6-
function FiniteDifferences.to_vec(X::BlockDiagonal)
7-
x, blocks_from_vec = to_vec(X.blocks)
8-
BlockDiagonal_from_vec(x_vec) = BlockDiagonal(blocks_from_vec(x_vec))
9-
return x, BlockDiagonal_from_vec
10-
end
11-
126
@testset "blockdiagonal.jl" begin
137
rng = MersenneTwister(123456)
148
N1, N2, N3 = 3, 4, 5

test/chainrules.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,10 @@
88
D = BlockDiagonal([randn(1, 2), randn(2, 2)])
99
test_rrule(Matrix, D)
1010
end
11+
12+
@testset "BlockDiagonal * Vector" begin
13+
D = BlockDiagonal([rand(2, 3), rand(3, 3)])
14+
v = rand(6)
15+
test_rrule(*, D, v)
16+
end
1117
end

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@ using FiniteDifferences # For overloading to_vec
66
using Test
77
using LinearAlgebra
88

9+
function FiniteDifferences.to_vec(X::BlockDiagonal)
10+
x, blocks_from_vec = to_vec(X.blocks)
11+
BlockDiagonal_from_vec(x_vec) = BlockDiagonal(blocks_from_vec(x_vec))
12+
return x, BlockDiagonal_from_vec
13+
end
14+
915
@testset "BlockDiagonals" begin
1016
# The doctests fail on x86, so only run them on 64-bit hardware
1117
Sys.WORD_SIZE == 64 && doctest(BlockDiagonals)

0 commit comments

Comments
 (0)