@@ -15,140 +15,14 @@ maprows(f, M) ≈ mapslices(f, M, dims=2)
15
15
slicemap (f, A; dims) ≈ mapslices (f, A, dims= dims) # only Zygote
16
16
```
17
17
18
- <!--
19
- It also defines Zygote gradients for the Slice/Align functions in
18
+ The capitalised functions differ both in using [ StaticArrays] ( https://github.com/JuliaArrays/StaticArrays.jl )
19
+ slices, and using [ ForwardDiff] ( https://github.com/JuliaDiff/ForwardDiff.jl ) for the gradient of each slice,
20
+ instead of the same reverse-mode Tracker/Zygote.
21
+ For small slices, this will often be much faster, with or without gradients.
22
+
23
+ The package also defines Zygote gradients for the Slice/Align functions in
20
24
[ JuliennedArrays] ( https://github.com/bramtayl/JuliennedArrays.jl ) ,
21
25
and the slice/glue functions in [ TensorCast] ( https://github.com/mcabbott/TensorCast.jl ) ,
22
- both of which are good ways to roll-your-own `mapslices`-like behaviour.
23
- -->
24
-
25
- ### Simple example
26
-
27
- ``` julia
28
- mat = rand (1 : 9 , 3 ,10 )
29
- fun (x) = 2 .+ x.^ 2
30
- mapslices (fun, mat, dims= 1 )
31
-
32
- using SliceMap
33
- mapcols (fun, mat) # eachcol(m)
34
- MapCols {3} (fun, mat) # reinterpret(SArray,...)
35
-
36
- using ForwardDiff, Tracker, Zygote
37
- ForwardDiff. gradient (m -> sum (sin, mapslices (fun, m, dims= 1 )), mat)
38
-
39
- Tracker. gradient (m -> sum (sin, mapcols (fun, m)), mat)[1 ] # Tracker.forward per slice
40
- Tracker. gradient (m -> sum (sin, MapCols {3} (fun, m)), mat)[1 ] # ForwardDiff on slices
41
-
42
- Zygote. gradient (m -> sum (sin, mapcols (fun, m)), mat)[1 ] # Zygote.forward per slice
43
- Zygote. gradient (m -> sum (sin, MapCols {3} (fun, m)), mat)[1 ]
44
- ```
45
-
46
- These are a bit faster than ` mapslices ` too. Although storing all the backward functions,
47
- which is what ` mapcols ` does, seems not to be so quick:
48
-
49
- ``` julia
50
- using BenchmarkTools
51
- mat1k = rand (3 ,1000 );
52
-
53
- @btime mapreduce (fun, hcat, eachcol ($ mat1k)) # 1.522 ms
54
- @btime mapslices (fun, $ mat1k, dims= 1 ) # 1.017 ms
55
-
56
- @btime mapcols (fun, $ mat1k) # 399.016 μs
57
- @btime MapCols {3} (fun, $ mat1k) # 15.564 μs
58
- @btime MapCols (fun, $ mat1k) # 16.774 μs without size
59
-
60
- @btime ForwardDiff. gradient (m -> sum (sin, mapslices (fun, m, dims= 1 )), $ mat1k); # 372.705 ms
61
- @btime Tracker. gradient (m -> sum (sin, mapcols (fun, m)), $ mat1k); # 70.203 ms
62
- @btime Tracker. gradient (m -> sum (sin, MapCols {3} (fun, m)), $ mat1k); # 146.561 μs, 330.51 KiB
63
- @btime Zygote. gradient (m -> sum (sin, mapcols (fun, m)), $ mat1k); # 20.018 ms, 3.82 MiB
64
- @btime Zygote. gradient (m -> sum (sin, MapCols {3} (fun, m)), $ mat1k); # 245.550 μs
65
- ```
66
-
67
- ### Other packages
68
-
69
- This package also provides Zygote gradients for the slice/glue functions in
70
- [ TensorCast] ( https://github.com/mcabbott/TensorCast.jl ) ,
71
- which can be used to write many mapslices-like operations.
72
- (The function ` slicemap(f, A, dims) ` uses these functions, without having to write index notation.)
73
-
74
- ``` julia
75
- using TensorCast
76
- @cast [i,j] := fun (mat[:,j])[i] # same as mapcols
77
-
78
- tcm (mat) = @cast out[i,j] := fun (mat[:,j])[i]
79
- Zygote. gradient (m -> sum (sin, tcm (m)), mat)[1 ]
80
-
81
- @btime tcm ($ mat1k) # 407.176 μs
82
- @btime Zygote. gradient (m -> sum (sin, tcm (m)), $ mat1k); # 19.086 ms
83
- ```
84
-
85
- Similar gradients work for the Slice/Align functions in
86
- [ JuliennedArrays] ( https://github.com/bramtayl/JuliennedArrays.jl ) ,
87
- so it defines these too:
88
-
89
- ``` julia
90
- using JuliennedArrays
91
- jumap (f,m) = Align (map (f, Slices (m, True (), False ())), True (), False ())
92
- jumap (fun, mat) # same as mapcols
93
- Zygote. gradient (m -> sum (sin, jumap (fun, m)), mat)[1 ]
94
-
95
- @btime jumap (fun, $ mat1k); # 408.259 μs
96
- @btime Zygote. gradient (m -> sum (sin, jumap (fun, m)), $ mat1k); # 18.638 ms
97
- ```
98
-
99
- That's a 2-line gradient definition, so borrowing it may be easier than depending on this package.
100
-
101
- The original purpose of ` MapCols ` , with ForwardDiff on slices, was that this works well when
102
- the function being mapped integrates some differential equation.
103
-
104
- ``` julia
105
- using DifferentialEquations, ParameterizedFunctions
106
- ode = @ode_def begin
107
- du = ( - k2 * u )/ (k1 + u) # an equation with 2 parameters
108
- end k1 k2
109
-
110
- function g (k:: AbstractVector{T} , times) where T
111
- u0 = T[ 1.0 ] # NB convert initial values to eltype(k)
112
- prob = ODEProblem (ode, u0, (0.0 , 0.0 + maximum (times)), k)
113
- Array (solve (prob, saveat= times)):: Matrix{T}
114
- end
115
-
116
- kay = rand (2 ,50 );
117
- MapCols {2} (g, kay, 1 : 5 ) # 5 time steps, for each col of parameters
118
-
119
- Tracker. gradient (k -> sum (sin, MapCols {2} (g, k, 1 : 5 )), kay)[1 ]
120
- ```
121
-
122
- This is both quite efficient, and seems to go well with multi-threading:
123
-
124
- ``` julia
125
- @btime MapCols {2} (g, $ kay, 1 : 5 ) # 1.369 ms
126
- @btime ThreadMapCols {2} (g, $ kay, 1 : 5 ) # 670.384 μs
127
-
128
- @btime Tracker. gradient (k -> sum (sin, MapCols {2} (g, k, 1 : 5 )), $ kay)[1 ] # 2.438 ms
129
- @btime Tracker. gradient (k -> sum (sin, ThreadMapCols {2} (g, k, 1 : 5 )), $ kay)[1 ] # 1.229 ms
130
-
131
- Threads. nthreads () == 4
132
- ```
133
-
134
- ### Elsewhere
135
-
136
- Issues about mapslices:
137
- * https://github.com/FluxML/Zygote.jl/issues/92
138
- * https://github.com/FluxML/Flux.jl/issues/741
139
- * https://github.com/JuliaLang/julia/issues/29146
140
-
141
- Differential equations:
142
- * https://arxiv.org/abs/1812.01892 "DSAAD"
143
- * http://docs.juliadiffeq.org/latest/analysis/sensitivity.html
144
-
145
- Other packages which define gradients of possible interest:
146
- * https://github.com/GiggleLiu/LinalgBackwards.jl
147
- * https://github.com/mcabbott/ArrayAllez.jl
26
+ both of which are good ways to roll-your-own ` mapslices ` -like things.
148
27
149
- Differentiation packages this could perhaps support, quite the zoo:
150
- * https://github.com/dfdx/Yota.jl
151
- * https://github.com/invenia/Nabla.jl
152
- * https://github.com/denizyuret/AutoGrad.jl
153
- * https://github.com/Roger-luo/YAAD.jl
154
- * And perhaps one day, just https://github.com/JuliaDiff/ChainRules.jl
28
+ There are more details & examples at [ docs/intro.md] ( docs/intro.md ) .
0 commit comments