@@ -26,23 +26,23 @@ Any arguments after the matrix are passed to `f` as scalars, i.e.
26
26
`mapcols(f, m, args...) = reduce(hcat, f(col, args...) for col in eeachcol(m))`.
27
27
They do not get sliced/iterated (unlike `map`), nor are their gradients tracked.
28
28
"""
29
- mapcols (f:: Function , M, args... ) = _mapcols (map, f, M, args... )
30
- tmapcols (f:: Function , M, args... ) = _mapcols (threadmap, f, M, args... )
29
+ mapcols (f, M, args... ) = _mapcols (map, f, M, args... )
30
+ tmapcols (f, M, args... ) = _mapcols (threadmap, f, M, args... )
31
31
32
- function _mapcols (map:: Function , f:: Function , M:: AbstractMatrix , args... )
32
+ function _mapcols (map:: Function , f, M:: AbstractMatrix , args... )
33
33
res = map (col -> _vec (f (col, args... )), eachcol (M))
34
34
eltype (res) <: AbstractVector ? reduce (hcat, res) : reshape (res,1 ,:)
35
35
end
36
36
37
37
_vec (x) = x
38
38
_vec (A:: AbstractArray ) = vec (A) # to allow f vector -> matrix, by reshaping
39
39
40
- _mapcols (map:: Function , f:: Function , M:: TrackedMatrix , args... ) = track (_mapcols, map, f, M, args... )
40
+ _mapcols (map:: Function , f, M:: TrackedMatrix , args... ) = track (_mapcols, map, f, M, args... )
41
41
42
- @grad _mapcols (map:: Function , f:: Function , M:: AbstractMatrix , args... ) =
42
+ @grad _mapcols (map:: Function , f, M:: AbstractMatrix , args... ) =
43
43
∇mapcols (map, map (col -> Tracker. forward (x -> _vec (f (x, args... )), col), eachcol (data (M))), args... )
44
44
45
- @adjoint _mapcols (map:: Function , f:: Function , M:: AbstractMatrix , args... ) =
45
+ @adjoint _mapcols (map:: Function , f, M:: AbstractMatrix , args... ) =
46
46
∇mapcols (map, map (col -> ZygoteRules. pullback (x -> _vec (f (x, args... )), col), eachcol (M)), args)
47
47
48
48
function ∇mapcols (bigmap, forwards, args... )
@@ -91,7 +91,7 @@ e.g. if `dims=(2,4)` then `f` must map matrices to matrices.
91
91
92
92
The gradient is for Zygote only.
93
93
"""
94
- function slicemap (f:: Function , A:: AbstractArray{T,N} , args... ; dims) where {T,N}
94
+ function slicemap (f, A:: AbstractArray{T,N} , args... ; dims) where {T,N}
95
95
code = ntuple (d -> d in dims ? True () : False (), N)
96
96
B = JuliennedArrays. Slices (A, code... )
97
97
C = [ f (slice, args... ) for slice in B ]
0 commit comments