Skip to content

Commit 673505d

Browse files
fix up some of the VectorOfArray dispatches
1 parent 0289ef0 commit 673505d

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RecursiveArrayTools"
22
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "2.18.0"
4+
version = "2.19.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/zygote.jl

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,22 +54,48 @@ ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfAr
5454
# These rules duplicate the `rrule` methods above, because Zygote looks for an `@adjoint`
5555
# definition first, and finds its own before finding those.
5656

57-
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}})
57+
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int)
5858
function AbstractVectorOfArray_getindex_adjoint(Δ)
5959
Δ′ = [(i == j ? Δ : Fill(zero(eltype(x)),size(x))) for (x,j) in zip(VA.u, 1:length(VA))]
6060
(VectorOfArray(Δ′),nothing)
6161
end
6262
VA[i],AbstractVectorOfArray_getindex_adjoint
6363
end
6464

65-
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}, j::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}...)
65+
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{BitArray,AbstractArray{Bool}})
66+
function AbstractVectorOfArray_getindex_adjoint(Δ)
67+
Δ′ = [(i[j] ? Δ[j] : Fill(zero(eltype(x)),size(x))) for (x,j) in zip(VA.u, 1:length(VA))]
68+
(VectorOfArray(Δ′),nothing)
69+
end
70+
VA[i],AbstractVectorOfArray_getindex_adjoint
71+
end
72+
73+
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::AbstractArray{Int})
74+
function AbstractVectorOfArray_getindex_adjoint(Δ)
75+
iter = 0
76+
Δ′ = [(j i ? Δ[iter+=1] : Fill(zero(eltype(x)),size(x))) for (x,j) in zip(VA.u, 1:length(VA))]
77+
(VectorOfArray(Δ′),nothing)
78+
end
79+
VA[i],AbstractVectorOfArray_getindex_adjoint
80+
end
81+
82+
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int}})
83+
function AbstractVectorOfArray_getindex_adjoint(Δ)
84+
Δ′ = [(i[j] ? Δ[j] : Fill(zero(eltype(x)),size(x))) for (x,j) in zip(VA.u, 1:length(VA))]
85+
(VectorOfArray(Δ′),nothing)
86+
end
87+
VA[i],AbstractVectorOfArray_getindex_adjoint
88+
end
89+
90+
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int, j::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}...)
6691
function AbstractVectorOfArray_getindex_adjoint(Δ)
6792
Δ′ = [(i == j ? zero(x) : Fill(zero(eltype(x)),size(x))) for (x,j) in zip(VA.u, 1:length(VA))]
6893
Δ′[i][j...] = Δ
6994
(VectorOfArray(Δ′), nothing, map(_ -> nothing, j)...)
7095
end
7196
VA[i,j...],AbstractVectorOfArray_getindex_adjoint
7297
end
98+
7399
ZygoteRules.@adjoint function ArrayPartition(x::S, ::Type{Val{copy_x}} = Val{false}) where {S<:Tuple,copy_x}
74100
function ArrayPartition_adjoint(_y)
75101
y = Array(_y)

0 commit comments

Comments
 (0)