Skip to content

Commit 04b0638

Browse files
RobAltenatreo
authored andcommitted
add a3c sample for cartpole.
Signed-off-by: Robert Altena <[email protected]>
1 parent 1b4a225 commit 04b0638

File tree

2 files changed

+92
-8
lines changed

2 files changed

+92
-8
lines changed
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/* *****************************************************************************
2+
* Copyright (c) 2015-2019 Skymind, Inc.
3+
*
4+
* This program and the accompanying materials are made available under the
5+
* terms of the Apache License, Version 2.0 which is available at
6+
* https://www.apache.org/licenses/LICENSE-2.0.
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
10+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
11+
* License for the specific language governing permissions and limitations
12+
* under the License.
13+
*
14+
* SPDX-License-Identifier: Apache-2.0
15+
******************************************************************************/
16+
17+
import org.deeplearning4j.rl4j.learning.async.a3c.discrete.A3CDiscrete;
18+
import org.deeplearning4j.rl4j.learning.async.a3c.discrete.A3CDiscreteDense;
19+
import org.deeplearning4j.rl4j.mdp.gym.GymEnv;
20+
import org.deeplearning4j.rl4j.network.ac.ActorCriticFactorySeparateStdDense;
21+
import org.deeplearning4j.rl4j.policy.ACPolicy;
22+
import org.deeplearning4j.rl4j.space.Box;
23+
import org.deeplearning4j.rl4j.space.DiscreteSpace;
24+
import org.deeplearning4j.rl4j.space.Encodable;
25+
import org.nd4j.linalg.learning.config.Adam;
26+
27+
import java.io.IOException;
28+
29+
/**
30+
* @author rubenfiszel ([email protected]) on 8/18/16.
31+
*
32+
* A3C on cartpole
33+
* This example shows the classes in rl4j that implement the article here: https://arxiv.org/abs/1602.01783
34+
* Asynchronous Methods for Deep Reinforcement Learning. Mnih et al.
35+
*
36+
*/
37+
public class A3CCartpole {
38+
39+
public static void main(String[] args) throws IOException {
40+
A3CcartPole();
41+
}
42+
43+
private static void A3CcartPole() throws IOException {
44+
45+
//define the mdp from gym (name, render)
46+
String envUD = "CartPole-v1";
47+
GymEnv<Encodable, Integer, DiscreteSpace> mdp = new GymEnv<Encodable, Integer, DiscreteSpace>(envUD, false, false);
48+
49+
A3CDiscrete.A3CConfiguration CARTPOLE_A3C =
50+
new A3CDiscrete.A3CConfiguration(
51+
123, //Random seed
52+
200, //Max step By epoch
53+
500000, //Max step
54+
8, //Number of threads
55+
20, //t_max
56+
10, //num step noop warmup
57+
0.01, //reward scaling
58+
0.99, //gamma
59+
1.0 //td-error clipping
60+
);
61+
62+
ActorCriticFactorySeparateStdDense.Configuration CARTPOLE_NET_A3C = ActorCriticFactorySeparateStdDense.Configuration
63+
.builder().updater(new Adam(1e-2)).l2(0).numHiddenNodes(16).numLayer(3).build();
64+
65+
//define the training
66+
A3CDiscreteDense<Encodable> a3c = new A3CDiscreteDense<Encodable>(mdp, CARTPOLE_NET_A3C, CARTPOLE_A3C);
67+
68+
a3c.train(); //start the training
69+
mdp.close();
70+
71+
ACPolicy<org.deeplearning4j.rl4j.space.Encodable> pol = a3c.getPolicy();
72+
73+
pol.save("/tmp/val1/", "/tmp/pol1");
74+
75+
//reload the policy, will be equal to "pol", but without the randomness
76+
ACPolicy<Box> pol2 = ACPolicy.load("/tmp/val1/", "/tmp/pol1");
77+
Cartpole.loadCartpole(pol2, envUD);
78+
System.out.println("sample finished.");
79+
}
80+
81+
}

rl4j-cartpole-examples/src/main/java/Cartpole.java

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.deeplearning4j.rl4j.mdp.gym.GymEnv;
2020
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
2121
import org.deeplearning4j.rl4j.policy.DQNPolicy;
22+
import org.deeplearning4j.rl4j.policy.Policy;
2223
import org.deeplearning4j.rl4j.space.ActionSpace;
2324
import org.deeplearning4j.rl4j.space.Box;
2425
import org.deeplearning4j.rl4j.space.DiscreteSpace;
@@ -29,15 +30,17 @@
2930
/**
3031
* @author rubenfiszel ([email protected]) on 8/11/16.
3132
*
32-
* Main example for Cartpole DQN
33+
* Cartpole DQN
34+
* This example shows the basic rl4j classes implementing the 2013 article by Mnih et al. from deepmind.
35+
* https://arxiv.org/pdf/1312.5602.pdf
3336
*/
3437
public class Cartpole
3538
{
36-
private static String envUD = "CartPole-v1";
39+
private static String envID = "CartPole-v1";
3740

3841
public static void main(String[] args) {
3942
DQNPolicy<Box> pol = cartPole(); //get a trained agent to play the game.
40-
loadCartpole(pol); //show off the trained agent.
43+
loadCartpole(pol, envID); //show off the trained agent.
4144
}
4245

4346
private static DQNPolicy<Box> cartPole() {
@@ -71,10 +74,9 @@ private static DQNPolicy<Box> cartPole() {
7174
.build();
7275

7376
//Create the gym environment. We include these through the rl4j-gym dependency.
74-
GymEnv<Box, Integer, DiscreteSpace> mdp = new GymEnv<Box, Integer, DiscreteSpace>(envUD, false, false);
77+
GymEnv<Box, Integer, DiscreteSpace> mdp = new GymEnv<Box, Integer, DiscreteSpace>(envID, false, false);
7578

76-
//Create the solver. This class implements the 2013 article by Mnih et al. from deepmind.
77-
// https://arxiv.org/pdf/1312.5602.pdf
79+
//Create the solver.
7880
QLearningDiscreteDense<Box> dql = new QLearningDiscreteDense<Box>(mdp, CARTPOLE_NET, CARTPOLE_QL);
7981

8082
dql.train();
@@ -83,11 +85,12 @@ private static DQNPolicy<Box> cartPole() {
8385
return dql.getPolicy(); //return the trained agent.
8486
}
8587

86-
private static void loadCartpole(DQNPolicy<Box> pol) {
88+
// pass in a generic policy and endID to allow access from other samples in this package..
89+
static void loadCartpole(Policy<Box, Integer> pol, String envID) {
8790
//use the trained agent on a new similar mdp (but render it this time)
8891

8992
//define the mdp from gym (name, render)
90-
GymEnv<Box, Integer, ActionSpace<Integer>> mdp2 = new GymEnv<Box, Integer, ActionSpace<Integer>>(envUD, true, false);
93+
GymEnv<Box, Integer, ActionSpace<Integer>> mdp2 = new GymEnv<Box, Integer, ActionSpace<Integer>>(envID, true, false);
9194

9295
//evaluate the agent
9396
double rewards = 0;

0 commit comments

Comments
 (0)