Skip to content

Commit 80e220d

Browse files
fix tests
1 parent 7c97b62 commit 80e220d

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

ext/ArrayInterfaceReverseDiffExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ function ArrayInterface.aos_to_soa(x::AbstractArray{<:ReverseDiff.TrackedReal,N}
1616
if length(x) > 1
1717
reduce(vcat,x)
1818
else
19+
@show "here?"
1920
reduce(vcat,[x[1],x[1]])[1:1]
2021
end
2122
end

test/ad.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
11
using ArrayInterface, ReverseDiff, Tracker, Test
2-
x = reduce(vcat, ReverseDiff.track([4.0]))
2+
x = ReverseDiff.track([4.0])
33
@test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray
44
x = reduce(vcat, ReverseDiff.track([4.0,4.0]))
55
@test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray
6+
x = [ReverseDiff.track([4.0])[1]]
7+
@test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray
8+
x = reduce(vcat, ReverseDiff.track([4.0,4.0]))
9+
x = [x[1],x[2]]
10+
@test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray
611

7-
x = identity.(Tracker.TrackedArray([4.0]))
12+
x = Tracker.TrackedArray([4.0])
13+
@test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray
14+
x = [Tracker.TrackedArray([4.0])[1]]
15+
@test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray
16+
x = Tracker.TrackedArray([4.0,4.0])
817
@test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray
9-
x = identity.(Tracker.TrackedArray([4.0,4.0]))
18+
x = reduce(vcat, Tracker.TrackedArray([4.0,4.0]))
19+
x = [x[1],x[2]]
1020
@test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray

0 commit comments

Comments
 (0)