22
22
import org .deeplearning4j .rl4j .space .ActionSpace ;
23
23
import org .deeplearning4j .rl4j .space .Box ;
24
24
import org .deeplearning4j .rl4j .space .DiscreteSpace ;
25
- import org .nd4j .linalg .learning .config .Adam ;
25
+ import org .nd4j .linalg .learning .config .RmsProp ;
26
26
27
27
import java .util .logging .Logger ;
28
28
@@ -41,7 +41,7 @@ public class Cartpole
41
41
.maxEpochStep (200 ) // Max step By epoch
42
42
.maxStep (15000 ) // Max step
43
43
.expRepMaxSize (150000 ) // Max size of experience replay
44
- .batchSize (32 ) // size of batches
44
+ .batchSize (128 ) // size of batches
45
45
.targetDqnUpdateFreq (500 ) // target update (hard)
46
46
.updateStart (10 ) // num step noop warmup
47
47
.rewardFactor (0.01 ) // reward scaling
@@ -52,13 +52,12 @@ public class Cartpole
52
52
.doubleDQN (true ) // double DQN
53
53
.build ();
54
54
55
-
56
55
private static DQNFactoryStdDense .Configuration CARTPOLE_NET =
57
56
DQNFactoryStdDense .Configuration .builder ()
58
- .l2 (0.001 )
59
- .updater (new Adam (0.0005 ))
60
- .numHiddenNodes (16 )
61
- .numLayer (3 )
57
+ .l2 (0 )
58
+ .updater (new RmsProp (0.000025 ))
59
+ .numHiddenNodes (300 )
60
+ .numLayer (2 )
62
61
.build ();
63
62
64
63
public static void main (String [] args ) {
@@ -85,7 +84,7 @@ private static void loadCartpole(DQNPolicy<Box> pol) {
85
84
86
85
//evaluate the agent
87
86
double rewards = 0 ;
88
- for (int i = 0 ; i < 1000 ; i ++) {
87
+ for (int i = 0 ; i < 10 ; i ++) {
89
88
mdp2 .reset ();
90
89
double reward = pol .play (mdp2 );
91
90
rewards += reward ;
0 commit comments