Skip to content

Commit b823426

Browse files
author
Michael Abbott
committed
add rev=true to slicemap
1 parent 5f71649 commit b823426

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

src/SliceMap.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,18 @@ The gradient is for Zygote only.
9595
9696
Parameters within the function `f` (if there are any) should be correctly tracked,
9797
which is not the case for `mapcols()`.
98+
Keyword `rev=true` will apply `f` to `Iterators.reverse(Slices(A,...))`, thus iterating
99+
in the opposite order.
98100
"""
99-
function slicemap(f, A::AbstractArray{T,N}, args...; dims) where {T,N}
101+
function slicemap(f, A::AbstractArray{T,N}, args...; dims, rev::Bool=false) where {T,N}
100102
code = ntuple(d -> d in dims ? True() : False(), N)
101103
B = JuliennedArrays.Slices(A, code...)
102-
C = [ f(slice, args...) for slice in B ]
104+
C = if rev==false
105+
[ f(slice, args...) for slice in B ]
106+
else
107+
R = [ f(slice, args...) for slice in Iterators.reverse(B) ]
108+
collect(Iterators.reverse(R))
109+
end
103110
JuliennedArrays.Align(C, code...)
104111
end
105112

test/runtests.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,28 @@ end
167167
# https://github.com/FluxML/Zygote.jl/issues/522#issuecomment-605935652
168168

169169
end
170+
@testset "slicemap rev=true" begin
171+
172+
rec(store) = x -> (push!(store, first(x)); x)
173+
A = [1,1,1] .* (1:5)'
174+
175+
store = []
176+
slicemap(rec(store), A; dims=1)
177+
@test store == 1:5
178+
179+
store = []
180+
slicemap(rec(store), A; dims=1, rev=true)
181+
@test store == 5:-1:1
182+
183+
# gradient check as above
184+
ten = randn(3,4,5,2)
185+
fun(x::AbstractVector) = sqrt(3) .+ x.^3 ./ (sum(x)^2)
186+
res = mapslices(fun, ten, dims=3)
187+
@test res slicemap(fun, ten; dims=3, rev=true)
188+
@test res slicemap(fun, ten; dims=3, rev=false)
189+
190+
grad = ForwardDiff.gradient(x -> sum(sin, mapslices(fun, x, dims=3)), ten)
191+
@test grad Zygote.gradient(x -> sum(sin, slicemap(fun, x, dims=3, rev=false)), ten)[1]
192+
@test grad Zygote.gradient(x -> sum(sin, slicemap(fun, x, dims=3, rev=true)), ten)[1]
193+
194+
end

0 commit comments

Comments
 (0)