Skip to content

Commit dd3b0a8

Browse files
author
Michael Abbott
committed
day two
1 parent 47e3b52 commit dd3b0a8

File tree

3 files changed

+162
-8
lines changed

3 files changed

+162
-8
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,16 @@ version = "0.1.0"
55

66
[deps]
77
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
8+
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
89
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
910
TensorCast = "02d47bb6-7ce6-556a-be16-bb1710789e2b"
1011
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
1112
WeightedArrays = "379a43df-f81c-573e-83a6-069eb6c11a71"
1213
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1314

15+
[compat]
16+
julia = "1"
17+
1418
[extras]
1519
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1620

README.md

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ ForwardDiff.gradient(m -> sum(sin, mapslices(fun, m, dims=1)), mat)
1919
Tracker.gradient(m -> sum(sin, mapcols(fun, m)), mat)[1] # Tracker.forward per slice
2020
Tracker.gradient(m -> sum(sin, MapCols{3}(fun, m)), mat)[1] # ForwardDiff on slices
2121

22-
# Zygote.gradient(m -> sum(sin, mapslices(fun, m, dims=1)), mat)
22+
# Zygote.gradient(m -> sum(sin, mapslices(fun, m, dims=1)), mat) # errors
2323
Zygote.gradient(m -> sum(sin, mapcols(fun, m)), mat)[1] # Zygote.forward
2424
Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m)), mat)[1]
2525
```
@@ -36,15 +36,28 @@ mat1k = rand(3,1000);
3636

3737
@btime ForwardDiff.gradient(m -> sum(sin, mapslices(fun, m, dims=1)), $mat1k); # 372.705 ms
3838
@btime Tracker.gradient(m -> sum(sin, mapcols(fun, m)), $mat1k); # 70.203 ms
39-
@btime Tracker.gradient(m -> sum(sin, MapCols{3}(fun, m)), $mat1k); # 255.032 μs
40-
@btime Zygote.gradient(m -> sum(sin, mapcols(fun, m)), $mat1k); # 20.018 ms
39+
@btime Tracker.gradient(m -> sum(sin, MapCols{3}(fun, m)), $mat1k); # 255.032 μs, 690.09 KiB
40+
@btime Zygote.gradient(m -> sum(sin, mapcols(fun, m)), $mat1k); # 20.018 ms, 3.82 MiB
4141
@btime Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m)), $mat1k); # 354.112 μs
4242
```
4343

4444
Of course `mapslices()` does things other than columns of matrices.
4545
Most of which can be done better with `eachslice()` and `reduce(hcat,...)`,
4646
maybe with some thought one could just write gradients for those.
4747

