@@ -8,33 +8,40 @@ Approximate the gradient of `f` at `xs...` using `fdm`. Assumes that `f(xs...)`
88"""
99function grad end
1010
11- function grad (fdm, f, x:: AbstractArray{T} ) where T <: Number
11+ function _grad (fdm, f, x:: AbstractArray{T} ) where T <: Number
12+ # x must be mutable, we will mutate it and then mutate it back.
1213 dx = similar (x)
13- tmp = similar (x)
1414 for k in eachindex (x)
1515 dx[k] = fdm (zero (T)) do ϵ
16- tmp .= x
17- tmp[k] += ϵ
18- return f (tmp)
16+ xk = x[k]
17+ x[k] = xk + ϵ
18+ ret = f (x)
19+ x[k] = xk # Can't do `x[k] -= ϵ` as floating-point math is not associative
20+ return ret
1921 end
2022 end
2123 return (dx, )
2224end
2325
26+ grad (fdm, f, x:: Array{<:Number} ) = _grad (fdm, f, x)
27+ # Fallback for when we don't know `x` will be mutable:
28+ grad (fdm, f, x:: AbstractArray{<:Number} ) = _grad (fdm, f, similar (x).= x)
29+
2430grad (fdm, f, x:: Real ) = (fdm (f, x), )
2531grad (fdm, f, x:: Tuple ) = (grad (fdm, (xs... )-> f (xs), x... ), )
2632
2733function grad (fdm, f, d:: Dict{K, V} ) where {K, V}
28- dd = Dict {K, V} ()
34+ ∇d = Dict {K, V} ()
2935 for (k, v) in d
36+ dk = d[k]
3037 function f′ (x)
31- tmp = copy (d)
32- tmp[k] = x
33- return f (tmp)
38+ d[k] = x
39+ return f (d)
3440 end
35- dd[k] = grad (fdm, f′, v)[1 ]
41+ ∇d[k] = grad (fdm, f′, v)[1 ]
42+ d[k] = dk
3643 end
37- return (dd , )
44+ return (∇d , )
3845end
3946
4047function grad (fdm, f, x)
0 commit comments