@@ -54,22 +54,48 @@ ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfAr
5454# These rules duplicate the `rrule` methods above, because Zygote looks for an `@adjoint`
5555# definition first, and finds its own before finding those.
5656
57- ZygoteRules. @adjoint function getindex (VA:: AbstractVectorOfArray , i:: Union{ Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}} )
57+ ZygoteRules. @adjoint function getindex (VA:: AbstractVectorOfArray , i:: Int )
5858 function AbstractVectorOfArray_getindex_adjoint (Δ)
5959 Δ′ = [(i == j ? Δ : Fill (zero (eltype (x)),size (x))) for (x,j) in zip (VA. u, 1 : length (VA))]
6060 (VectorOfArray (Δ′),nothing )
6161 end
6262 VA[i],AbstractVectorOfArray_getindex_adjoint
6363end
6464
65- ZygoteRules. @adjoint function getindex (VA:: AbstractVectorOfArray , i:: Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}} , j:: Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}} ...)
65+ ZygoteRules. @adjoint function getindex (VA:: AbstractVectorOfArray , i:: Union{BitArray,AbstractArray{Bool}} )
66+ function AbstractVectorOfArray_getindex_adjoint (Δ)
67+ Δ′ = [(i[j] ? Δ[j] : Fill (zero (eltype (x)),size (x))) for (x,j) in zip (VA. u, 1 : length (VA))]
68+ (VectorOfArray (Δ′),nothing )
69+ end
70+ VA[i],AbstractVectorOfArray_getindex_adjoint
71+ end
72+
73+ ZygoteRules. @adjoint function getindex (VA:: AbstractVectorOfArray , i:: AbstractArray{Int} )
74+ function AbstractVectorOfArray_getindex_adjoint (Δ)
75+ iter = 0
76+ Δ′ = [(j ∈ i ? Δ[iter+= 1 ] : Fill (zero (eltype (x)),size (x))) for (x,j) in zip (VA. u, 1 : length (VA))]
77+ (VectorOfArray (Δ′),nothing )
78+ end
79+ VA[i],AbstractVectorOfArray_getindex_adjoint
80+ end
81+
82+ ZygoteRules. @adjoint function getindex (VA:: AbstractVectorOfArray , i:: Union{Int,AbstractArray{Int}} )
83+ function AbstractVectorOfArray_getindex_adjoint (Δ)
84+ Δ′ = [(i[j] ? Δ[j] : Fill (zero (eltype (x)),size (x))) for (x,j) in zip (VA. u, 1 : length (VA))]
85+ (VectorOfArray (Δ′),nothing )
86+ end
87+ VA[i],AbstractVectorOfArray_getindex_adjoint
88+ end
89+
90+ ZygoteRules. @adjoint function getindex (VA:: AbstractVectorOfArray , i:: Int , j:: Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}} ...)
6691 function AbstractVectorOfArray_getindex_adjoint (Δ)
6792 Δ′ = [(i == j ? zero (x) : Fill (zero (eltype (x)),size (x))) for (x,j) in zip (VA. u, 1 : length (VA))]
6893 Δ′[i][j... ] = Δ
6994 (VectorOfArray (Δ′), nothing , map (_ -> nothing , j)... )
7095 end
7196 VA[i,j... ],AbstractVectorOfArray_getindex_adjoint
7297end
98+
7399ZygoteRules. @adjoint function ArrayPartition (x:: S , :: Type{Val{copy_x}} = Val{false }) where {S<: Tuple ,copy_x}
74100 function ArrayPartition_adjoint (_y)
75101 y = Array (_y)
0 commit comments