Skip to content

Commit 47e3b52

Browse files
author
Michael Abbott
committed
day one
0 parents  commit 47e3b52

File tree

4 files changed

+217
-0
lines changed

4 files changed

+217
-0
lines changed

Project.toml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
name = "SliceMap"
2+
uuid = "82cb661a-3f19-5665-9e27-df437c7e54c8"
3+
authors = ["Michael Abbott"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
8+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
9+
TensorCast = "02d47bb6-7ce6-556a-be16-bb1710789e2b"
10+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
11+
WeightedArrays = "379a43df-f81c-573e-83a6-069eb6c11a71"
12+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
13+
14+
[extras]
15+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
16+
17+
[targets]
18+
test = ["Test"]

README.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# SliceMap.jl
2+
3+
It would be nice if [Flux](https://github.com/FluxML/Flux.jl) worked with `mapslices`,
4+
or with something generalising that. This package has some quick attempts:
5+
6+
```julia
7+
mat = rand(1:99, 3,10)
8+
fun(x) = 2 .+ x.^2
9+
mapslices(fun, mat, dims=1)
10+
11+
using SliceMap
12+
13+
mapcols(fun, mat) # eachcol(m)
14+
MapCols{3}(fun, mat) # reinterpret(SArray,...)
15+
16+
using Tracker, Zygote, ForwardDiff
17+
ForwardDiff.gradient(m -> sum(sin, mapslices(fun, m, dims=1)), mat)
18+
19+
Tracker.gradient(m -> sum(sin, mapcols(fun, m)), mat)[1] # Tracker.forward per slice
20+
Tracker.gradient(m -> sum(sin, MapCols{3}(fun, m)), mat)[1] # ForwardDiff on slices
21+
22+
# Zygote.gradient(m -> sum(sin, mapslices(fun, m, dims=1)), mat)
23+
Zygote.gradient(m -> sum(sin, mapcols(fun, m)), mat)[1] # Zygote.forward
24+
Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m)), mat)[1]
25+
```
26+
27+
These are a bit faster than `mapslices` too:
28+
29+
```julia
30+
mat1k = rand(3,1000);
31+
32+
@btime mapslices(fun, $mat1k, dims=1) # 1.017 ms
33+
@btime mapcols(fun, $mat1k) # 399.016 μs
34+
@btime MapCols{3}(fun, $mat1k) # 46.733 μs
35+
@btime MapCols(fun, $mat1k) # 59.471 μs without size
36+
37+
@btime ForwardDiff.gradient(m -> sum(sin, mapslices(fun, m, dims=1)), $mat1k); # 372.705 ms
38+
@btime Tracker.gradient(m -> sum(sin, mapcols(fun, m)), $mat1k); # 70.203 ms
39+
@btime Tracker.gradient(m -> sum(sin, MapCols{3}(fun, m)), $mat1k); # 255.032 μs
40+
@btime Zygote.gradient(m -> sum(sin, mapcols(fun, m)), $mat1k); # 20.018 ms
41+
@btime Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m)), $mat1k); # 354.112 μs
42+
```
43+
44+
Of course `mapslices()` does things other than columns of matrices.
45+
Most of which can be done better with `eachslice()` and `reduce(hcat,...)`,
46+
maybe with some thought one could just write gradients for those.
47+
48+
Or for the slice/glue functions in [TensorCast](https://github.com/mcabbott/TensorCast.jl),
49+
which now does some mapslices things (and will soon do many more) by chaining such functions.
50+

src/SliceMap.jl

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
2+
module SliceMap
3+
4+
export MapCols, mapcols
5+
6+
#========== Reverse, Eachslice ==========#
7+
8+
"""
9+
mapcols(f, m::Matrix, args...) = reduce(hcat, f(c, args...) for c in eachcol(M))
10+
11+
When `m::TrackedMatrix`, it saves the backward function for each slice.
12+
"""
13+
mapcols(f::Function, M::Matrix, args...) =
14+
reduce(hcat, [ rvec(f(col, args...)) for col in eachcol(M) ])
15+
16+
using Tracker
17+
using Tracker: TrackedMatrix, track, @grad, data
18+
19+
mapcols(f::Function, M::TrackedMatrix, args...) = track(mapcols, f, M, args...)
20+
21+
@grad function mapcols(f::Function, M::TrackedMatrix, args...)
22+
res = [ Tracker.forward(x -> rvec(f(x, args...)), col) for col in eachcol(data(M)) ]
23+
fwd = reduce(hcat, data.(first.(res)))
24+
function back(Δ)
25+
cols = [ data((last(res[c]))(Δcol)[1]) for (c, Δcol) in enumerate(eachcol(data(Δ))) ]
26+
∇M = reduce(hcat, cols)
27+
(nothing, ∇M, map(_->nothing, args)...)
28+
end
29+
fwd, back
30+
end
31+
32+
using Zygote
33+
Zygote.@adjoint function mapcols(f::Function, M::Matrix, args...)
34+
res = [ Zygote.forward(x -> rvec(f(x, args...)), col) for col in eachcol(data(M)) ]
35+
fwd = reduce(hcat, data.(first.(res)))
36+
function back(Δ)
37+
cols = [ data((last(res[c]))(Δcol)[1]) for (c, Δcol) in enumerate(eachcol(data(Δ))) ]
38+
∇M = reduce(hcat, cols)
39+
(nothing, ∇M, map(_->nothing, args)...)
40+
end
41+
fwd, back
42+
end
43+
44+
#========== Forward, Static ==========#
45+
46+
using TensorCast, StaticArrays, WeightedArrays
47+
48+
struct MapCols{d} end
49+
50+
"""
51+
MapCols{d}(f, m::Matrix, args...)
52+
53+
Expects `f(::SVector{d}, args...)` and maps this over the columns, `d = size(M,1)`.
54+
Doesn't expect `f` to return a staticarray, just an array.
55+
56+
When `m::TrackedMatrix`, it uses `ForwardDiff` to calculate the gradient of each slice.
57+
The second point of keeping one type parameter is that the dual numbers needed depend on this.
58+
59+
MapCols{d}(f, m::Weighted, args...)
60+
Takes `m.weights` along for the ride.
61+
"""
62+
MapCols(f::Function, M::WeightedArrays.MaybeWeightedMatrix, args...) =
63+
MapCols{size(M,1)}(f, M, args...)
64+
65+
MapCols{d}(f::Function, M::WeightedMatrix, args...) where {d} =
66+
Weighted(MapCols{d}(f, M.array, args...), M.weights, M.opt)
67+
68+
function MapCols{d}(f::Function, M::Matrix, args...) where {d}
69+
@cast A[c]{r:d} := M[r,c] assert
70+
reduce(hcat, [ rvec(f(acol, args...)) for acol in A ])
71+
72+
# TODO: call some function which static-glues if possible...
73+
# TensorCast.auto_glue(map(col -> rvec(f(col, args...)), A), (:,*))
74+
75+
# TODO: can I thread this? Is it even safe to do so?
76+
# https://github.com/mohamed82008/KissThreading.jl
77+
end
78+
79+
rvec(x::Number) = [x] # to allow for f vector -> scalar, as mapslices does
80+
rvec(x::StaticArray) = vec(Array(x)) # to avoid creating a giant staticarray, as reduce(hcat would otherwise do
81+
rvec(A) = vec(A) # LinearAlgebra.
82+
83+
84+
using ForwardDiff
85+
86+
MapCols{d}(f::Function, M::TrackedMatrix, args...) where {d} = track(MapCols, f, M, Val(d), args...)
87+
88+
@grad function MapCols(f::Function, M::TrackedMatrix, dval::Val{d}, args...) where {d}
89+
90+
@cast A[c]{r:d} := M.data[r,c]
91+
dualcol = SVector(ntuple(j->ForwardDiff.Dual(0, ntuple(i->i==j ? 1 : 0, dval)...), dval))
92+
93+
C = [ rvec(f(acol .+ dualcol, args...)) for acol in A ]
94+
95+
Z = reduce(hcat, [ ForwardDiff.value.(full) for full in C ]) # full is not an SVector here
96+
97+
function back(ΔZ)
98+
∇M = similar(data(M)) .+ zero(first(data(ΔZ)))
99+
@inbounds for c=1:size(M,2)
100+
part = ForwardDiff.partials.(C[c])
101+
for r=1:d
102+
∇M[r,c] = 0
103+
for i=1:size(ΔZ,1)
104+
∇M[r,c] += data(ΔZ)[i,c] * part[i].values[r]
105+
end
106+
end
107+
end
108+
(nothing, ∇M, nothing, map(_->nothing, args)...)
109+
end
110+
111+
Z, back
112+
end
113+
114+
Zygote.@adjoint function MapCols{d}(f::Function, M::Matrix, args...) where {d} # no dval!
115+
116+
@cast A[c]{r:d} := M[r,c]
117+
dualcol = SVector(ntuple(j->ForwardDiff.Dual(0, ntuple(i->i==j ? 1 : 0, Val(d))...), Val(d)))
118+
119+
C = [ rvec(f(acol .+ dualcol, args...)) for acol in A ]
120+
121+
Z = reduce(hcat, [ ForwardDiff.value.(full) for full in C ])
122+
123+
function back(ΔZ)
124+
∇M = similar(data(M)) .+ zero(first(data(ΔZ)))
125+
@inbounds for c=1:size(M,2)
126+
part = ForwardDiff.partials.(C[c])
127+
for r=1:d
128+
∇M[r,c] = 0
129+
for i=1:size(ΔZ,1)
130+
∇M[r,c] += data(ΔZ)[i,c] * part[i].values[r]
131+
end
132+
end
133+
end
134+
(nothing, ∇M, map(_->nothing, args)...) # changed!
135+
end
136+
137+
Z, back
138+
end
139+
140+
end # module

test/runtests.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
2+
using SliceMap
3+
using Test
4+
5+
@testset "nothing" begin
6+
7+
@test true
8+
9+
end

0 commit comments

Comments
 (0)