Skip to content

Commit b9dd221

Browse files
author
Michael Abbott
committed
readme
1 parent 754d99b commit b9dd221

File tree

1 file changed

+21
-29
lines changed

1 file changed

+21
-29
lines changed

README.md

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,20 @@
11
# SliceMap.jl
22

3-
It would be nice if [Flux](https://github.com/FluxML/Flux.jl) / [Zygote](https://github.com/FluxML/Zygote.jl) worked with `mapslices`,
4-
or with something generalising that. This package has some quick attempts:
3+
[![Build Status](https://travis-ci.org/mcabbott/SliceMap.jl.svg?branch=master)](https://travis-ci.org/mcabbott/SliceMap.jl)
4+
5+
This package provides some `mapslices`-like functions,
6+
with gradients for [Flux](https://github.com/FluxML/Flux.jl) and [Zygote](https://github.com/FluxML/Zygote.jl):
7+
8+
```julia
9+
mapcols(f, M) mapreduce(f, hcat, eachcol(M))
10+
MapCols{d}(f, M) # where d=size(M,1), for StaticArrays
11+
12+
maprows(f, M) mapreduce(f, vcat, eachrow(M))
13+
14+
slicemap(f, A; dims) mapslices(f, A, dims)
15+
```
16+
17+
### An example
518

619
```julia
720
mat = rand(1:9, 3,10)
@@ -18,8 +31,7 @@ ForwardDiff.gradient(m -> sum(sin, mapslices(fun, m, dims=1)), mat)
1831
Tracker.gradient(m -> sum(sin, mapcols(fun, m)), mat)[1] # Tracker.forward per slice
1932
Tracker.gradient(m -> sum(sin, MapCols{3}(fun, m)), mat)[1] # ForwardDiff on slices
2033

21-
# Zygote.gradient(m -> sum(sin, mapslices(fun, m, dims=1)), mat) # errors
22-
Zygote.gradient(m -> sum(sin, mapcols(fun, m)), mat)[1] # Zygote.forward
34+
Zygote.gradient(m -> sum(sin, mapcols(fun, m)), mat)[1] # Zygote.forward per slice
2335
Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m)), mat)[1]
2436
```
2537

@@ -44,20 +56,10 @@ mat1k = rand(3,1000);
4456
@btime Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m)), $mat1k); # 245.550 μs
4557
```
4658

47-
Of course `mapslices()` does things other than columns of matrices.
48-
Most of which can be done better with `eachslice()` and `reduce(hcat,...)`,
49-
maybe with some thought one could just write gradients for those...
50-
51-
Perhaps this is done, at least for Zygote. The views of `eachcol()` have quite inefficient gradients,
52-
because for each `view()` they make a fresh `zero(A)`, but `collecteachcol()` is efficient:
53-
54-
```julia
55-
@btime Zygote.gradient(m -> sum(sin, mapcols4(fun, m)), $mat1k); # 45.616 ms, 49.49 MiB
56-
@btime Zygote.gradient(m -> sum(sin, mapcols6(fun, m)), $mat1k); # 18.655 ms, 3.37 MiB
57-
```
58-
59-
Or for the slice/glue functions in [TensorCast](https://github.com/mcabbott/TensorCast.jl),
60-
which now does some mapslices things (and will soon do many more) by chaining such functions.
59+
It also provides Zygote gradients for the slice/glue functions in
60+
[TensorCast](https://github.com/mcabbott/TensorCast.jl),
61+
which can be used to write many mapslices-like operations.
62+
(The function `slicemap(f, A, dims)` uses these functions, without having to write index notation.)
6163

6264
```julia
6365
using TensorCast
@@ -68,14 +70,9 @@ Zygote.gradient(m -> sum(sin, tcm(m)), mat)[1]
6870

6971
@btime tcm($mat1k) # 407.176 μs
7072
@btime Zygote.gradient(m -> sum(sin, tcm(m)), $mat1k) # 19.086 ms
71-
72-
ten = rand(1:9, 3,10,2)
73-
@cast zed[i,j,k] := fun(ten[i,:,k])[j]
74-
Zygote.gradient(m -> sum(sin, @cast zed[i,j,k] := fun(m[i,:,k])[j] nolazy), ten)[1]
7573
```
7674

77-
The function `slicemap(f, A, dims)` uses these slice/glue functions,
78-
without having to write index notation.
75+
### Elsewhere
7976

8077
Issues about mapslices:
8178
* https://github.com/FluxML/Zygote.jl/issues/92
@@ -85,14 +82,9 @@ Other packages which define gradients of possible interest:
8582
* https://github.com/GiggleLiu/LinalgBackwards.jl
8683
* https://github.com/mcabbott/ArrayAllez.jl
8784

88-
I added some tests:
89-
[![Build Status](https://travis-ci.org/mcabbott/SliceMap.jl.svg?branch=master)](https://travis-ci.org/mcabbott/SliceMap.jl)
90-
91-
<!--
9285
AD packages this could perhaps support, quite the zoo:
9386
* https://github.com/invenia/Nabla.jl
9487
* https://github.com/dfdx/Yota.jl
9588
* https://github.com/denizyuret/AutoGrad.jl
9689
* https://github.com/Roger-luo/YAAD.jl
9790
* And perhaps one day, just https://github.com/JuliaDiff/ChainRules.jl
98-
-->

0 commit comments

Comments
 (0)