1
1
# SliceMap.jl
2
2
3
- It would be nice if [ Flux] ( https://github.com/FluxML/Flux.jl ) / [ Zygote] ( https://github.com/FluxML/Zygote.jl ) worked with ` mapslices ` ,
4
- or with something generalising that. This package has some quick attempts:
3
+ [ ![ Build Status] ( https://travis-ci.org/mcabbott/SliceMap.jl.svg?branch=master )] ( https://travis-ci.org/mcabbott/SliceMap.jl )
4
+
5
+ This package provides some ` mapslices ` -like functions,
6
+ with gradients for [ Flux] ( https://github.com/FluxML/Flux.jl ) and [ Zygote] ( https://github.com/FluxML/Zygote.jl ) :
7
+
8
+ ``` julia
9
+ mapcols (f, M) ≈ mapreduce (f, hcat, eachcol (M))
10
+ MapCols {d} (f, M) # where d=size(M,1), for StaticArrays
11
+
12
+ maprows (f, M) ≈ mapreduce (f, vcat, eachrow (M))
13
+
14
+ slicemap (f, A; dims) ≈ mapslices (f, A, dims)
15
+ ```
16
+
17
+ ### An example
5
18
6
19
``` julia
7
20
mat = rand (1 : 9 , 3 ,10 )
@@ -18,8 +31,7 @@ ForwardDiff.gradient(m -> sum(sin, mapslices(fun, m, dims=1)), mat)
18
31
Tracker. gradient (m -> sum (sin, mapcols (fun, m)), mat)[1 ] # Tracker.forward per slice
19
32
Tracker. gradient (m -> sum (sin, MapCols {3} (fun, m)), mat)[1 ] # ForwardDiff on slices
20
33
21
- # Zygote.gradient(m -> sum(sin, mapslices(fun, m, dims=1)), mat) # errors
22
- Zygote. gradient (m -> sum (sin, mapcols (fun, m)), mat)[1 ] # Zygote.forward
34
+ Zygote. gradient (m -> sum (sin, mapcols (fun, m)), mat)[1 ] # Zygote.forward per slice
23
35
Zygote. gradient (m -> sum (sin, MapCols {3} (fun, m)), mat)[1 ]
24
36
```
25
37
@@ -44,20 +56,10 @@ mat1k = rand(3,1000);
44
56
@btime Zygote. gradient (m -> sum (sin, MapCols {3} (fun, m)), $ mat1k); # 245.550 μs
45
57
```
46
58
47
- Of course ` mapslices() ` does things other than columns of matrices.
48
- Most of which can be done better with ` eachslice() ` and ` reduce(hcat,...) ` ,
49
- maybe with some thought one could just write gradients for those...
50
-
51
- Perhaps this is done, at least for Zygote. The views of ` eachcol() ` have quite inefficient gradients,
52
- because for each ` view() ` they make a fresh ` zero(A) ` , but ` collecteachcol() ` is efficient:
53
-
54
- ``` julia
55
- @btime Zygote. gradient (m -> sum (sin, mapcols4 (fun, m)), $ mat1k); # 45.616 ms, 49.49 MiB
56
- @btime Zygote. gradient (m -> sum (sin, mapcols6 (fun, m)), $ mat1k); # 18.655 ms, 3.37 MiB
57
- ```
58
-
59
- Or for the slice/glue functions in [ TensorCast] ( https://github.com/mcabbott/TensorCast.jl ) ,
60
- which now does some mapslices things (and will soon do many more) by chaining such functions.
59
+ It also provides Zygote gradients for the slice/glue functions in
60
+ [ TensorCast] ( https://github.com/mcabbott/TensorCast.jl ) ,
61
+ which can be used to write many mapslices-like operations.
62
+ (The function ` slicemap(f, A, dims) ` uses these functions, without having to write index notation.)
61
63
62
64
``` julia
63
65
using TensorCast
@@ -68,14 +70,9 @@ Zygote.gradient(m -> sum(sin, tcm(m)), mat)[1]
68
70
69
71
@btime tcm ($ mat1k) # 407.176 μs
70
72
@btime Zygote. gradient (m -> sum (sin, tcm (m)), $ mat1k) # 19.086 ms
71
-
72
- ten = rand (1 : 9 , 3 ,10 ,2 )
73
- @cast zed[i,j,k] := fun (ten[i,:,k])[j]
74
- Zygote. gradient (m -> sum (sin, @cast zed[i,j,k] := fun (m[i,:,k])[j] nolazy), ten)[1 ]
75
73
```
76
74
77
- The function ` slicemap(f, A, dims) ` uses these slice/glue functions,
78
- without having to write index notation.
75
+ ### Elsewhere
79
76
80
77
Issues about mapslices:
81
78
* https://github.com/FluxML/Zygote.jl/issues/92
@@ -85,14 +82,9 @@ Other packages which define gradients of possible interest:
85
82
* https://github.com/GiggleLiu/LinalgBackwards.jl
86
83
* https://github.com/mcabbott/ArrayAllez.jl
87
84
88
- I added some tests:
89
- [ ![ Build Status] ( https://travis-ci.org/mcabbott/SliceMap.jl.svg?branch=master )] ( https://travis-ci.org/mcabbott/SliceMap.jl )
90
-
91
- <!--
92
85
AD packages this could perhaps support, quite the zoo:
93
86
* https://github.com/invenia/Nabla.jl
94
87
* https://github.com/dfdx/Yota.jl
95
88
* https://github.com/denizyuret/AutoGrad.jl
96
89
* https://github.com/Roger-luo/YAAD.jl
97
90
* And perhaps one day, just https://github.com/JuliaDiff/ChainRules.jl
98
- -->
0 commit comments