Skip to content

Commit c88e61b

Browse files
test: broadcasting preserves nested types
1 parent 60f684c commit c88e61b

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

test/adjoints.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,7 @@ loss(x)
9292
VectorOfArray([collect((3i):(3i + 3)) for i in 1:5])
9393
@test Zygote.gradient(loss10, x)[1] == ForwardDiff.gradient(loss10, x)
9494
@test Zygote.gradient(loss11, x)[1] == ForwardDiff.gradient(loss11, x)
95+
96+
x = ArrayPartition(ArrayPartition(rand(3,4), rand(3,4)), rand(2))
97+
g = Zygote.gradient(norm, x)[1]
98+
@test g isa typeof(x)

test/basic_indexing.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,3 +280,9 @@ x = VectorOfArray(StructArray{SVector{1, Float64}}(ntuple(_ -> [1.0, 2.0], 1)))
280280
y = 2 * x
281281
@. x = y
282282
@test all(all.(y .== x))
283+
284+
285+
x = ArrayPartition(ArrayPartition(rand(3,4), rand(3,4)), rand(2))
286+
@test (x .* 1.2) isa ArrayPartition{<:Any, <:ArrayPartition}
287+
288+
g = Zygote.gradient(norm, x)[1]

0 commit comments

Comments
 (0)