@@ -18,23 +18,14 @@ function ChainRulesCore.rrule(T::Type{<:RecursiveArrayTools.GPUArraysCore.Abstra
1818 T (xs), ȳ -> (ChainRulesCore. NoTangent (), ȳ)
1919end
2020
21- @adjoint function getindex (VA:: AbstractVectorOfArray , i:: Int )
22- function AbstractVectorOfArray_getindex_adjoint (Δ)
23- Δ′ = [(i == j ? Δ : FillArrays. Fill (zero (eltype (x)), size (x)))
24- for (x, j) in zip (VA. u, 1 : length (VA))]
25- (VectorOfArray (Δ′), nothing )
26- end
27- VA[i], AbstractVectorOfArray_getindex_adjoint
28- end
29-
3021@adjoint function getindex (VA:: AbstractVectorOfArray ,
3122 i:: Union{BitArray, AbstractArray{Bool}} )
3223 function AbstractVectorOfArray_getindex_adjoint (Δ)
3324 Δ′ = [(i[j] ? Δ[j] : FillArrays. Fill (zero (eltype (x)), size (x)))
3425 for (x, j) in zip (VA. u, 1 : length (VA))]
3526 (VectorOfArray (Δ′), nothing )
3627 end
37- VA[i], AbstractVectorOfArray_getindex_adjoint
28+ VA[:, i], AbstractVectorOfArray_getindex_adjoint
3829end
3930
4031@adjoint function getindex (VA:: AbstractVectorOfArray , i:: AbstractArray{Int} )
4435 for (x, j) in zip (VA. u, 1 : length (VA))]
4536 (VectorOfArray (Δ′), nothing )
4637 end
47- VA[i], AbstractVectorOfArray_getindex_adjoint
48- end
49-
50- @adjoint function getindex (VA:: AbstractVectorOfArray ,
51- i:: Union{Int, AbstractArray{Int}} )
52- function AbstractVectorOfArray_getindex_adjoint (Δ)
53- Δ′ = [(i[j] ? Δ[j] : FillArrays. Fill (zero (eltype (x)), size (x)))
54- for (x, j) in zip (VA. u, 1 : length (VA))]
55- (VectorOfArray (Δ′), nothing )
56- end
57- VA[i], AbstractVectorOfArray_getindex_adjoint
38+ VA[:, i], AbstractVectorOfArray_getindex_adjoint
5839end
5940
6041@adjoint function getindex (VA:: AbstractVectorOfArray , i:: Colon )
6142 function AbstractVectorOfArray_getindex_adjoint (Δ)
6243 (VectorOfArray (Δ), nothing )
6344 end
64- VA[i], AbstractVectorOfArray_getindex_adjoint
45+ VA. u [i], AbstractVectorOfArray_getindex_adjoint
6546end
6647
6748@adjoint function getindex (VA:: AbstractVectorOfArray , i:: Int ,
0 commit comments