Skip to content

Commit 3a6ca29

Browse files
RobAltenatreo
authored andcommitted
updated config to have the sample learn and beat the cartpole game.
Signed-off-by: Robert Altena <[email protected]>
1 parent 746bffa commit 3a6ca29

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

rl4j-cartpole-examples/src/main/java/Cartpole.java

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import org.deeplearning4j.rl4j.space.ActionSpace;
2323
import org.deeplearning4j.rl4j.space.Box;
2424
import org.deeplearning4j.rl4j.space.DiscreteSpace;
25-
import org.nd4j.linalg.learning.config.Adam;
25+
import org.nd4j.linalg.learning.config.RmsProp;
2626

2727
import java.util.logging.Logger;
2828

@@ -41,7 +41,7 @@ public class Cartpole
4141
.maxEpochStep(200) // Max step By epoch
4242
.maxStep(15000) // Max step
4343
.expRepMaxSize(150000) // Max size of experience replay
44-
.batchSize(32) // size of batches
44+
.batchSize(128) // size of batches
4545
.targetDqnUpdateFreq(500) // target update (hard)
4646
.updateStart(10) // num step noop warmup
4747
.rewardFactor(0.01) // reward scaling
@@ -52,13 +52,12 @@ public class Cartpole
5252
.doubleDQN(true) // double DQN
5353
.build();
5454

55-
5655
private static DQNFactoryStdDense.Configuration CARTPOLE_NET =
5756
DQNFactoryStdDense.Configuration.builder()
58-
.l2(0.001)
59-
.updater(new Adam(0.0005))
60-
.numHiddenNodes(16)
61-
.numLayer(3)
57+
.l2(0)
58+
.updater(new RmsProp(0.000025))
59+
.numHiddenNodes(300)
60+
.numLayer(2)
6261
.build();
6362

6463
public static void main(String[] args) {
@@ -85,7 +84,7 @@ private static void loadCartpole(DQNPolicy<Box> pol) {
8584

8685
//evaluate the agent
8786
double rewards = 0;
88-
for (int i = 0; i < 1000; i++) {
87+
for (int i = 0; i < 10; i++) {
8988
mdp2.reset();
9089
double reward = pol.play(mdp2);
9190
rewards += reward;

0 commit comments

Comments
 (0)