Skip to content

Commit 07c8f26

Browse files
author
Michael Abbott
committed
docs
1 parent 102df32 commit 07c8f26

File tree

2 files changed

+25
-31
lines changed

2 files changed

+25
-31
lines changed

README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
[![Build Status](https://travis-ci.org/mcabbott/SliceMap.jl.svg?branch=master)](https://travis-ci.org/mcabbott/SliceMap.jl)
44

5-
This package provides some `mapslices`-like functions,
6-
with gradients for [Flux](https://github.com/FluxML/Flux.jl) and [Zygote](https://github.com/FluxML/Zygote.jl):
5+
This package provides some `mapslices`-like functions, with gradients defined for
6+
[Tracker](https://github.com/FluxML/Tracker.jl) and [Zygote](https://github.com/FluxML/Zygote.jl):
77

88
```julia
99
mapcols(f, M) mapreduce(f, hcat, eachcol(M))
@@ -22,7 +22,6 @@ For small slices, this will often be much faster, with or without gradients.
2222

2323
The package also defines Zygote gradients for the Slice/Align functions in
2424
[JuliennedArrays](https://github.com/bramtayl/JuliennedArrays.jl),
25-
and the slice/glue functions in [TensorCast](https://github.com/mcabbott/TensorCast.jl),
26-
both of which are good ways to roll-your-own `mapslices`-like things.
25+
which is a good ways to roll-your-own `mapslices`-like thing.
2726

2827
There are more details & examples at [docs/intro.md](docs/intro.md).

docs/intro.md

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,34 @@ mat1k = rand(3,1000);
4444
@btime Zygote.gradient(m -> sum(MapCols{3}(fun, m)), $mat1k); # 28.229 μs, 164.63 KiB
4545
```
4646

47-
For such a simple function, timing `sum(sin, MapCols{3}(fun, m))` takes 3 to 10 times longer!
47+
On recent versions of Julia, `mapcols` has become much faster, 5-10 times.
4848

4949
## Other packages
5050

51-
This package also provides Zygote gradients for the slice/glue functions in
51+
This package also provides Zygote gradients for the Slice/Align functions in
52+
[JuliennedArrays](https://github.com/bramtayl/JuliennedArrays.jl),
53+
which can be used to write many mapslices-like operations:
54+
55+
```julia
56+
using JuliennedArrays
57+
jumap(f,m) = Align(map(f, Slices(m, True(), False())), True(), False())
58+
jumap1(f,m) = Align(map(f, Slices(m, 1)), 1)
59+
jumap(fun, mat) # same as mapcols
60+
jumap1(fun, mat)
61+
Zygote.gradient(m -> sum(sin, jumap(fun, m)), mat)[1]
62+
63+
@btime jumap(fun, $mat1k); # 44.823 μs
64+
@btime jumap1(fun, $mat1k); # 11.805 μs, really?
65+
@btime Zygote.gradient(m -> sum(jumap(fun, m)), $mat1k); # 26.110 ms
66+
@btime Zygote.gradient(m -> sum(jumap1(fun, m)), $mat1k) # 412.904 μs, really?
67+
```
68+
69+
It used to do the same thing for the slice/glue functions in
5270
[TensorCast](https://github.com/mcabbott/TensorCast.jl),
53-
which can be used to write many mapslices-like operations.
54-
(The function `slicemap(f, A, dims)` uses these functions, without having to write index notation.)
71+
but but that should soon be part of that package:
5572

5673
```julia
57-
using TensorCast
74+
using TensorCast#two
5875
@cast [i,j] := fun(mat[:,j])[i] # same as mapcols
5976

6077
tcm(mat) = @cast out[i,j] := fun(mat[:,j])[i]
@@ -64,22 +81,6 @@ Zygote.gradient(m -> sum(sin, tcm(m)), mat)[1]
6481
@btime Zygote.gradient(m -> sum(tcm(m)), $mat1k); # 18.358 ms
6582
```
6683

67-
Similar gradients work for the Slice/Align functions in
68-
[JuliennedArrays](https://github.com/bramtayl/JuliennedArrays.jl),
69-
so it defines these too:
70-
71-
```julia
72-
using JuliennedArrays
73-
jumap(f,m) = Align(map(f, Slices(m, True(), False())), True(), False())
74-
jumap(fun, mat) # same as mapcols
75-
Zygote.gradient(m -> sum(sin, jumap(fun, m)), mat)[1]
76-
77-
@btime jumap(fun, $mat1k); # 421.061 μs
78-
@btime Zygote.gradient(m -> sum(jumap(fun, m)), $mat1k); # 18.383 ms
79-
```
80-
81-
That's a 2-line gradient definition, so borrowing it may be easier than depending on this package.
82-
8384
The original purpose of `MapCols`, with ForwardDiff on slices, was that this works well when
8485
the function being mapped integrates some differential equation.
8586

@@ -128,9 +129,3 @@ Other packages which define gradients of possible interest:
128129
* https://github.com/GiggleLiu/LinalgBackwards.jl
129130
* https://github.com/mcabbott/ArrayAllez.jl
130131

131-
Differentiation packages this could perhaps support, quite the zoo:
132-
* https://github.com/dfdx/Yota.jl
133-
* https://github.com/invenia/Nabla.jl
134-
* https://github.com/denizyuret/AutoGrad.jl
135-
* https://github.com/Roger-luo/YAAD.jl
136-
* And perhaps one day, just https://github.com/JuliaDiff/ChainRules.jl

0 commit comments

Comments
 (0)