1
1
# # Generic ##
2
2
3
+ if VERSION < v " 1.1"
4
+ eachcol (A:: AbstractVecOrMat ) = (view (A, :, i) for i in axes (A, 2 ))
5
+ end
6
+
7
+ Base. one (:: Irrational ) = true
8
+
9
+ function vcatmapreduce (f, args... )
10
+ init = vcat (f (first .(args)... ,))
11
+ zipped_args = zip (args... ,)
12
+ return mapreduce (vcat, drop (zipped_args, 1 ); init = init) do zarg
13
+ f (zarg... ,)
14
+ end
15
+ end
16
+ @adjoint function vcatmapreduce (f, args... )
17
+ g (f, args... ) = f .(args... ,)
18
+ return pullback (g, f, args... )
19
+ end
20
+
3
21
function Base. fill (
4
22
value:: TrackedReal ,
5
23
dims:: Vararg{Union{Integer, AbstractUnitRange}} ,
6
24
)
7
25
return track (fill, value, dims... )
8
26
end
9
- Tracker . @grad function Base. fill (value:: Real , dims... )
27
+ @grad function Base. fill (value:: Real , dims... )
10
28
return fill (data (value), dims... ), function (Δ)
11
29
size (Δ) ≢ dims && error (" Dimension mismatch" )
12
30
return (sum (Δ), map (_-> nothing , dims)... )
16
34
# # StatsFuns ##
17
35
18
36
logsumexp (x:: TrackedArray ) = track (logsumexp, x)
19
- Tracker . @grad function logsumexp (x:: TrackedArray )
37
+ @grad function logsumexp (x:: TrackedArray )
20
38
lse = logsumexp (data (x))
21
39
return lse, Δ -> (Δ .* exp .(x .- lse),)
22
40
end
23
41
24
42
# # Linear algebra ##
25
43
26
44
LinearAlgebra. UpperTriangular (A:: TrackedMatrix ) = track (UpperTriangular, A)
27
- Tracker . @grad function LinearAlgebra. UpperTriangular (A:: AbstractMatrix )
45
+ @grad function LinearAlgebra. UpperTriangular (A:: AbstractMatrix )
28
46
return UpperTriangular (data (A)), Δ-> (UpperTriangular (Δ),)
29
47
end
30
48
@@ -39,27 +57,27 @@ function turing_chol(A::AbstractMatrix, check)
39
57
(chol. factors, chol. info)
40
58
end
41
59
turing_chol (A:: TrackedMatrix , check) = track (turing_chol, A, check)
42
- Tracker . @grad function turing_chol (A:: AbstractMatrix , check)
60
+ @grad function turing_chol (A:: AbstractMatrix , check)
43
61
C, back = pullback (unsafe_cholesky, data (A), data (check))
44
62
return (C. factors, C. info), Δ-> back ((factors= data (Δ[1 ]),))
45
63
end
46
64
47
65
unsafe_cholesky (x, check) = cholesky (x, check= check)
48
- ZygoteRules . @adjoint function unsafe_cholesky (Σ:: Real , check)
66
+ @adjoint function unsafe_cholesky (Σ:: Real , check)
49
67
C = cholesky (Σ; check= check)
50
68
return C, function (Δ:: NamedTuple )
51
69
issuccess (C) || return (zero (Σ), nothing )
52
70
(Δ. factors[1 , 1 ] / (2 * C. U[1 , 1 ]), nothing )
53
71
end
54
72
end
55
- ZygoteRules . @adjoint function unsafe_cholesky (Σ:: Diagonal , check)
73
+ @adjoint function unsafe_cholesky (Σ:: Diagonal , check)
56
74
C = cholesky (Σ; check= check)
57
75
return C, function (Δ:: NamedTuple )
58
76
issuccess (C) || (Diagonal (zero (diag (Δ. factors))), nothing )
59
77
(Diagonal (diag (Δ. factors) .* inv .(2 .* C. factors. diag)), nothing )
60
78
end
61
79
end
62
- ZygoteRules . @adjoint function unsafe_cholesky (Σ:: Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}} , check)
80
+ @adjoint function unsafe_cholesky (Σ:: Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}} , check)
63
81
C = cholesky (Σ; check= check)
64
82
return C, function (Δ:: NamedTuple )
65
83
issuccess (C) || return (zero (Δ. factors), nothing )
78
96
# Specialised logdet for cholesky to target the triangle directly.
79
97
logdet_chol_tri (U:: AbstractMatrix ) = 2 * sum (log, U[diagind (U)])
80
98
logdet_chol_tri (U:: TrackedMatrix ) = track (logdet_chol_tri, U)
81
- Tracker . @grad function logdet_chol_tri (U:: AbstractMatrix )
99
+ @grad function logdet_chol_tri (U:: AbstractMatrix )
82
100
U_data = data (U)
83
101
return logdet_chol_tri (U_data), Δ-> (Matrix (Diagonal (2 .* Δ ./ diag (U_data))),)
84
102
end
@@ -88,6 +106,7 @@ function LinearAlgebra.logdet(C::Cholesky{<:TrackedReal, <:TrackedMatrix})
88
106
end
89
107
90
108
# Tracker's implementation of ldiv isn't good. We'll use Zygote's instead.
109
+
91
110
zygote_ldiv (A:: AbstractMatrix , B:: AbstractVecOrMat ) = A \ B
92
111
function zygote_ldiv (A:: TrackedMatrix , B:: TrackedVecOrMat )
93
112
return track (zygote_ldiv, A, B)
@@ -96,11 +115,49 @@ function zygote_ldiv(A::TrackedMatrix, B::AbstractVecOrMat)
96
115
return track (zygote_ldiv, A, B)
97
116
end
98
117
zygote_ldiv (A:: AbstractMatrix , B:: TrackedVecOrMat ) = track (zygote_ldiv, A, B)
99
- Tracker . @grad function zygote_ldiv (A, B)
118
+ @grad function zygote_ldiv (A, B)
100
119
Y, back = pullback (\ , data (A), data (B))
101
120
return Y, Δ-> back (data (Δ))
102
121
end
103
122
104
123
function Base.:\ (a:: Cholesky{<:TrackedReal, <:TrackedArray} , b:: AbstractVecOrMat )
105
124
return (a. U \ (a. U' \ b))
106
125
end
126
+
127
+ # SpecialFunctions
128
+
129
+ SpecialFunctions. logabsgamma (x:: TrackedReal ) = track (logabsgamma, x)
130
+ @grad function SpecialFunctions. logabsgamma (x:: Real )
131
+ return logabsgamma (data (x)), Δ -> (digamma (data (x)) * Δ[1 ],)
132
+ end
133
+ @adjoint function SpecialFunctions. logabsgamma (x:: Real )
134
+ return logabsgamma (x), Δ -> (digamma (x) * Δ[1 ],)
135
+ end
136
+
137
+ # Some Tracker fixes
138
+
139
+ for i = 0 : 2 , c = Tracker. combinations ([:AbstractArray , :TrackedArray , :TrackedReal , :Number ], i), f = [:hcat , :vcat ]
140
+ if :TrackedReal in c
141
+ cnames = map (_ -> gensym (), c)
142
+ @eval Base.$ f ($ ([:($ x:: $c ) for (x, c) in zip (cnames, c)]. .. ), x:: Union{TrackedArray,TrackedReal} , xs:: Union{AbstractArray,Number} ...) =
143
+ track ($ f, $ (cnames... ), x, xs... )
144
+ end
145
+ end
146
+ @grad function vcat (x:: Real )
147
+ vcat (data (x)), (Δ) -> (Δ[1 ],)
148
+ end
149
+ @grad function vcat (x1:: Real , x2:: Real )
150
+ vcat (data (x1), data (x2)), (Δ) -> (Δ[1 ], Δ[2 ])
151
+ end
152
+ @grad function vcat (x1:: AbstractVector , x2:: Real )
153
+ vcat (data (x1), data (x2)), (Δ) -> (Δ[1 : length (x1)], Δ[length (x1)+ 1 ])
154
+ end
155
+
156
+ # Zygote fill has issues with non-numbers
157
+
158
+ @adjoint function fill (x:: T , dims... ) where {T}
159
+ function zfill (x, dims... ,)
160
+ return reshape ([x for i in 1 : prod (dims)], dims)
161
+ end
162
+ pullback (zfill, x, dims... )
163
+ end
0 commit comments