diff --git a/src/train.jl b/src/train.jl index 0514ce99..6b398dc3 100644 --- a/src/train.jl +++ b/src/train.jl @@ -183,6 +183,123 @@ function apply_gradients(optimizer::AdamOptimizer, grads_and_vars; global_step=n return group(ops...) end +mutable struct NadamOptimizer <: Optimizer + η::Float64 + β1::Float64 + β2::Float64 + ϵ::Float64 + name::String +end + +NadamOptimizer(learning_rate; β1=.9, β2=.999, ϵ=1e-8, name="nadam") = NadamOptimizer(learning_rate, β1, β2, ϵ, name) + +function NadamOptimizer(; η=.001, kwargs...) + NadamOptimizer(η; kwargs...) +end + +function Base.show(io::IO, optim::NadamOptimizer) + print(io, "NadamOptimizer(η=$(optim.η), β1=$(optim.β1), β2=$(optim.β2), ϵ=$(optim.ϵ))") +end + +function apply_gradients(optimizer::NadamOptimizer, grads_and_vars; global_step=nothing, name="nadam") + ops = Tensor[] + @advance_step + for (grad, var) in grads_and_vars + local m, v, T + variable_scope(name) do + variable_scope(node_name(var)[1]) do + m = get_variable("m", get_shape(var), eltype(var), initializer=ConstantInitializer(0.0), trainable=false) + v = get_variable("v", get_shape(var), eltype(var), initializer=ConstantInitializer(0.0), trainable=false) + T = get_variable("t", [], Float32, initializer=ConstantInitializer(1.0), trainable=false) + end + end + β1 = eltype(var)(optimizer.β1) + β2 = eltype(var)(optimizer.β2) + ϵ = eltype(var)(optimizer.ϵ) + η = eltype(var)(optimizer.η) + t = convert(Tensor{eltype(var)}, T) + push!(ops, tf.assign(T, T+1)) + lr = η*sqrt(1-β2^t)/(1-β1^t) + if isa(grad, tf.IndexedSlices) + m_slice = tf.gather(m, grad.indices) + v_slice = tf.gather(v, grad.indices) + m_new = β1 .* m_slice + (1-β1) .* grad.values + v_new = (1-β2) .* (grad.values .^ 2) + push!(ops, tf.scatter_sub(var.var_node, grad.indices, lr/(sqrt(v_new)+ϵ) .* (β1 .* m_new + (1-β1) .* grad.values))) + push!(ops, tf.scatter_update(m.var_node, grad.indices, m_new)) + push!(ops, tf.scatter_update(v.var_node, grad.indices, v_new)) + else + m_new = β1 .* m + (1-β1).*grad + v_new = β2 .* v + (1-β2).*(grad.*grad) + push!(ops, tf.assign_sub(var, lr/(sqrt(v_new)+ϵ) .* (β1 .* m_new + (1-β1) .* grad.values))) + push!(ops, tf.assign(m, m_new)) + push!(ops, tf.assign(v, v_new)) + end + end + return group(ops...) +end + +mutable struct AMSGradOptimizer <: Optimizer + η::Float64 + β1::Float64 + β2::Float64 + ϵ::Float64 + name::String +end + +AMSGradOptimizer(learning_rate; β1=.9, β2=.999, ϵ=1e-8, name="AMSGrad") = AMSGradOptimizer(learning_rate, β1, β2, ϵ, name) + +function AMSGradOptimizer(; η=.001, kwargs...) + AMSGradOptimizer(η; kwargs...) +end + +function Base.show(io::IO, optim::AMSGradOptimizer) + print(io, "AMSGradOptimizer(η=$(optim.η), β1=$(optim.β1), β2=$(optim.β2), ϵ=$(optim.ϵ))") +end + +function apply_gradients(optimizer::AMSGradOptimizer, grads_and_vars; global_step=nothing, name="AMSGrad") + ops = Tensor[] + @advance_step + for (grad, var) in grads_and_vars + local m, v, T + variable_scope(name) do + variable_scope(node_name(var)[1]) do + m = get_variable("m", get_shape(var), eltype(var), initializer=ConstantInitializer(0.0), trainable=false) + v = get_variable("v", get_shape(var), eltype(var), initializer=ConstantInitializer(0.0), trainable=false) + v_hat = get_variable("v_hat", get_shape(var), eltype(var), initializer=ConstantInitializer(0.0), trainable=false) + T = get_variable("t", [], Float32, initializer=ConstantInitializer(1.0), trainable=false) + end + end + β1 = eltype(var)(optimizer.β1) + β2 = eltype(var)(optimizer.β2) + ϵ = eltype(var)(optimizer.ϵ) + η = eltype(var)(optimizer.η) + t = convert(Tensor{eltype(var)}, T) + push!(ops, tf.assign(T, T+1)) + if isa(grad, tf.IndexedSlices) + m_slice = tf.gather(m, grad.indices) + v_slice = tf.gather(v, grad.indices) + m_new = β1 .* m_slice + (1-β1) .* grad.values + v_new = β2 .* v_slice + (1-β2) .* (grad.values .^ 2) + v_hat = max(v_hat, v_new) + push!(ops, tf.scatter_sub(var.var_node, grad.indices, η/(sqrt(v_hat)+ϵ) .* m_new)) + push!(ops, tf.scatter_update(m.var_node, grad.indices, m_new)) + push!(ops, tf.scatter_update(v.var_node, grad.indices, v_new)) + push!(ops, tf.scatter_update(v_hat.var_node, grad.indices, v_hat)) + else + m_new = β1 .* m + (1-β1).*grad + v_new = β2 .* v + (1-β2).*(grad.*grad) + v_hat = max(v_hat, v_new) + push!(ops, tf.assign_sub(var, η/(sqrt(v_hat)+ϵ) .* m_new)) + push!(ops, tf.assign(m, m_new)) + push!(ops, tf.assign(v, v_new)) + push!(ops, tf.assign(v_hat, v_hat)) + end + end + return group(ops...) +end + + mutable struct Saver var_list max_to_keep