Skip to content

Commit 5f71649

Browse files
authored
Merge pull request #7 from mcabbott/impure
Allow f::Any, possibly containing parameters
2 parents 897cf2c + 7365add commit 5f71649

File tree

4 files changed

+48
-11
lines changed

4 files changed

+48
-11
lines changed

.travis.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
language: julia
22
os:
33
- linux
4-
- osx
5-
- windows
64
julia:
75
- 1.1
8-
- 1.2
6+
- 1.4
97
- nightly
108

119
matrix:

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SliceMap"
22
uuid = "82cb661a-3f19-5665-9e27-df437c7e54c8"
33
authors = ["Michael Abbott"]
4-
version = "0.2.1"
4+
version = "0.2.2"
55

66
[deps]
77
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

src/SliceMap.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,24 +25,26 @@ It provides a gradient for Tracker and Zygote, saving the backward function for
2525
Any arguments after the matrix are passed to `f` as scalars, i.e.
2626
`mapcols(f, m, args...) = reduce(hcat, f(col, args...) for col in eeachcol(m))`.
2727
They do not get sliced/iterated (unlike `map`), nor are their gradients tracked.
28+
29+
Note that if `f` itself contains parameters, their gradients are also not tracked.
2830
"""
29-
mapcols(f::Function, M, args...) = _mapcols(map, f, M, args...)
30-
tmapcols(f::Function, M, args...) = _mapcols(threadmap, f, M, args...)
31+
mapcols(f, M, args...) = _mapcols(map, f, M, args...)
32+
tmapcols(f, M, args...) = _mapcols(threadmap, f, M, args...)
3133

32-
function _mapcols(map::Function, f::Function, M::AbstractMatrix, args...)
34+
function _mapcols(map::Function, f, M::AbstractMatrix, args...)
3335
res = map(col -> _vec(f(col, args...)), eachcol(M))
3436
eltype(res) <: AbstractVector ? reduce(hcat, res) : reshape(res,1,:)
3537
end
3638

3739
_vec(x) = x
3840
_vec(A::AbstractArray) = vec(A) # to allow f vector -> matrix, by reshaping
3941

40-
_mapcols(map::Function, f::Function, M::TrackedMatrix, args...) = track(_mapcols, map, f, M, args...)
42+
_mapcols(map::Function, f, M::TrackedMatrix, args...) = track(_mapcols, map, f, M, args...)
4143

42-
@grad _mapcols(map::Function, f::Function, M::AbstractMatrix, args...) =
44+
@grad _mapcols(map::Function, f, M::AbstractMatrix, args...) =
4345
∇mapcols(map, map(col -> Tracker.forward(x -> _vec(f(x, args...)), col), eachcol(data(M))), args...)
4446

45-
@adjoint _mapcols(map::Function, f::Function, M::AbstractMatrix, args...) =
47+
@adjoint _mapcols(map::Function, f, M::AbstractMatrix, args...) =
4648
∇mapcols(map, map(col -> ZygoteRules.pullback(x -> _vec(f(x, args...)), col), eachcol(M)), args)
4749

4850
function ∇mapcols(bigmap, forwards, args...)
@@ -90,8 +92,11 @@ Like `mapcols()`, but for any slice. The function `f` must preserve shape,
9092
e.g. if `dims=(2,4)` then `f` must map matrices to matrices.
9193
9294
The gradient is for Zygote only.
95+
96+
Parameters within the function `f` (if there are any) should be correctly tracked,
97+
which is not the case for `mapcols()`.
9398
"""
94-
function slicemap(f::Function, A::AbstractArray{T,N}, args...; dims) where {T,N}
99+
function slicemap(f, A::AbstractArray{T,N}, args...; dims) where {T,N}
95100
code = ntuple(d -> d in dims ? True() : False(), N)
96101
B = JuliennedArrays.Slices(A, code...)
97102
C = [ f(slice, args...) for slice in B ]

test/runtests.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,37 @@ end
133133
@test grad Zygote.gradient(m -> sum(sin, j3(fun, m)), ten)[1]
134134

135135
end
136+
@testset "gradient of the function" begin
137+
138+
struct F W end
139+
(f::F)(x) = f.W * x # toy version of e.g. Flux.Dense
140+
w = rand(3,2)
141+
x = rand(2,5)
142+
gradx = ForwardDiff.gradient(x -> sum(mapslices(F(w), x, dims=1)), x)
143+
gradw = ForwardDiff.gradient(w -> sum(mapslices(F(w), x, dims=1)), w)
144+
145+
wp = Tracker.param(w)
146+
xp = Tracker.param(x)
147+
Tracker.back!(sum(mapcols(F(wp), xp)))
148+
@test Tracker.grad(xp) gradx
149+
@test Tracker.grad(wp) == 0 .* gradw # bug or a feature?
150+
151+
# fp = F(wp)
152+
# wp.grad .= 0; xp.grad .= 0;
153+
# Tracker.back!(sum(mapcols(fp, xp)))
154+
# @test Tracker.grad(xp) ≈ gradx
155+
# @test_broken Tracker.grad(wp) ≈ gradw # zero
156+
157+
f = F(w)
158+
grad_mapcols = Zygote.gradient(() -> sum(mapcols(f, x)), Zygote.Params([w,x]))
159+
@test grad_mapcols[x] gradx
160+
@test grad_mapcols[w] == nothing # bug or a feature?
161+
162+
grad_slicemap = Zygote.gradient(() -> sum(slicemap(f, x, dims=1)), Zygote.Params([w,x]))
163+
@test grad_slicemap[x] gradx
164+
@test grad_slicemap[w] gradw
165+
@test gradw Zygote.gradient(w -> sum(slicemap(F(w), x, dims=1)), w)[1]
166+
# Using F(w) with Params() gives wrong answers:
167+
# https://github.com/FluxML/Zygote.jl/issues/522#issuecomment-605935652
168+
169+
end

0 commit comments

Comments
 (0)