|
1 | | -function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i) |
| 1 | +function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i::Int) |
2 | 2 | function AbstractVectorOfArray_getindex_adjoint(Δ) |
3 | 3 | Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))] |
4 | 4 | (NoTangent(),Δ′,NoTangent()) |
5 | 5 | end |
6 | 6 | VA[i],AbstractVectorOfArray_getindex_adjoint |
7 | 7 | end |
8 | 8 |
|
9 | | -function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i, j...) |
| 9 | +function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, indices::Vararg{Int,N}) where {N} |
10 | 10 | function AbstractVectorOfArray_getindex_adjoint(Δ) |
11 | 11 | Δ′ = zero(VA) |
12 | | - Δ′[i,j...] = Δ |
13 | | - (NoTangent(), Δ′, i,map(_ -> NoTangent(), j)...) |
| 12 | + Δ′[indices...] = Δ |
| 13 | + (NoTangent(), Δ′, indices[1],map(_ -> NoTangent(), indices[2:end])...) |
14 | 14 | end |
15 | | - VA[i,j...],AbstractVectorOfArray_getindex_adjoint |
| 15 | + VA[indices...],AbstractVectorOfArray_getindex_adjoint |
16 | 16 | end |
17 | 17 |
|
18 | 18 | function ChainRulesCore.rrule(::Type{<:ArrayPartition}, x::S, ::Type{Val{copy_x}} = Val{false}) where {S<:Tuple,copy_x} |
|
0 commit comments