Skip to content

Commit 746bffa

Browse files
RobAltenatreo
authored andcommitted
implement feedback.
Signed-off-by: Robert Altena <[email protected]>
1 parent 1c3245e commit 746bffa

File tree

1 file changed

+28
-19
lines changed

1 file changed

+28
-19
lines changed

rl4j-cartpole-examples/src/main/java/Cartpole.java

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
import org.deeplearning4j.rl4j.mdp.gym.GymEnv;
2020
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
2121
import org.deeplearning4j.rl4j.policy.DQNPolicy;
22+
import org.deeplearning4j.rl4j.space.ActionSpace;
2223
import org.deeplearning4j.rl4j.space.Box;
24+
import org.deeplearning4j.rl4j.space.DiscreteSpace;
2325
import org.nd4j.linalg.learning.config.Adam;
2426

2527
import java.util.logging.Logger;
@@ -31,26 +33,33 @@
3133
*/
3234
public class Cartpole
3335
{
34-
private static QLearning.QLConfiguration CARTPOLE_QL =
35-
new QLearning.QLConfiguration(
36-
123, //Random seed
37-
200, //Max step By epoch
38-
150000, //Max step
39-
150000, //Max size of experience replay
40-
32, //size of batches
41-
500, //target update (hard)
42-
10, //num step noop warmup
43-
0.01, //reward scaling
44-
0.99, //gamma
45-
1.0, //td-error clipping
46-
0.1f, //min epsilon
47-
1000, //num step for eps greedy anneal
48-
true //double DQN
49-
);
36+
/*
37+
Q learning configuration. Note that none of these are specific to the cartpole problem.
38+
*/
39+
private static QLearning.QLConfiguration CARTPOLE_QL = QLearning.QLConfiguration.builder()
40+
.seed(123) //Random seed (for reproducability)
41+
.maxEpochStep(200) // Max step By epoch
42+
.maxStep(15000) // Max step
43+
.expRepMaxSize(150000) // Max size of experience replay
44+
.batchSize(32) // size of batches
45+
.targetDqnUpdateFreq(500) // target update (hard)
46+
.updateStart(10) // num step noop warmup
47+
.rewardFactor(0.01) // reward scaling
48+
.gamma(0.99) // gamma
49+
.errorClamp(1.0) // /td-error clipping
50+
.minEpsilon(0.1f) // min epsilon
51+
.epsilonNbStep(1000) // num step for eps greedy anneal
52+
.doubleDQN(true) // double DQN
53+
.build();
54+
5055

5156
private static DQNFactoryStdDense.Configuration CARTPOLE_NET =
5257
DQNFactoryStdDense.Configuration.builder()
53-
.l2(0.001).updater(new Adam(0.0005)).numHiddenNodes(16).numLayer(3).build();
58+
.l2(0.001)
59+
.updater(new Adam(0.0005))
60+
.numHiddenNodes(16)
61+
.numLayer(3)
62+
.build();
5463

5564
public static void main(String[] args) {
5665
DQNPolicy<Box> pol = cartPole();
@@ -59,7 +68,7 @@ public static void main(String[] args) {
5968

6069
private static DQNPolicy<Box> cartPole() {
6170
//define the mdp from gym (name, render)
62-
GymEnv<Box, Integer, org.deeplearning4j.rl4j.space.DiscreteSpace> mdp = new GymEnv<Box, Integer, org.deeplearning4j.rl4j.space.DiscreteSpace>("CartPole-v0", false, false);
71+
GymEnv<Box, Integer, DiscreteSpace> mdp = new GymEnv<Box, Integer, DiscreteSpace>("CartPole-v0", false, false);
6372
QLearningDiscreteDense<Box> dql = new QLearningDiscreteDense<Box>(mdp, CARTPOLE_NET, CARTPOLE_QL);
6473

6574
dql.train();
@@ -72,7 +81,7 @@ private static void loadCartpole(DQNPolicy<Box> pol) {
7281
//use the trained agent on a new similar mdp (but render it this time)
7382

7483
//define the mdp from gym (name, render)
75-
GymEnv<Box, Integer, org.deeplearning4j.rl4j.space.ActionSpace<Integer>> mdp2 = new GymEnv<Box, Integer, org.deeplearning4j.rl4j.space.ActionSpace<Integer>>("CartPole-v0", true, false);
84+
GymEnv<Box, Integer, ActionSpace<Integer>> mdp2 = new GymEnv<Box, Integer, ActionSpace<Integer>>("CartPole-v0", true, false);
7685

7786
//evaluate the agent
7887
double rewards = 0;

0 commit comments

Comments
 (0)