Skip to content

Commit b3b0106

Browse files
authored
Merge pull request #2 from FluxML/wct/fix-getindex
Fix indexing
2 parents fcdaebd + eb9a462 commit b3b0106

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

src/lib/array.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,14 @@ Base.getindex(xs::TrackedArray, i...; kwargs...) = track(getindex, xs, i...; kwa
107107
end
108108
end
109109

110+
@grad function getindex(xs::AbstractArray, i::Array...)
111+
data(xs)[i...], function (Δ)
112+
Δ′ = zero(xs)
113+
@views Δ′[i...] .+= data(Δ)
114+
(nobacksies(:getindex, Δ′), map(_->nothing, i)...)
115+
end
116+
end
117+
110118
Base.view(x::TrackedArray, inds...; kwargs...) = track(Base.view, x, inds...; kwargs...)
111119

112120
@grad function view(x::AbstractArray, inds...; kwargs...)

test/tracker.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,11 @@ end
122122

123123
end
124124

125+
@testset "getindex (Nabla.jl - #139)" begin
126+
z = [2, 3, 3]
127+
@test gradtest(x->x[z], randn(MersenneTwister(123456), 3))
128+
end
129+
125130
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
126131
@test gradtest(x -> PermutedDimsArray(x, [3,1,2]), rand(4,5,6))
127132

0 commit comments

Comments
 (0)