Skip to content

Commit 5879275

Browse files
author
Michael Abbott
committed
bugs, tests
1 parent abce623 commit 5879275

File tree

4 files changed

+139
-13
lines changed

4 files changed

+139
-13
lines changed

.travis.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
language: julia
2+
os:
3+
- linux
4+
- osx
5+
julia:
6+
- 1.0
7+
- 1.1
8+
- nightly
9+
10+
matrix:
11+
allow_failures:
12+
- julia: nightly

README.md

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SliceMap.jl
22

3-
It would be nice if [Flux](https://github.com/FluxML/Flux.jl) worked with `mapslices`,
3+
It would be nice if [Flux](https://github.com/FluxML/Flux.jl) / [Zygote](https://github.com/FluxML/Zygote.jl) worked with `mapslices`,
44
or with something generalising that. This package has some quick attempts:
55

66
```julia
@@ -23,7 +23,8 @@ Zygote.gradient(m -> sum(sin, mapcols(fun, m)), mat)[1] # Zygote.forward
2323
Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m)), mat)[1]
2424
```
2525

26-
These are a bit faster than `mapslices` too:
26+
These are a bit faster than `mapslices` too. Although storing all the backward functions,
27+
which is what `mapcols` does, seems not to be so quick:
2728

2829
```julia
2930
using BenchmarkTools
@@ -47,7 +48,7 @@ Of course `mapslices()` does things other than columns of matrices.
4748
Most of which can be done better with `eachslice()` and `reduce(hcat,...)`,
4849
maybe with some thought one could just write gradients for those...
4950

50-
Perhaps this is done. The views of `eachcol()` have quite inefficient gradients,
51+
Perhaps this is done, at least for Zygote. The views of `eachcol()` have quite inefficient gradients,
5152
because for each `view()` they make a fresh `zero(A)`, but `collecteachcol()` is efficient:
5253

5354
```julia
@@ -73,6 +74,25 @@ ten = rand(1:9, 3,10,2)
7374
Zygote.gradient(m -> sum(sin, @cast zed[i,j,k] := fun(m[i,:,k])[j] nolazy), ten)[1]
7475
```
7576

77+
The function `slicemap(f, A, dims)` uses these slice/glue functions,
78+
without having to write index notation.
79+
7680
Issues about mapslices:
7781
* https://github.com/FluxML/Zygote.jl/issues/92
7882
* https://github.com/FluxML/Flux.jl/issues/741
83+
84+
Other packages which define gradients of possible interest:
85+
* https://github.com/GiggleLiu/LinalgBackwards.jl
86+
* https://github.com/mcabbott/ArrayAllez.jl
87+
88+
I added some tests:
89+
[![Build Status](https://travis-ci.org/mcabbott/SliceMap.jl.svg?branch=master)](https://travis-ci.org/mcabbott/SliceMap.jl)
90+
91+
<!--
92+
AD packages this could perhaps support, quite the zoo:
93+
* https://github.com/invenia/Nabla.jl
94+
* https://github.com/dfdx/Yota.jl
95+
* https://github.com/denizyuret/AutoGrad.jl
96+
* https://github.com/Roger-luo/YAAD.jl
97+
* And perhaps one day, just https://github.com/JuliaDiff/ChainRules.jl
98+
-->

src/SliceMap.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,22 +51,20 @@ end
5151
Like `mapcols()` but for rows.
5252
"""
5353
maprows(f::Function, M::AbstractMatrix, args...) =
54-
reduce(vcat, [ surerow(f(col, args...)) for col in eachrow(M) ])
55-
56-
surerow(x) = transpose(surevec(x))
54+
reduce(vcat, [ transpose(surevec(f(col, args...))) for col in eachrow(M) ])
5755

5856
maprows(f::Function, M::TrackedMatrix, args...) = track(maprows, f, M, args...)
5957

6058
@grad maprows(f::Function, M::AbstractMatrix, args...) =
61-
∇maprows([ Tracker.forward(x -> surerow(f(x, args...)), row) for row in eachrow(data(M)) ], args)
59+
∇maprows([ Tracker.forward(x -> surevec(f(x, args...)), row) for row in eachrow(data(M)) ], args)
6260

6361
@adjoint maprows(f::Function, M::AbstractMatrix, args...) =
64-
∇maprows([ Zygote.forward(x -> surerow(f(x, args...)), row) for row in eachrow(M) ], args)
62+
∇maprows([ Zygote.forward(x -> surevec(f(x, args...)), row) for row in eachrow(M) ], args)
6563

6664
function ∇maprows(forwards, args)
67-
reduce(vcat, data.(first.(forwards))), Δ -> begin
65+
reduce(vcat, map(transposedatafirst, forwards)), Δ -> begin
6866
rows = [ data(last(fwd)(Δrow)[1]) for (fwd, Δrow) in zip(forwards, eachrow(data(Δ))) ]
69-
(nothing, reduce(vcat, rows), map(_->nothing, args)...)
67+
(nothing, reduce(vcat, transpose.(rows)), map(_->nothing, args)...)
7068
end
7169
end
7270

@@ -265,7 +263,10 @@ end
265263
"""
266264
slicemap(f, A; dims) ≈ mapslices(f, A; dims)
267265
268-
Like `mapcols()`, but for any slice. Gradient is for Zygote only.
266+
Like `mapcols()`, but for any slice. The function `f` must preserve shape,
267+
e.g. `dims=(2,4)` then `f` must map matrices to matrices.
268+
269+
The gradient is for Zygote only.
269270
"""
270271
function slicemap(f::Function, A::AbstractArray{T,N}, args...; dims) where {T,N}
271272
code = ntuple(d -> d in dims ? (:) : (*), N)

test/runtests.jl

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,102 @@
11

22
using SliceMap
33
using Test
4+
using ForwardDiff, Tracker, Zygote, TensorCast
45

5-
@testset "nothing" begin
6+
Zygote.refresh()
67

7-
@test true
8+
@testset "columns" begin
9+
10+
mat = rand(1:9, 3,10)
11+
fun(x) = 2 .+ x.^2
12+
res = mapslices(fun, mat, dims=1)
13+
14+
@test res mapcols(fun, mat)
15+
@test res MapCols{3}(fun, mat)
16+
17+
grad = ForwardDiff.gradient(m -> sum(sin, mapslices(fun, m, dims=1)), mat)
18+
19+
@test grad Tracker.gradient(m -> sum(sin, mapcols(fun, m)), mat)[1]
20+
@test grad Tracker.gradient(m -> sum(sin, MapCols{3}(fun, m)), mat)[1]
21+
22+
@test grad Zygote.gradient(m -> sum(sin, mapcols(fun, m)), mat)[1]
23+
@test grad Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m)), mat)[1]
24+
25+
tcm(mat) = @cast out[i,j] := fun(mat[:,j])[i]
26+
@test res tcm(mat)
27+
@test grad Zygote.gradient(m -> sum(sin, tcm(m)), mat)[1]
28+
29+
end
30+
@testset "columns, scalar" begin
31+
32+
mat = rand(1:9, 3,10)
33+
fun(x) = sum(x) # different function!
34+
res = mapslices(fun, mat, dims=1)
35+
36+
@test res mapcols(fun, mat)
37+
@test res MapCols{3}(fun, mat)
38+
39+
grad = ForwardDiff.gradient(m -> sum(sin, mapslices(fun, m, dims=1)), mat)
40+
41+
@test grad Tracker.gradient(m -> sum(sin, mapcols(fun, m)), mat)[1]
42+
@test grad Tracker.gradient(m -> sum(sin, MapCols{3}(fun, m)), mat)[1]
43+
44+
@test grad Zygote.gradient(m -> sum(sin, mapcols(fun, m)), mat)[1]
45+
@test grad Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m)), mat)[1]
46+
47+
tcm3(mat) = @cast out[_,j] := fun(mat[:,j]) # changed here too
48+
@test res tcm3(mat)
49+
@test grad Zygote.gradient(m -> sum(sin, tcm3(m)), mat)[1]
50+
51+
end
52+
@testset "columns, matrix" begin
53+
54+
mat = rand(1:9, 3,10)
55+
fun(x) = x .* x' # different function! vector -> matrix
56+
res = mapslices(vecfun, mat, dims=1)
57+
58+
@test res mapcols(fun, mat)
59+
@test res MapCols{3}(fun, mat)
60+
61+
grad = ForwardDiff.gradient(m -> sum(sin, mapslices(vecfun, m, dims=1)), mat)
62+
63+
@test grad Tracker.gradient(m -> sum(sin, mapcols(fun, m)), mat)[1]
64+
@test grad Tracker.gradient(m -> sum(sin, MapCols{3}(fun, m)), mat)[1]
65+
66+
@test grad Zygote.gradient(m -> sum(sin, mapcols(fun, m)), mat)[1]
67+
@test grad Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m)), mat)[1]
68+
69+
tcm4(mat) = @cast out[ii′,j] := fun(mat[:,j])[i,i′] i:3, i′:3 # changed here too
70+
@test res tcm4(mat)
71+
@test grad Zygote.gradient(m -> sum(sin, tcm4(m)), mat)[1]
72+
73+
end
74+
@testset "rows" begin
75+
76+
mat = randn(4,5)
77+
fun(x) = 2 .+ x.^2 ./ sum(x)
78+
79+
res = mapslices(fun, mat, dims=2)
80+
@test res maprows(fun, mat)
81+
82+
grad = ForwardDiff.gradient(m -> sum(sin, mapslices(fun, m, dims=2)), mat)
83+
@test grad Tracker.gradient(m -> sum(sin, maprows(fun, m)), mat)[1]
84+
@test grad Zygote.gradient(m -> sum(sin, maprows(fun, m)), mat)[1]
85+
86+
tcm2(mat) = @cast out[i,j] := fun(mat[i,:])[j]
87+
@test res tcm2(mat)
88+
@test grad Zygote.gradient(m -> sum(sin, tcm2(m)), mat)[1]
89+
90+
end
91+
@testset "slices" begin
92+
93+
ten = randn(3,4,5,2)
94+
fun(x) = sqrt(3) .+ x.^3 ./ (sum(x)^2)
95+
res = mapslices(fun, ten, dims=3)
96+
97+
@test res slicemap(fun, ten, dims=3)
98+
99+
grad = ForwardDiff.gradient(x -> sum(sin, slicemap(fun, x, dims=3)), ten)
100+
@test grad Zygote.gradient(x -> sum(sin, slicemap(fun, x, dims=3)), ten)[1]
8101

9102
end

0 commit comments

Comments
 (0)