@@ -100,21 +100,21 @@ end
100
100
Base. getindex (xs:: TrackedArray , i... ; kwargs... ) = track (getindex, xs, i... ; kwargs... )
101
101
102
102
@grad function getindex (xs:: AbstractArray , i... ; kwargs... )
103
- data (xs)[ i... ; kwargs... ] , function (Δ)
103
+ getindex ( data (xs), i... ; kwargs... ) , function (Δ)
104
104
Δ′ = zero (xs)
105
- Δ′[ i... ] = data (Δ )
106
- (nobacksies (:getindex , Δ′), map (_-> nothing , i)... )
107
- end
105
+ setindex! (Δ′, data (Δ), i... ; kwargs ... )
106
+ (nobacksies (:getindex , Δ′), map (_-> nothing , i)... ) # TODO : put kwargs here
107
+ end
108
108
end
109
109
110
110
Base. view (x:: TrackedArray , inds... ; kwargs... ) = track (Base. view, x, inds... ; kwargs... )
111
111
112
112
@grad function view (x:: AbstractArray , inds... ; kwargs... )
113
113
view (data (x), inds... ; kwargs... ), function (Δ)
114
114
grad_output = zero (x)
115
- subgrad = view (grad_output, inds... )
115
+ subgrad = view (grad_output, inds... ; kwargs ... )
116
116
subgrad[:] = data (Δ)
117
- (nobacksies (:view , grad_output), map (_-> nothing , inds)... )
117
+ (nobacksies (:view , grad_output), map (_-> nothing , inds)... ) # TODO : put kwargs here
118
118
end
119
119
end
120
120
0 commit comments