@@ -18,6 +18,7 @@ using ZygoteRules: pullback, @adjoint
18
18
19
19
This is a more efficient version of the functions on the right.
20
20
For `f(x::Vector)::Matrix` it reshapes like `mapslices(vec∘f, m, dims=1)`.
21
+ For `f(x::Vector)::Number` it skips the reduction, just `reshape(map(f, eachcol(m)),1,:)`.
21
22
22
23
It provides a gradient for Tracker and Zygote, saving the backward function for each slice.
23
24
@@ -28,53 +29,58 @@ They do not get sliced/iterated (unlike `map`), nor are their gradients tracked.
28
29
mapcols (f:: Function , M, args... ) = _mapcols (map, f, M, args... )
29
30
tmapcols (f:: Function , M, args... ) = _mapcols (threadmap, f, M, args... )
30
31
31
- _mapcols (map:: Function , f:: Function , M:: AbstractMatrix , args... ) =
32
- reduce (hcat, map (col -> surevec (f (col, args... ), M), eachcol (M)))
32
+ function _mapcols (map:: Function , f:: Function , M:: AbstractMatrix , args... )
33
+ res = map (col -> _vec (f (col, args... )), eachcol (M))
34
+ eltype (res) <: AbstractVector ? reduce (hcat, res) : reshape (res,1 ,:)
35
+ end
36
+
37
+ _vec (x) = x
38
+ _vec (A:: AbstractArray ) = vec (A) # to allow f vector -> matrix, by reshaping
33
39
34
40
_mapcols (map:: Function , f:: Function , M:: TrackedMatrix , args... ) = track (_mapcols, map, f, M, args... )
35
41
36
42
@grad _mapcols (map:: Function , f:: Function , M:: AbstractMatrix , args... ) =
37
- ∇mapcols (map, map (col -> Tracker. forward (x -> surevec (f (x, args... ), M ), col), eachcol (data (M))), args... )
43
+ ∇mapcols (map, map (col -> Tracker. forward (x -> _vec (f (x, args... )), col), eachcol (data (M))), args... )
38
44
39
45
@adjoint _mapcols (map:: Function , f:: Function , M:: AbstractMatrix , args... ) =
40
- ∇mapcols (map, map (col -> ZygoteRules. pullback (x -> surevec (f (x, args... ), M ), col), eachcol (M)), args)
46
+ ∇mapcols (map, map (col -> ZygoteRules. pullback (x -> _vec (f (x, args... )), col), eachcol (M)), args)
41
47
42
48
function ∇mapcols (bigmap, forwards, args... )
43
- reduce (hcat, map (data∘ first, forwards)), Δ -> begin
44
- cols = bigmap ((fwd, Δcol) -> data (last (fwd)(Δcol)[1 ]), forwards, eachcol (data (Δ)))
49
+ res = map (data∘ first, forwards)
50
+ function back (Δ)
51
+ Δcols = eltype (res) <: AbstractVector ? eachcol (data (Δ)) : vec (data (Δ))
52
+ cols = bigmap ((fwd, Δcol) -> data (last (fwd)(Δcol)[1 ]), forwards, Δcols)
45
53
(nothing , nothing , reduce (hcat, cols), map (_-> nothing , args)... )
46
54
end
55
+ eltype (res) <: AbstractVector ? reduce (hcat, res) : reshape (res,1 ,:), back
47
56
end
48
57
49
- surevec (A:: AbstractArray , M) = vec (A) # to allow f vector -> matrix, by reshaping
50
- surevec (x:: Number , M) = _veclike (x, M) # to allow f vector -> scalar, as mapslices does
51
-
52
- _veclike (x:: T , M) where {T} = fill! (similar (M, T, 1 ), x) # use similar to preserve CuArrays, #4
53
- _veclike (x:: TrackedReal , M) = track (_veclike, x, M)
54
- @grad _veclike (x, M) = _veclike (data (x), M), Δ -> (first (Δ), nothing )
55
- @adjoint _veclike (x, M) = _veclike (x, M), Δ -> (first (Δ), nothing )
56
-
57
58
"""
58
59
maprows(f, M) ≈ mapslices(f, M, dims=2)
59
60
60
61
Like `mapcols()` but for rows.
61
62
"""
62
- maprows (f:: Function , M:: AbstractMatrix , args... ) =
63
- reduce (vcat, map (col -> transpose (surevec (f (col, args... ), M)), eachrow (M)))
63
+ function maprows (f:: Function , M:: AbstractMatrix , args... )
64
+ res = map (col -> transpose (_vec (f (col, args... ))), eachrow (M))
65
+ eltype (res) <: AbstractArray ? reduce (vcat, res) : reshape (res,:,1 )
66
+ end
64
67
65
68
maprows (f:: Function , M:: TrackedMatrix , args... ) = track (maprows, f, M, args... )
66
69
67
70
@grad maprows (f:: Function , M:: AbstractMatrix , args... ) =
68
- ∇maprows (map (row -> Tracker. forward (x -> surevec (f (x, args... ), M ), row), eachrow (data (M))), args)
71
+ ∇maprows (map (row -> Tracker. forward (x -> _vec (f (x, args... )), row), eachrow (data (M))), args)
69
72
70
73
@adjoint maprows (f:: Function , M:: AbstractMatrix , args... ) =
71
- ∇maprows (map (row -> ZygoteRules. pullback (x -> surevec (f (x, args... ), M ), row), eachrow (M)), args)
74
+ ∇maprows (map (row -> ZygoteRules. pullback (x -> _vec (f (x, args... )), row), eachrow (M)), args)
72
75
73
76
function ∇maprows (forwards, args)
74
- reduce (vcat, map (transpose∘ data∘ first, forwards)), Δ -> begin
75
- rows = map ((fwd, Δrow) -> data (last (fwd)(Δrow)[1 ]), forwards, eachrow (data (Δ)))
76
- (nothing , reduce (vcat, transpose .(rows)), map (_-> nothing , args)... )
77
+ res = map (transpose∘ data∘ first, forwards)
78
+ function back (Δ)
79
+ Δrows = eltype (res) <: AbstractArray ? eachrow (data (Δ)) : vec (data (Δ))
80
+ rows = map ((fwd, Δrow) -> transpose (data (last (fwd)(Δrow)[1 ])), forwards, Δrows)
81
+ (nothing , reduce (vcat, rows), map (_-> nothing , args)... )
77
82
end
83
+ eltype (res) <: AbstractArray ? reduce (vcat, res) : reshape (res,:,1 ), back
78
84
end
79
85
80
86
"""
@@ -115,10 +121,13 @@ MapCols{d}(f::Function, M::AbstractMatrix, args...) where {d} =
115
121
function _MapCols (map:: Function , f:: Function , M:: Matrix{T} , :: Val{d} , args... ) where {T,d}
116
122
d == size (M,1 ) || error (" expected M with $d rows" )
117
123
A = reinterpret (SArray{Tuple{d}, T, 1 , d}, vec (M))
118
- B = map (col -> surevec (f (col, args... ), M ), A)
124
+ B = map (col -> surevec (f (col, args... )), A)
119
125
reduce (hcat, B)
120
126
end
121
127
128
+ surevec (A:: AbstractArray ) = vec (A)
129
+ surevec (x:: Number ) = SVector (x) # simple way to deal with f vector -> scalar
130
+
122
131
_MapCols (map:: Function , f:: Function , M:: TrackedMatrix , dval, args... ) =
123
132
track (_MapCols, map, f, M, dval, args... )
124
133
@@ -135,7 +144,7 @@ function ∇MapCols(bigmap::Function, f::Function, M::AbstractMatrix{T}, dval::V
135
144
A = reinterpret (SArray{Tuple{d}, T, 1 , d}, vec (data (M)))
136
145
137
146
dualcol = SVector (ntuple (j-> ForwardDiff. Dual (0 , ntuple (i-> i== j ? 1 : 0 , dval)... ), dval))
138
- C = bigmap (col -> surevec (f (col + dualcol, args... ), M ), A)
147
+ C = bigmap (col -> surevec (f (col + dualcol, args... )), A)
139
148
140
149
Z = reduce (hcat, map (col -> ForwardDiff. value .(col), C))
141
150
0 commit comments