Skip to content

Commit 859c152

Browse files
author
Michael Abbott
committed
fix tests for PR, add comments
1 parent 836b1ba commit 859c152

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

src/SliceMap.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ It provides a gradient for Tracker and Zygote, saving the backward function for
2525
Any arguments after the matrix are passed to `f` as scalars, i.e.
2626
`mapcols(f, m, args...) = reduce(hcat, f(col, args...) for col in eeachcol(m))`.
2727
They do not get sliced/iterated (unlike `map`), nor are their gradients tracked.
28+
29+
Note that if `f` itself contains parameters, their gradients are also not tracked.
2830
"""
2931
mapcols(f, M, args...) = _mapcols(map, f, M, args...)
3032
tmapcols(f, M, args...) = _mapcols(threadmap, f, M, args...)
@@ -90,6 +92,9 @@ Like `mapcols()`, but for any slice. The function `f` must preserve shape,
9092
e.g. if `dims=(2,4)` then `f` must map matrices to matrices.
9193
9294
The gradient is for Zygote only.
95+
96+
Parameters within the function `f` (if there are any) should be correctly tracked,
97+
which is not the case for `mapcols()`.
9398
"""
9499
function slicemap(f, A::AbstractArray{T,N}, args...; dims) where {T,N}
95100
code = ntuple(d -> d in dims ? True() : False(), N)

test/runtests.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,15 +146,24 @@ end
146146
xp = Tracker.param(x)
147147
Tracker.back!(sum(mapcols(F(wp), xp)))
148148
@test Tracker.grad(xp) gradx
149-
@test_broken Tracker.grad(wp) gradw # zero
149+
@test Tracker.grad(wp) == 0 .* gradw # bug or a feature?
150150

151-
grad_mapcols = Zygote.gradient(() -> sum(mapcols(F(w), x)), Zygote.Params([w,x]))
151+
# fp = F(wp)
152+
# wp.grad .= 0; xp.grad .= 0;
153+
# Tracker.back!(sum(mapcols(fp, xp)))
154+
# @test Tracker.grad(xp) ≈ gradx
155+
# @test_broken Tracker.grad(wp) ≈ gradw # zero
156+
157+
f = F(w)
158+
grad_mapcols = Zygote.gradient(() -> sum(mapcols(f, x)), Zygote.Params([w,x]))
152159
@test grad_mapcols[x] gradx
153-
@test_broken grad_mapcols[w] gradw # grad_mapcols[w] === nothing
160+
@test grad_mapcols[w] == nothing # bug or a feature?
154161

155-
grad_slicemap = Zygote.gradient(() -> sum(slicemap(F(w), x, dims=1)), Zygote.Params([w,x]))
162+
grad_slicemap = Zygote.gradient(() -> sum(slicemap(f, x, dims=1)), Zygote.Params([w,x]))
156163
@test grad_slicemap[x] gradx
157-
@test_broken grad_slicemap[w] gradw # wrong numbers
164+
@test grad_slicemap[w] gradw
158165
@test gradw Zygote.gradient(w -> sum(slicemap(F(w), x, dims=1)), w)[1]
166+
# Using F(w) with Params() gives wrong answers:
167+
# https://github.com/FluxML/Zygote.jl/issues/522#issuecomment-605935652
159168

160169
end

0 commit comments

Comments
 (0)