1- export grad, jacobian, jvp, j′vp, to_vec
2- replace_arg (x, xs:: Tuple , k:: Int ) = ntuple (p -> p == k ? x : xs[p], length (xs))
3-
41"""
5- grad (fdm, f, xs ...)
2+ jacobian (fdm, f, x ...)
63
7- Approximate the gradient of `f` at `xs...` using `fdm`. Assumes that `f(xs...)` is scalar.
4+ Approximate the Jacobian of `f` at `x` using `fdm`. Results will be returned as a
5+ `Matrix{<:Number}` of `size(length(y_vec), length(x_vec))` where `x_vec` is the flattened
6+ version of `x`, and `y_vec` the flattened version of `f(x...)`. Flattening performed by
7+ [`to_vec`](@ref).
88"""
9- function grad end
10-
11- function _grad (fdm, f, x:: AbstractArray{T} ) where T <: Number
12- # x must be mutable, we will mutate it and then mutate it back.
13- dx = similar (x)
14- for k in eachindex (x)
15- dx[k] = fdm (zero (T)) do ϵ
16- xk = x[k]
17- x[k] = xk + ϵ
18- ret = f (x)
19- x[k] = xk # Can't do `x[k] -= ϵ` as floating-point math is not associative
9+ function jacobian (fdm, f, x:: Vector{<:Number} ; len= nothing )
10+ len != = nothing && Base. depwarn (
11+ " `len` keyword argument to `jacobian` is no longer required " *
12+ " and will not be permitted in the future." ,
13+ :jacobian
14+ )
15+ ẏs = map (eachindex (x)) do n
16+ return fdm (zero (eltype (x))) do ε
17+ xn = x[n]
18+ x[n] = xn + ε
19+ ret = first (to_vec (f (x)))
20+ x[n] = xn # Can't do `x[n] -= ϵ` as floating-point math is not associative
2021 return ret
2122 end
2223 end
23- return (dx , )
24+ return (hcat (ẏs ... ) , )
2425end
2526
26- grad (fdm, f, x:: Array{<:Number} ) = _grad (fdm, f, x)
27- # Fallback for when we don't know `x` will be mutable:
28- grad (fdm, f, x:: AbstractArray{<:Number} ) = _grad (fdm, f, similar (x).= x)
29-
30- grad (fdm, f, x:: Real ) = (fdm (f, x), )
31- grad (fdm, f, x:: Tuple ) = (grad (fdm, (xs... )-> f (xs), x... ), )
32-
33- function grad (fdm, f, d:: Dict{K, V} ) where {K, V}
34- ∇d = Dict {K, V} ()
35- for (k, v) in d
36- dk = d[k]
37- function f′ (x)
38- d[k] = x
39- return f (d)
40- end
41- ∇d[k] = grad (fdm, f′, v)[1 ]
42- d[k] = dk
43- end
44- return (∇d, )
27+ function jacobian (fdm, f, x; len= nothing )
28+ x_vec, from_vec = to_vec (x)
29+ return jacobian (fdm, f ∘ from_vec, x_vec; len= len)
4530end
4631
47- function grad (fdm, f, x)
48- v, back = to_vec (x)
49- return (back (grad (fdm, x-> f (back (v)), v)), )
50- end
51-
52- function grad (fdm, f, xs... )
53- return ntuple (length (xs)) do k
54- grad (fdm, x-> f (replace_arg (x, xs, k)... ), xs[k])[1 ]
55- end
56- end
57-
58- """
59- jacobian(fdm, f, xs::Union{Real, AbstractArray{<:Real}}; len::Int=length(f(x)))
60-
61- Approximate the Jacobian of `f` at `x` using `fdm`. `f(x)` must be a length `len` vector. If
62- `len` is not provided, then `f(x)` is computed once to determine the output size.
63- """
64- function jacobian (fdm, f, x:: Union{T, AbstractArray{T}} ; len:: Int = length (f (x))) where {T <: Number }
65- J = Matrix {float(T)} (undef, len, length (x))
66- for d in 1 : len
67- gs = grad (fdm, x-> f (x)[d], x)[1 ]
68- for k in 1 : length (x)
69- J[d, k] = gs[k]
70- end
71- end
72- return (J, )
73- end
74-
75- function jacobian (fdm, f, xs... ; len:: Int = length (f (xs... )))
32+ function jacobian (fdm, f, xs... ; len= nothing )
7633 return ntuple (length (xs)) do k
7734 jacobian (fdm, x-> f (replace_arg (x, xs, k)... ), xs[k]; len= len)[1 ]
7835 end
7936end
8037
38+ replace_arg (x, xs:: Tuple , k:: Int ) = ntuple (p -> p == k ? x : xs[p], length (xs))
39+
8140"""
8241 _jvp(fdm, f, x::Vector{<:Number}, ẋ::AbstractVector{<:Number})
8342
8443Convenience function to compute `jacobian(f, x) * ẋ`.
8544"""
86- _jvp (fdm, f, x:: Vector{<:Number} , ẋ:: AV{<:Number} ) = fdm (ε -> f (x .+ ε .* ẋ), zero (eltype (x)))
87-
88- """
89- _j′vp(fdm, f, ȳ::AbstractVector{<:Number}, x::Vector{<:Number})
90-
91- Convenience function to compute `transpose(jacobian(f, x)) * ȳ`.
92- """
93- _j′vp (fdm, f, ȳ:: AV{<:Number} , x:: Vector{<:Number} ) = transpose (jacobian (fdm, f, x; len= length (ȳ))[1 ]) * ȳ
45+ function _jvp (fdm, f, x:: Vector{<:Number} , ẋ:: Vector{<:Number} )
46+ return fdm (ε -> f (x .+ ε .* ẋ), zero (eltype (x)))
47+ end
9448
9549"""
9650 jvp(fdm, f, x, ẋ)
@@ -111,110 +65,23 @@ end
11165"""
11266 j′vp(fdm, f, ȳ, x...)
11367
114- Compute an adjoint with any types of arguments for which [`to_vec`](@ref) is defined.
68+ Compute an adjoint with any types of arguments `x` for which [`to_vec`](@ref) is defined.
11569"""
11670function j′vp (fdm, f, ȳ, x)
11771 x_vec, vec_to_x = to_vec (x)
11872 ȳ_vec, _ = to_vec (ȳ)
119- return (vec_to_x (_j′vp (fdm, x_vec -> to_vec ( f ( vec_to_x (x_vec)))[ 1 ] , ȳ_vec, x_vec)), )
73+ return (vec_to_x (_j′vp (fdm, first ∘ to_vec ∘ f ∘ vec_to_x, ȳ_vec, x_vec)), )
12074end
121- j′vp (fdm, f, ȳ, xs... ) = j′vp (fdm, xs-> f (xs... ), ȳ, xs)[1 ]
12275
123- """
124- to_vec(x) -> Tuple{<:AbstractVector, <:Function}
125-
126- Transform `x` into a `Vector`, and return a closure which inverts the transformation.
127- """
128- function to_vec (x:: Number )
129- function Number_from_vec (x_vec)
130- return first (x_vec)
131- end
132- return [x], Number_from_vec
133- end
134-
135- # Vectors
136- to_vec (x:: Vector{<:Number} ) = (x, identity)
137- function to_vec (x:: Vector )
138- x_vecs_and_backs = map (to_vec, x)
139- x_vecs, backs = first .(x_vecs_and_backs), last .(x_vecs_and_backs)
140- function Vector_from_vec (x_vec)
141- sz = cumsum ([map (length, x_vecs)... ])
142- return [backs[n](x_vec[sz[n]- length (x_vecs[n])+ 1 : sz[n]]) for n in eachindex (x)]
143- end
144- return vcat (x_vecs... ), Vector_from_vec
145- end
146-
147- # Arrays
148- function to_vec (x:: Array{<:Number} )
149- function Array_from_vec (x_vec)
150- return reshape (x_vec, size (x))
151- end
152- return vec (x), Array_from_vec
153- end
154-
155- function to_vec (x:: Array )
156- x_vec, back = to_vec (reshape (x, :))
157- function Array_from_vec (x_vec)
158- return reshape (back (x_vec), size (x))
159- end
160- return x_vec, Array_from_vec
161- end
162-
163- # AbstractArrays
164- function to_vec (x:: T ) where {T<: LinearAlgebra.AbstractTriangular }
165- x_vec, back = to_vec (Matrix (x))
166- function AbstractTriangular_from_vec (x_vec)
167- return T (reshape (back (x_vec), size (x)))
168- end
169- return x_vec, AbstractTriangular_from_vec
170- end
171-
172- function to_vec (x:: Symmetric )
173- function Symmetric_from_vec (x_vec)
174- return Symmetric (reshape (x_vec, size (x)))
175- end
176- return vec (Matrix (x)), Symmetric_from_vec
177- end
178-
179- function to_vec (X:: Diagonal )
180- function Diagonal_from_vec (x_vec)
181- return Diagonal (reshape (x_vec, size (X)... ))
182- end
183- return vec (Matrix (X)), Diagonal_from_vec
184- end
185-
186- function to_vec (X:: Transpose )
187- function Transpose_from_vec (x_vec)
188- return Transpose (permutedims (reshape (x_vec, size (X))))
189- end
190- return vec (Matrix (X)), Transpose_from_vec
191- end
76+ j′vp (fdm, f, ȳ, xs... ) = j′vp (fdm, xs-> f (xs... ), ȳ, xs)[1 ]
19277
193- function to_vec (X:: Adjoint )
194- function Adjoint_from_vec (x_vec)
195- return Adjoint (conj! (permutedims (reshape (x_vec, size (X)))))
196- end
197- return vec (Matrix (X)), Adjoint_from_vec
78+ function _j′vp (fdm, f, ȳ:: Vector{<:Number} , x:: Vector{<:Number} )
79+ return transpose (first (jacobian (fdm, f, x))) * ȳ
19880end
19981
200- # Non-array data structures
201-
202- function to_vec (x:: Tuple )
203- x_vecs, x_backs = zip (map (to_vec, x)... )
204- sz = cumsum ([map (length, x_vecs)... ])
205- function Tuple_from_vec (v)
206- return ntuple (n-> x_backs[n](v[sz[n]- length (x_vecs[n])+ 1 : sz[n]]), length (x))
207- end
208- return vcat (x_vecs... ), Tuple_from_vec
209- end
82+ """
83+ grad(fdm, f, xs...)
21084
211- # Convert to a vector-of-vectors to make use of existing functionality.
212- function to_vec (d:: Dict )
213- d_vec_vec = [val for val in values (d)]
214- d_vec, back = to_vec (d_vec_vec)
215- function Dict_from_vec (v)
216- v_vec_vec = back (v)
217- return Dict ([(key, v_vec_vec[n]) for (n, key) in enumerate (keys (d))])
218- end
219- return d_vec, Dict_from_vec
220- end
85+ Compute the gradient of `f` for any `xs` for which [`to_vec`](@ref) is defined.
86+ """
87+ grad (fdm, f, xs... ) = j′vp (fdm, f, 1 , xs... ) # `j′vp` with seed of 1
0 commit comments