Skip to content

Commit 70ddc17

Browse files
Update rrules
1 parent 80d1e1a commit 70ddc17

File tree

2 files changed

+15
-13
lines changed

2 files changed

+15
-13
lines changed

src/chainrules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ function ChainRulesCore.rrule(
2727
::typeof(*),
2828
bm::BlockDiagonal{T, V},
2929
v::StridedVector{T}
30-
) where {T<:Union{Real, Complex}, V<:Matrix{T}}
30+
) where {T<:Union{Real, Complex}, V}
3131

3232
y = bm * v
3333

test/chainrules.jl

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
@testset "chainrules.jl" begin
2-
@testset "BlockDiagonal" begin
3-
x = [randn(1, 2), randn(2, 2)]
4-
test_rrule(BlockDiagonal, x)
5-
end
2+
@testset for V in (Tuple, Vector)
3+
@testset "BlockDiagonal" begin
4+
x = V([randn(1, 2), randn(2, 2)])
5+
test_rrule(BlockDiagonal, x)
6+
end
67

7-
@testset "Matrix" begin
8-
D = BlockDiagonal([randn(1, 2), randn(2, 2)])
9-
test_rrule(Matrix, D)
10-
end
8+
@testset "Matrix" begin
9+
B = BlockDiagonal(V([randn(1, 2), randn(2, 2)]))
10+
test_rrule(Matrix, B)
11+
end
1112

12-
@testset "BlockDiagonal * Vector" begin
13-
D = BlockDiagonal([rand(2, 3), rand(3, 3)])
14-
v = rand(6)
15-
test_rrule(*, D, v)
13+
@testset "BlockDiagonal * Vector" begin
14+
B = BlockDiagonal(V([rand(2, 3), rand(3, 3)]))
15+
v = rand(6)
16+
test_rrule(*, B, v)
17+
end
1618
end
1719
end

0 commit comments

Comments
 (0)