Skip to content

Commit 2f1d131

Browse files
committed
Added testing version of optimised RMSProp from Atari paper
1 parent 666d769 commit 2f1d131

File tree

3 files changed

+43
-0
lines changed

3 files changed

+43
-0
lines changed

DeepFried2/optimizers/DQNProp.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# -*- coding: utf-8 -*-
2+
import DeepFried2 as df
3+
4+
5+
class DQNProp(df.Optimizer):
6+
"""
7+
RMSProp as described here on page 23:
8+
http://arxiv.org/pdf/1308.0850v5.pdf
9+
10+
Also used by DeepMind here:
11+
https://sites.google.com/a/deepmind.com/dqn/
12+
In NeuralQLearner.lua
13+
14+
The updates are:
15+
16+
g_{e+1} = ρ * g_e + (1-ρ) * ∇p_e
17+
g²_{e+1} = ρ * g²_e + (1-ρ) * ∇p_e²
18+
p_{e+1} = p_e - lr * ∇p_e / √(g²_{e+1} - g_{e+1}²)
19+
20+
This roughly corresponds to dividing the gradients by their standard deviation
21+
over the past batches, in a rolling-momentum fashion.
22+
The more "unstable" a gradient, the lower its effective learning-rate.
23+
24+
"""
25+
26+
def __init__(self, lr, rho, eps=1e-7):
27+
df.Optimizer.__init__(self, lr=lr, rho=rho, eps=eps)
28+
29+
def get_updates(self, params, grads, lr, rho, eps):
30+
updates = []
31+
32+
for param, grad in zip(params, grads):
33+
g_state = df.utils.create_param_state_as(param)
34+
new_g = rho*g_state + (1-rho)*grad
35+
g2_state = df.utils.create_param_state_as(param)
36+
new_g2 = rho*g2_state+(1-rho)*grad*grad
37+
updates.append((g_state, new_g))
38+
updates.append((g2_state, new_g2))
39+
updates.append((param, param - lr*(grad/df.T.sqrt(new_g2-new_g*new_g+eps))))
40+
41+
return updates

DeepFried2/optimizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
from .Nesterov import Nesterov
44
from .AdaGrad import AdaGrad
55
from .RMSProp import RMSProp
6+
from .DQNProp import DQNProp
67
from .AdaDelta import AdaDelta
78
from .Adam import Adam

examples/Optimizers/run.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,6 @@ def run(optim):
4343
run(df.Nesterov(lr=1e-2, momentum=0.90))
4444
run(df.AdaGrad(lr=1e-2, eps=1e-4))
4545
run(df.RMSProp(lr=1e-3, rho=0.90, eps=1e-5))
46+
run(df.DQNProp(lr=1e-3, rho=0.90, eps=1e-5))
4647
run(df.AdaDelta(rho=0.99, lr=5e-1, eps=1e-4))
4748
run(df.Adam(alpha=1e-3, beta1=0.95, beta2=0.9, eps=1e-8))

0 commit comments

Comments
 (0)