@@ -25,24 +25,26 @@ It provides a gradient for Tracker and Zygote, saving the backward function for
25
25
Any arguments after the matrix are passed to `f` as scalars, i.e.
26
26
`mapcols(f, m, args...) = reduce(hcat, f(col, args...) for col in eeachcol(m))`.
27
27
They do not get sliced/iterated (unlike `map`), nor are their gradients tracked.
28
+
29
+ Note that if `f` itself contains parameters, their gradients are also not tracked.
28
30
"""
29
- mapcols (f:: Function , M, args... ) = _mapcols (map, f, M, args... )
30
- tmapcols (f:: Function , M, args... ) = _mapcols (threadmap, f, M, args... )
31
+ mapcols (f, M, args... ) = _mapcols (map, f, M, args... )
32
+ tmapcols (f, M, args... ) = _mapcols (threadmap, f, M, args... )
31
33
32
- function _mapcols (map:: Function , f:: Function , M:: AbstractMatrix , args... )
34
+ function _mapcols (map:: Function , f, M:: AbstractMatrix , args... )
33
35
res = map (col -> _vec (f (col, args... )), eachcol (M))
34
36
eltype (res) <: AbstractVector ? reduce (hcat, res) : reshape (res,1 ,:)
35
37
end
36
38
37
39
_vec (x) = x
38
40
_vec (A:: AbstractArray ) = vec (A) # to allow f vector -> matrix, by reshaping
39
41
40
- _mapcols (map:: Function , f:: Function , M:: TrackedMatrix , args... ) = track (_mapcols, map, f, M, args... )
42
+ _mapcols (map:: Function , f, M:: TrackedMatrix , args... ) = track (_mapcols, map, f, M, args... )
41
43
42
- @grad _mapcols (map:: Function , f:: Function , M:: AbstractMatrix , args... ) =
44
+ @grad _mapcols (map:: Function , f, M:: AbstractMatrix , args... ) =
43
45
∇mapcols (map, map (col -> Tracker. forward (x -> _vec (f (x, args... )), col), eachcol (data (M))), args... )
44
46
45
- @adjoint _mapcols (map:: Function , f:: Function , M:: AbstractMatrix , args... ) =
47
+ @adjoint _mapcols (map:: Function , f, M:: AbstractMatrix , args... ) =
46
48
∇mapcols (map, map (col -> ZygoteRules. pullback (x -> _vec (f (x, args... )), col), eachcol (M)), args)
47
49
48
50
function ∇mapcols (bigmap, forwards, args... )
@@ -90,8 +92,11 @@ Like `mapcols()`, but for any slice. The function `f` must preserve shape,
90
92
e.g. if `dims=(2,4)` then `f` must map matrices to matrices.
91
93
92
94
The gradient is for Zygote only.
95
+
96
+ Parameters within the function `f` (if there are any) should be correctly tracked,
97
+ which is not the case for `mapcols()`.
93
98
"""
94
- function slicemap (f:: Function , A:: AbstractArray{T,N} , args... ; dims) where {T,N}
99
+ function slicemap (f, A:: AbstractArray{T,N} , args... ; dims) where {T,N}
95
100
code = ntuple (d -> d in dims ? True () : False (), N)
96
101
B = JuliennedArrays. Slices (A, code... )
97
102
C = [ f (slice, args... ) for slice in B ]
0 commit comments