Deep reinforcement learning package for torch7.
Algorithms:
- Deep Q-learning [1]
- Double DQN [2]
- Bootstrapped DQN (broken) [3]
- Asynchronous advantage actor-critic [4]
git clone https://github.com/PoHsunSu/dprl.git
cd dprl
luarocks make dprl-scm-1.rockspec
The library provides implementation of deep reinforcement learning algorithms.
This class contains learning and testing procedures for deep Q-learning [1].
This is the constructor of dql. Its arguments are:
-
env: an environment with interfaces defined in rlenvs -
config: a table containing configurations ofdqlstep: number of steps before an episode terminatesupdatePeriod: number of steps between successive updates of target Q-network
-
statePreprop: a function which receives observation fromenvas argument and returns state fordqn. See test-dql-catch.lua for example -
actPreprop: a function which receives output ofdqnand returns action forenv. See test-dql-catch.lua for example
This method implements learning procedure of dql. Its arguments are:
episode: number of episodes whichdqllearns forreport: a function called at each step for reporting the status of learning. Its inputs are transition, current step number, and current episode number. A transition contains the following keys:s: current statea: current actionr: reward given actionaat statesns: next state given actionaat statest: boolean value telling whethernsis terminal state or not
You can use report to compute total reward of an episode or print the estimated Q value by dqn. See test-dql-catch.lua for example.
This method implements test procedure of dql. Its arguments are:
episode: number of episodes whichdqltests forreport: seereportin dql:learn
"dqn" means deep Q-network [1]. It is the back-end of dql. It implements interfaces to train the underlying neural network model. It also implements experiment replay.
This is the constructor of dqn. Its arguments are:
-
qnet: a neural network model built with the nn package. Its input is always mini-batch of states whose dimension defined bystatePreprop(see dprl.dql). Its output is estimated Q values of all possible actions. -
config: a table containing the following configurations ofdqnreplaySize: size of replay memorybatchSize: size of mini-batch of training cases sampled on each replaydiscount: discount factor of rewardepsilon: the ε of ε-greedy exploration
-
optim: optimization in the optim package for trainingqnet. -
optimConfig: configuration ofoptim
"ddqn" means double deep Q-network [2]. It inherets from dprl.dqn. We get double deep Q-learning by giving ddqn, instead of dqn, to dql .
The only difference of dprl.ddqn to dprl.dqn is how it compute target Q-value. dprl.ddqn is recommended because it alleviates the over-estimation problem of dprl.dqn [2].
This is the constructor of dprl.dqnn. Its arguments are identical to dprl.dqn.
dprl.bdql implements learning procedure in Bootstrapped DQN. Except initialization, its usage is identical to dprl.dql.
-
Initialize a bootstrapped deep Q-learning agent.
local bdql = dprl.bdql(bdqn, env, config, statePreprop, actPreprop)Except the first arguments
bdqn, which is an instance ofdprl.bdqn, definitions of the other arguments are the same indprl.dql.
dprl.bdqn inherets dprl.dqn. It is customized for Bootstrapped Deep Q-network.
-
Initialize
dprl.bdqnlocal bdqn = dprl.bdqn(bqnet, config, optim, optimConfig)arguments:
bqnet: a bootstrapped neural network with moduleBootstrap.config: a table containing the following configurations forbdqnreplaySize,batchSize,discount,andepsilon: seeconfigindprl.dqn.headNum: the number of heads in bootstrapped neural networkbqnet.
optim: seeoptimindprl.dqn.optimConfig: seeoptimConfigindprl.dqn.
dprl.asyncl is the framework for asynchronous learning [4]. It manages the multi-threaded procedure in asynchronous learning. The asynchronous advantage actor critic (a3c) algorithm is realized by providing adventage actor critic agent (dprl.aac) to Asynchronous learning (dprl.asyncl). See test-a3c-atari.lua for example.
This is the constructor of dprl.asyncl. Its arguments are:
-
asynclAgent: a learning agent such as avantage actor-critic (dprl.aac). -
env: an eviroment with interfaces defined in rlenvs -
config: a table containing configurations ofasynclnthread: number of actor-leaner threadsmaxSteps: maximum number of steps in an episode on testingloadPackage: an optional function called before construting actor-learner thread for loading packageloadEnv: an optional function to load enviroment (env) in actor-learner thread if the enviroment is not serializable. Enviroments written in Lua are serializable but those written in C like Atari emulator are not serializable. See in test-a3c-atari.lua for example.
-
statePreprop: a function which receives observation fromenvas argument and returns state forasynclAgent. See test-a3c-catch.lua for example. -
actPreprop: a function which receives output ofaacor other learning agent and returns action forenv. See test-a3c-catch.lua for example. -
rewardPreprop: a function which receives reward fromenvas argument and returns processed reward forasynclAgent. See reward clippling in test-a3c-atari.lua for example.
This method is the learning procedure of dprl.asyncl. Its arguments are:
Tmax: the limit of total (global) learning steps of all actor-learner threadsreport: a function called at each step for reporting the status of learning. Its arguments are transition, count of learning steps in a thread, and count of global learning steps. A transition contains the following keys:s: current statea: current actionr: reward given actionaat statesns: next state given actionaat statest: boolean value telling whethernsis terminal state or not
This class implements routines called by the asynchronous framework dprl.asyncl for realizing the Asynchronous Advantage Actor-Critic (a3c) algorithm [4]. It trains the actor and critic neural network model that should be provided on construction.
This is the constructor of dprl.aac. Its arguments are:
anet: the actor network. Its input and output must conform to the requirement of environmentenvindprl.asyncl. Note that you can usestatePrepropandactPrepropindprl.asyncl.cnet: the critic network. Its input is the same asanetand its output must be a single value tensor.config: a table containing configurations ofaactmax: number of steps between performing asynchronous update of global parametersdiscount: discount factor for computing total rewardcriticGradScale: a scale multiplying the gradient parameters of critic networkcnetto allow different learning rate between actor network and critic network.
To share parameters between actor network and critic network, make sure gradient parameters (i.e. 'gradWeight' and 'gradBias') is shared as well. For example,
local cnet = anet:clone('weight','bias','gradWeight','gradBias')
This module is for constructing bootstrapped network [3]. Let the shared network be shareNet and the head network be headNet. A bootstrapped network bqnet for dprl.bdqn can be constructed as follows:
require 'Bootstrap'
-- Definition of 'shareNet' and head 'headNet'
-- Decorate headNet with nn.Bootstrap
local boostrappedHeadNet = nn.Bootstrap(headNet, headNum, param_init)
-- Connect shareNet and boostrappedHeadNet
local bqnet = nn.Sequential():add(shareNet):add(boostrappedHeadNet)
headNum: the number of heads of the bootstrapped network
param_init: a scalar value controlling the range or variance of parameter initialization in headNet.
It is passed to method headNet:reset(param_init) after constructing clones of headNet.
[1] Volodymyr Mnih et al., “Human-Level Control through Deep Reinforcement Learning,” Nature 518, no. 7540 (February 26, 2015): 529–33, doi:10.1038/nature14236.
[2] Hado van Hasselt, Arthur Guez, and David Silver, “Deep Reinforcement Learning with Double Q-Learning,” arXiv:1509.06461, September 22, 2015, http://arxiv.org/abs/1509.06461.
[3] Ian Osband et al., “Deep Exploration via Bootstrapped DQN,” arXiv:1602.04621, February 15, 2016, http://arxiv.org/abs/1602.04621.
[4] Volodymyr Mnih et al., “Asynchronous Methods for Deep Reinforcement Learning,” arXiv:1602.01783, February 4, 2016, http://arxiv.org/abs/1602.01783.