48+
Perhaps done. The views of `eachcol()` have quite inefficient gradients,
49+
but `collecteachcol()` is efficient:
50+
51+
```julia
52+
@btime Zygote.gradient(m -> sum(sin, mapcols4(fun, m)), $mat1k); # 45.616 ms, 49.49 MiB
53+
@btime Zygote.gradient(m -> sum(sin, mapcols6(fun, m)), $mat1k); # 18.655 ms, 3.37 MiB
54+
```
55+
56+
<!--
4857
Or for the slice/glue functions in [TensorCast](https://github.com/mcabbott/TensorCast.jl),
4958
which now does some mapslices things (and will soon do many more) by chaining such functions.
59+
-->
5060

61+
Issues about mapslices:
62+
* https://github.com/FluxML/Zygote.jl/issues/92
63+
* https://github.com/FluxML/Flux.jl/issues/741

src/SliceMap.jl

Lines changed: 142 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,49 @@ module SliceMap
33

44
export MapCols, mapcols
55

6+
#========== Gradient Macro ==========#
7+
8+
using MacroTools, Tracker, Zygote
9+
using Tracker: TrackedMatrix, track, @grad, data
10+
using Zygote: @adjoint, _zero
11+
12+
macro gradadjoint(ex)
13+
quote
14+
# $(Zygote.gradm(ex)) # this doesn't work
15+
$(trackergrad(ex))
16+
end
17+
end
18+
19+
# Copied from https://github.com/FluxML/Tracker.jl/blob/master/src/Tracker.jl#L55
20+
function trackergrad(ex)
21+
@capture(shortdef(ex), (name_(args__) = body_) |
22+
(name_(args__) where {T__} = body_)) || error("Need a function definition")
23+
T == nothing && (T = [])
24+
isexpr(name, :(::)) || (name = :(::typeof($name)))
25+
insert!(args, 1+isexpr(args[1], :parameters) , name)
26+
MacroTools.@q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
27+
end
28+
629
#========== Reverse, Eachslice ==========#
730

31+
using WeightedArrays
32+
833
"""
934
mapcols(f, m::Matrix, args...) = reduce(hcat, f(c, args...) for c in eachcol(M))
1035
1136
When `m::TrackedMatrix`, it saves the backward function for each slice.
37+
All further arguments are scalar constants, i.e. they do not get sliced/iterated (unlike `map`)
38+
nor are their gradients tracked.
1239
"""
13-
mapcols(f::Function, M::Matrix, args...) =
40+
mapcols(f::Function, M::AbstractMatrix, args...) =
1441
reduce(hcat, [ rvec(f(col, args...)) for col in eachcol(M) ])
1542

16-
using Tracker
17-
using Tracker: TrackedMatrix, track, @grad, data
43+
mapcols(f::Function, M::WeightedMatrix, args...) =
44+
Weighted(mapcols(f, M.array, args...), M.weights, M.opt)
1845

1946
mapcols(f::Function, M::TrackedMatrix, args...) = track(mapcols, f, M, args...)
2047

21-
@grad function mapcols(f::Function, M::TrackedMatrix, args...)
48+
@gradadjoint function mapcols(f::Function, M::AbstractMatrix, args...)
2249
res = [ Tracker.forward(x -> rvec(f(x, args...)), col) for col in eachcol(data(M)) ]
2350
fwd = reduce(hcat, data.(first.(res)))
2451
function back(Δ)
@@ -29,7 +56,7 @@ mapcols(f::Function, M::TrackedMatrix, args...) = track(mapcols, f, M, args...)
2956
fwd, back
3057
end
3158

32-
using Zygote
59+
# @gradadjoint not yet working
3360
Zygote.@adjoint function mapcols(f::Function, M::Matrix, args...)
3461
res = [ Zygote.forward(x -> rvec(f(x, args...)), col) for col in eachcol(data(M)) ]
3562
fwd = reduce(hcat, data.(first.(res)))
@@ -41,6 +68,23 @@ Zygote.@adjoint function mapcols(f::Function, M::Matrix, args...)
4168
fwd, back
4269
end
4370

71+
maprows(f::Function, M::AbstractMatrix, args...) =
72+
reduce(vcat, [ tvec(f(col, args...)) for col in eachrow(M) ])
73+
74+
maprows(f::Function, M::TrackedMatrix, args...) = track(maprows, f, M, args...)
75+
76+
@gradadjoint function maprows(f::Function, M::AbstractMatrix, args...)
77+
res = [ Tracker.forward(x -> tvec(f(x, args...)), row) for row in eachrow(data(M)) ]
78+
fwd = reduce(vcat, data.(first.(res)))
79+
function back(Δ)
80+
rows = [ data((last(res[r]))(Δrow)[1]) for (r, Δrow) in enumerate(eachrow(data(Δ))) ]
81+
∇M = reduce(vcat, rows)
82+
(nothing, ∇M, map(_->nothing, args)...)
83+
end
84+
fwd, back
85+
end
86+
87+
4488
#========== Forward, Static ==========#
4589

4690
using TensorCast, StaticArrays, WeightedArrays
@@ -80,6 +124,7 @@ rvec(x::Number) = [x] # to allow for f vector -> scalar, as mapslices does
80124
rvec(x::StaticArray) = vec(Array(x)) # to avoid creating a giant staticarray, as reduce(hcat would otherwise do
81125
rvec(A) = vec(A) # LinearAlgebra.
82126

127+
tvec(x) = transpose(rvec(x))
83128

84129
using ForwardDiff
85130

@@ -111,6 +156,8 @@ MapCols{d}(f::Function, M::TrackedMatrix, args...) where {d} = track(MapCols, f,
111156
Z, back
112157
end
113158

159+
# TODO make a _MapCols which always takes Val(d), then unite these
160+
114161
Zygote.@adjoint function MapCols{d}(f::Function, M::Matrix, args...) where {d} # no dval!
115162

116163
@cast A[c]{r:d} := M[r,c]
@@ -137,4 +184,94 @@ Zygote.@adjoint function MapCols{d}(f::Function, M::Matrix, args...) where {d} #
137184
Z, back
138185
end
139186

187+
#========== Gradient for eachslice / reduce ==========#
188+
189+
export gluecol, mapcols2, mapcols4, mapcols5, mapcols6, mapcols7
190+
191+
gluecol(V::AbstractVector{<:AbstractVector}) = reduce(hcat, V)
192+
193+
gluecol(V::AbstractVector{<:TrackedVector}) = track(gluecol, V)
194+
195+
@grad function gluecol(V::AbstractVector)
196+
gluecol(data.(V)), ΔM -> (collect(eachcol(data(ΔM))),) # doesn't work
197+
end
198+
199+
Zygote.@adjoint function gluecol(V::AbstractVector)
200+
gluecol(V), ΔM -> (collect(eachcol(ΔM)),) # does work!
201+
end
202+
203+
function mapcols2(f, A)
204+
cols = [A[:,c] for c=1:size(A,2)]
205+
res = f.(cols)
206+
gluecol(res)
207+
end
208+
209+
# Apply that straight to reduce(hcat,...)
210+
211+
Zygote.@adjoint function Base.reduce(::typeof(hcat), V::AbstractVector{<:AbstractVector})
212+
reduce(hcat, V), dV -> (nothing, collect(eachcol(dV)),)
213+
end
214+
215+
function mapcols4(f, A)
216+
cols = [view(A,:,c) for c=1:size(A,2)]
217+
res = map(f, cols)
218+
reduce(hcat, res)
219+
end
220+
221+
# Zygote doesn't understand views, but easy to fix:
222+
# https://github.com/FluxML/Zygote.jl/issues/52
223+
# now https://github.com/FluxML/Zygote.jl/pull/219
224+
225+
Zygote.@adjoint function view(x::AbstractArray, inds...; kwargs...)
226+
view(x, inds...; kwargs...), dy -> begin
227+
dx = _zero(x)
228+
copyto!(view(dx, inds...; kwargs...), dy)
229+
(dx, map(_->nothing, inds)...)
230+
end
231+
end
232+
233+
# Surprisingly dy for eachcol seems to know the answer?
234+
# typeof(dy) = NamedTuple{(:f, :iter),Tuple{NamedTuple{(:A,),Tuple{Array{Float64,2}}},Array{Nothing,1}}}
235+
# dy = (f = (A = [47.9325 51.3781
236+
# Which means this works... but uses as much memory as gradient of array of views:
237+
238+
Zygote.@adjoint function eachcol(x::AbstractMatrix)
239+
eachcol(x), dy -> (dy.f.A,) #= begin
240+
@show typeof(dy) dy
241+
dx = zero(x) .+ 0.0 # zeros(eltype(dy), size(x))
242+
foreach(copyto!, eachcol(dx), dy)
243+
(dx,)
244+
end =#
245+
end
246+
247+
# @adjoint eachcol(x) = eachcol(x), dy -> (dy.f.A,)
248+
249+
function mapcols5(f, A)
250+
cols = collect(eachcol(A))
251+
res = map(f, cols)
252+
reduce(hcat, res)
253+
end
254+
255+
collecteachcol(x) = collect(eachcol(x))
256+
257+
Zygote.@adjoint function collecteachcol(x)
258+
collecteachcol(x), dy -> begin
259+
dx = _zero(x)
260+
foreach(copyto!, collecteachcol(dx), dy)
261+
(dx,)
262+
end
263+
end
264+
265+
function mapcols6(f, A)
266+
cols = collecteachcol(A)
267+
res = map(f, cols)
268+
reduce(hcat, res)
269+
end
270+
271+
# function mapcols7(f, A)
272+
# cols = eachcol(A) # without collect. Zygote.gradient -> StackOverflowError
273+
# res = map(f, cols)
274+
# reduce(hcat, res)
275+
# end
276+
140277
end # module

0 commit comments

Comments
 (0)