Skip to content

Commit 27d82ef

Browse files
authored
Merge pull request #9 from mcabbott/rev
`slicemap` with `rev=true`
2 parents 3774efc + 150b947 commit 27d82ef

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

src/SliceMap.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,25 @@ The gradient is for Zygote only.
9595
9696
Parameters within the function `f` (if there are any) should be correctly tracked,
9797
which is not the case for `mapcols()`.
98+
Keyword `rev=true` will apply `f` to `Iterators.reverse(Slices(A,...))`, thus iterating
99+
in the opposite order.
98100
"""
99-
function slicemap(f, A::AbstractArray{T,N}, args...; dims) where {T,N}
101+
function slicemap(f, A::AbstractArray{T,N}, args...; dims, rev::Bool=false) where {T,N}
100102
code = ntuple(d -> d in dims ? True() : False(), N)
101103
B = JuliennedArrays.Slices(A, code...)
102-
C = [ f(slice, args...) for slice in B ]
104+
C = if rev==false
105+
[ f(slice, args...) for slice in B ]
106+
else
107+
R = [ f(slice, args...) for slice in iter_reverse(B) ]
108+
iter_reverse(R)
109+
end
103110
JuliennedArrays.Align(C, code...)
104111
end
105112

113+
iter_reverse(x) = collect(Iterators.reverse(x))
114+
115+
@adjoint iter_reverse(x) = iter_reverse(x), dy -> (iter_reverse(dy),)
116+
106117
#========== Forward, Static ==========#
107118

108119
using StaticArrays, ForwardDiff

test/runtests.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,28 @@ end
167167
# https://github.com/FluxML/Zygote.jl/issues/522#issuecomment-605935652
168168

169169
end
170+
@testset "slicemap rev=true" begin
171+
172+
rec(store) = x -> (push!(store, first(x)); x)
173+
A = [1,1,1] .* (1:5)'
174+
175+
store = []
176+
slicemap(rec(store), A; dims=1)
177+
@test store == 1:5
178+
179+
store = []
180+
slicemap(rec(store), A; dims=1, rev=true)
181+
@test store == 5:-1:1
182+
183+
# gradient check as above
184+
ten = randn(3,4,5,2)
185+
fun(x::AbstractVector) = sqrt(3) .+ x.^3 ./ (sum(x)^2)
186+
res = mapslices(fun, ten, dims=3)
187+
@test res slicemap(fun, ten; dims=3, rev=true)
188+
@test res slicemap(fun, ten; dims=3, rev=false)
189+
190+
grad = ForwardDiff.gradient(x -> sum(sin, mapslices(fun, x, dims=3)), ten)
191+
@test grad Zygote.gradient(x -> sum(sin, slicemap(fun, x, dims=3, rev=false)), ten)[1]
192+
@test grad Zygote.gradient(x -> sum(sin, slicemap(fun, x, dims=3, rev=true)), ten)[1]
193+
194+
end

0 commit comments

Comments
 (0)