@@ -62,7 +62,7 @@ Base.copy(x::TrackedArray) = x
62
62
63
63
collect (xs:: TrackedArray ) = xs
64
64
65
- Base. setindex! (xs:: TrackedArray , v, i... ) =
65
+ Base. setindex! (xs:: TrackedArray , v, i... ; kwargs ... ) =
66
66
error (" Can't differentiate `setindex!`" )
67
67
68
68
back! (:: TrackedArray ) = error (" Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`" )
97
97
98
98
# Array Stdlib
99
99
100
- Base. getindex (xs:: TrackedArray , i... ) = track (getindex, xs, i... )
100
+ Base. getindex (xs:: TrackedArray , i... ; kwargs ... ) = track (getindex, xs, i... ; kwargs ... )
101
101
102
- @grad function getindex (xs:: AbstractArray , i... )
103
- data (xs)[ i... ] , function (Δ)
104
- Δ′ = zero (xs)
105
- Δ′[ i... ] = data (Δ )
106
- (nobacksies (:getindex , Δ′), map (_-> nothing , i)... )
107
- end
102
+ @grad function getindex (xs:: AbstractArray , i... ; kwargs ... )
103
+ getindex ( data (xs), i... ; kwargs ... ) , function (Δ)
104
+ Δ′ = zero (xs)
105
+ setindex! (Δ′, data (Δ), i... ; kwargs ... )
106
+ (nobacksies (:getindex , Δ′), map (_-> nothing , i)... )
107
+ end
108
108
end
109
109
110
- Base. view (x:: TrackedArray , inds... ) = track (Base. view, x, inds... )
110
+ Base. view (x:: TrackedArray , inds... ; kwargs ... ) = track (Base. view, x, inds... ; kwargs ... )
111
111
112
- @grad function view (x:: AbstractArray , inds... )
113
- view (data (x), inds... ), function (Δ)
112
+ @grad function view (x:: AbstractArray , inds... ; kwargs ... )
113
+ view (data (x), inds... ; kwargs ... ), function (Δ)
114
114
grad_output = zero (x)
115
- subgrad = view (grad_output, inds... )
115
+ subgrad = view (grad_output, inds... ; kwargs ... )
116
116
subgrad[:] = data (Δ)
117
117
(nobacksies (:view , grad_output), map (_-> nothing , inds)... )
118
118
end
0 commit comments