diff --git a/Project.toml b/Project.toml index 538f627..145a1f2 100644 --- a/Project.toml +++ b/Project.toml @@ -13,12 +13,10 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" -Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" [extensions] OptimisersAdaptExt = ["Adapt"] OptimisersEnzymeCoreExt = "EnzymeCore" -OptimisersReactantExt = "Reactant" [compat] Adapt = "4" diff --git a/ext/OptimisersReactantExt.jl b/ext/OptimisersReactantExt.jl deleted file mode 100644 index ee764c8..0000000 --- a/ext/OptimisersReactantExt.jl +++ /dev/null @@ -1,8 +0,0 @@ -module OptimisersReactantExt - -import Optimisers -import Reactant - -Optimisers._eps(T::Type{<:Reactant.TracedRNumber{<:AbstractFloat}}, e) = T(e) - -end diff --git a/src/utils.jl b/src/utils.jl index 8f66746..c923290 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -15,8 +15,14 @@ end ofeltype(x, y) = convert(float(eltype(x)), y) -_eps(T::Type{<:AbstractFloat}, e) = T(e) -# catch complex and integers -_eps(T::Type{<:Number}, e) = _eps(real(float(T)), e) -# avoid small e being rounded to zero +""" + _eps(Type{T}, val) + +Mostly this produces `real(T)(val)`, so that `_eps(Float32, 1e-8) === 1f-8` will +convert the Float64 parameter epsilon to work nicely with Float32 parameter arrays. + +But for Float16, it imposes a minimum of `Float16(1e-7)`, unless `val==0`. +This is basically a hack to increase the default epsilon, to help many optimisers avoid NaN. +""" +_eps(T::Type{<:Number}, e) = real(float(T))(e) _eps(T::Type{Float16}, e) = e == 0 ? T(0) : max(T(1e-7), T(e))