Skip to content

Commit deeb0f6

Browse files
Fix some more VoA adjoints
1 parent ecb177e commit deeb0f6

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RecursiveArrayTools"
22
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "2.19.0"
4+
version = "2.19.1"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/zygote.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,18 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,A
8787
VA[i],AbstractVectorOfArray_getindex_adjoint
8888
end
8989

90+
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Colon)
91+
function AbstractVectorOfArray_getindex_adjoint(Δ)
92+
(VectorOfArray(Δ),nothing)
93+
end
94+
VA[i],AbstractVectorOfArray_getindex_adjoint
95+
end
96+
9097
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int, j::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}...)
9198
function AbstractVectorOfArray_getindex_adjoint(Δ)
92-
Δ′ = [(i == j ? zero(x) : Fill(zero(eltype(x)),size(x))) for (x,j) in zip(VA.u, 1:length(VA))]
93-
Δ′[i][j...] = Δ
94-
(VectorOfArray(Δ′), nothing, map(_ -> nothing, j)...)
99+
Δ′ = VectorOfArray([zero(x) for (x,j) in zip(VA.u, 1:length(VA))])
100+
Δ′[i,j...] = Δ
101+
(Δ′, nothing, map(_ -> nothing, j)...)
95102
end
96103
VA[i,j...],AbstractVectorOfArray_getindex_adjoint
97104
end

0 commit comments

Comments
 (0)