Skip to content

Commit 150b947

Browse files
author
Michael Abbott
committed
fix gradient for rev=true
1 parent b823426 commit 150b947

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/SliceMap.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,16 @@ function slicemap(f, A::AbstractArray{T,N}, args...; dims, rev::Bool=false) wher
104104
C = if rev==false
105105
[ f(slice, args...) for slice in B ]
106106
else
107-
R = [ f(slice, args...) for slice in Iterators.reverse(B) ]
108-
collect(Iterators.reverse(R))
107+
R = [ f(slice, args...) for slice in iter_reverse(B) ]
108+
iter_reverse(R)
109109
end
110110
JuliennedArrays.Align(C, code...)
111111
end
112112

113+
iter_reverse(x) = collect(Iterators.reverse(x))
114+
115+
@adjoint iter_reverse(x) = iter_reverse(x), dy -> (iter_reverse(dy),)
116+
113117
#========== Forward, Static ==========#
114118

115119
using StaticArrays, ForwardDiff

0 commit comments

Comments
 (0)