Skip to content

Commit e40d0bd

Browse files
author
Michael Abbott
committed
remove some ::Function, and add tests f(w)(x)
1 parent 897cf2c commit e40d0bd

File tree

2 files changed

+31
-7
lines changed

2 files changed

+31
-7
lines changed

src/SliceMap.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,23 @@ Any arguments after the matrix are passed to `f` as scalars, i.e.
2626
`mapcols(f, m, args...) = reduce(hcat, f(col, args...) for col in eeachcol(m))`.
2727
They do not get sliced/iterated (unlike `map`), nor are their gradients tracked.
2828
"""
29-
mapcols(f::Function, M, args...) = _mapcols(map, f, M, args...)
30-
tmapcols(f::Function, M, args...) = _mapcols(threadmap, f, M, args...)
29+
mapcols(f, M, args...) = _mapcols(map, f, M, args...)
30+
tmapcols(f, M, args...) = _mapcols(threadmap, f, M, args...)
3131

32-
function _mapcols(map::Function, f::Function, M::AbstractMatrix, args...)
32+
function _mapcols(map::Function, f, M::AbstractMatrix, args...)
3333
res = map(col -> _vec(f(col, args...)), eachcol(M))
3434
eltype(res) <: AbstractVector ? reduce(hcat, res) : reshape(res,1,:)
3535
end
3636

3737
_vec(x) = x
3838
_vec(A::AbstractArray) = vec(A) # to allow f vector -> matrix, by reshaping
3939

40-
_mapcols(map::Function, f::Function, M::TrackedMatrix, args...) = track(_mapcols, map, f, M, args...)
40+
_mapcols(map::Function, f, M::TrackedMatrix, args...) = track(_mapcols, map, f, M, args...)
4141

42-
@grad _mapcols(map::Function, f::Function, M::AbstractMatrix, args...) =
42+
@grad _mapcols(map::Function, f, M::AbstractMatrix, args...) =
4343
∇mapcols(map, map(col -> Tracker.forward(x -> _vec(f(x, args...)), col), eachcol(data(M))), args...)
4444

45-
@adjoint _mapcols(map::Function, f::Function, M::AbstractMatrix, args...) =
45+
@adjoint _mapcols(map::Function, f, M::AbstractMatrix, args...) =
4646
∇mapcols(map, map(col -> ZygoteRules.pullback(x -> _vec(f(x, args...)), col), eachcol(M)), args)
4747

4848
function ∇mapcols(bigmap, forwards, args...)
@@ -91,7 +91,7 @@ e.g. if `dims=(2,4)` then `f` must map matrices to matrices.
9191
9292
The gradient is for Zygote only.
9393
"""
94-
function slicemap(f::Function, A::AbstractArray{T,N}, args...; dims) where {T,N}
94+
function slicemap(f, A::AbstractArray{T,N}, args...; dims) where {T,N}
9595
code = ntuple(d -> d in dims ? True() : False(), N)
9696
B = JuliennedArrays.Slices(A, code...)
9797
C = [ f(slice, args...) for slice in B ]

test/runtests.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,27 @@ end
133133
@test grad Zygote.gradient(m -> sum(sin, j3(fun, m)), ten)[1]
134134

135135
end
136+
@testset "gradient of the function" begin
137+
138+
struct F W end
139+
(f::F)(x) = f.W * x # toy version of e.g. Flux.Dense
140+
w = rand(3,2)
141+
x = rand(2,5)
142+
gradx = ForwardDiff.gradient(x -> sum(mapslices(F(w), x, dims=1)), x)
143+
gradw = ForwardDiff.gradient(w -> sum(mapslices(F(w), x, dims=1)), w)
144+
145+
wp = Tracker.param(w)
146+
xp = Tracker.param(x)
147+
Tracker.back!(sum(mapcols(F(wp), xp)))
148+
@test Tracker.grad(xp) gradx
149+
@test_broken Tracker.grad(wp) gradw # zero
150+
151+
grad_mapcols = Zygote.gradient(() -> sum(mapcols(F(w), x)), Zygote.Params([w,x]))
152+
@test grad_mapcols[x] gradx
153+
@test_broken grad_mapcols[w] gradw # grad_mapcols[w] === nothing
154+
155+
grad_slicemap = Zygote.gradient(() -> sum(slicemap(F(w), x, dims=1)), Zygote.Params([w,x]))
156+
@test grad_slicemap[x] gradx
157+
@test_broken grad_slicemap[w] gradw # wrong numbers
158+
159+
end

0 commit comments

Comments
 (0)