@@ -6,7 +6,7 @@ export mapcols, MapCols, maprows, slicemap, tmapcols, ThreadMapCols
6
6
using JuliennedArrays
7
7
8
8
using Tracker
9
- using Tracker: TrackedMatrix, track, @grad , data
9
+ using Tracker: TrackedReal, TrackedMatrix, track, @grad , data
10
10
11
11
using ZygoteRules
12
12
using ZygoteRules: pullback, @adjoint
@@ -29,18 +29,15 @@ mapcols(f::Function, M, args...) = _mapcols(map, f, M, args...)
29
29
tmapcols (f:: Function , M, args... ) = _mapcols (threadmap, f, M, args... )
30
30
31
31
_mapcols (map:: Function , f:: Function , M:: AbstractMatrix , args... ) =
32
- reduce (hcat, map (col -> surevec (f (col, args... )), eachcol (M)))
33
-
34
- surevec (x:: Number ) = [x] # to allow f vector -> scalar, as mapslices does
35
- surevec (A) = vec (A) # to allow f vector -> matrix, by reshaping
32
+ reduce (hcat, map (col -> surevec (f (col, args... ), M), eachcol (M)))
36
33
37
34
_mapcols (map:: Function , f:: Function , M:: TrackedMatrix , args... ) = track (_mapcols, map, f, M, args... )
38
35
39
36
@grad _mapcols (map:: Function , f:: Function , M:: AbstractMatrix , args... ) =
40
- ∇mapcols (map, map (col -> Tracker. forward (x -> surevec (f (x, args... )), col), eachcol (data (M))), args... )
37
+ ∇mapcols (map, map (col -> Tracker. forward (x -> surevec (f (x, args... ), M ), col), eachcol (data (M))), args... )
41
38
42
39
@adjoint _mapcols (map:: Function , f:: Function , M:: AbstractMatrix , args... ) =
43
- ∇mapcols (map, map (col -> ZygoteRules. pullback (x -> surevec (f (x, args... )), col), eachcol (M)), args)
40
+ ∇mapcols (map, map (col -> ZygoteRules. pullback (x -> surevec (f (x, args... ), M ), col), eachcol (M)), args)
44
41
45
42
function ∇mapcols (bigmap, forwards, args... )
46
43
reduce (hcat, map (data∘ first, forwards)), Δ -> begin
@@ -49,21 +46,29 @@ function ∇mapcols(bigmap, forwards, args...)
49
46
end
50
47
end
51
48
49
+ surevec (A:: AbstractArray , M) = vec (A) # to allow f vector -> matrix, by reshaping
50
+ surevec (x:: Number , M) = _veclike (x, M) # to allow f vector -> scalar, as mapslices does
51
+
52
+ _veclike (x:: T , M) where {T} = fill! (similar (M, T, 1 ), x) # use similar to preserve CuArrays, #4
53
+ _veclike (x:: TrackedReal , M) = track (_veclike, x, M)
54
+ @grad _veclike (x, M) = _veclike (data (x), M), Δ -> (first (Δ), nothing )
55
+ @adjoint _veclike (x, M) = _veclike (x, M), Δ -> (first (Δ), nothing )
56
+
52
57
"""
53
58
maprows(f, M) ≈ mapslices(f, M, dims=2)
54
59
55
60
Like `mapcols()` but for rows.
56
61
"""
57
62
maprows (f:: Function , M:: AbstractMatrix , args... ) =
58
- reduce (vcat, map (col -> transpose (surevec (f (col, args... ))), eachrow (M)))
63
+ reduce (vcat, map (col -> transpose (surevec (f (col, args... ), M )), eachrow (M)))
59
64
60
65
maprows (f:: Function , M:: TrackedMatrix , args... ) = track (maprows, f, M, args... )
61
66
62
67
@grad maprows (f:: Function , M:: AbstractMatrix , args... ) =
63
- ∇maprows (map (row -> Tracker. forward (x -> surevec (f (x, args... )), row), eachrow (data (M))), args)
68
+ ∇maprows (map (row -> Tracker. forward (x -> surevec (f (x, args... ), M ), row), eachrow (data (M))), args)
64
69
65
70
@adjoint maprows (f:: Function , M:: AbstractMatrix , args... ) =
66
- ∇maprows (map (row -> ZygoteRules. pullback (x -> surevec (f (x, args... )), row), eachrow (M)), args)
71
+ ∇maprows (map (row -> ZygoteRules. pullback (x -> surevec (f (x, args... ), M ), row), eachrow (M)), args)
67
72
68
73
function ∇maprows (forwards, args)
69
74
reduce (vcat, map (transpose∘ data∘ first, forwards)), Δ -> begin
@@ -110,7 +115,7 @@ MapCols{d}(f::Function, M::AbstractMatrix, args...) where {d} =
110
115
function _MapCols (map:: Function , f:: Function , M:: Matrix{T} , :: Val{d} , args... ) where {T,d}
111
116
d == size (M,1 ) || error (" expected M with $d rows" )
112
117
A = reinterpret (SArray{Tuple{d}, T, 1 , d}, vec (M))
113
- B = map (col -> surevec (f (col, args... )), A)
118
+ B = map (col -> surevec (f (col, args... ), M ), A)
114
119
reduce (hcat, B)
115
120
end
116
121
@@ -130,7 +135,7 @@ function ∇MapCols(bigmap::Function, f::Function, M::AbstractMatrix{T}, dval::V
130
135
A = reinterpret (SArray{Tuple{d}, T, 1 , d}, vec (data (M)))
131
136
132
137
dualcol = SVector (ntuple (j-> ForwardDiff. Dual (0 , ntuple (i-> i== j ? 1 : 0 , dval)... ), dval))
133
- C = bigmap (col -> surevec (f (col + dualcol, args... )), A)
138
+ C = bigmap (col -> surevec (f (col + dualcol, args... ), M ), A)
134
139
135
140
Z = reduce (hcat, map (col -> ForwardDiff. value .(col), C))
136
141
0 commit comments