19
19
import org .deeplearning4j .rl4j .mdp .gym .GymEnv ;
20
20
import org .deeplearning4j .rl4j .network .dqn .DQNFactoryStdDense ;
21
21
import org .deeplearning4j .rl4j .policy .DQNPolicy ;
22
+ import org .deeplearning4j .rl4j .space .ActionSpace ;
22
23
import org .deeplearning4j .rl4j .space .Box ;
24
+ import org .deeplearning4j .rl4j .space .DiscreteSpace ;
23
25
import org .nd4j .linalg .learning .config .Adam ;
24
26
25
27
import java .util .logging .Logger ;
31
33
*/
32
34
public class Cartpole
33
35
{
34
- private static QLearning .QLConfiguration CARTPOLE_QL =
35
- new QLearning .QLConfiguration (
36
- 123 , //Random seed
37
- 200 , //Max step By epoch
38
- 150000 , //Max step
39
- 150000 , //Max size of experience replay
40
- 32 , //size of batches
41
- 500 , //target update (hard)
42
- 10 , //num step noop warmup
43
- 0.01 , //reward scaling
44
- 0.99 , //gamma
45
- 1.0 , //td-error clipping
46
- 0.1f , //min epsilon
47
- 1000 , //num step for eps greedy anneal
48
- true //double DQN
49
- );
36
+ /*
37
+ Q learning configuration. Note that none of these are specific to the cartpole problem.
38
+ */
39
+ private static QLearning .QLConfiguration CARTPOLE_QL = QLearning .QLConfiguration .builder ()
40
+ .seed (123 ) //Random seed (for reproducability)
41
+ .maxEpochStep (200 ) // Max step By epoch
42
+ .maxStep (15000 ) // Max step
43
+ .expRepMaxSize (150000 ) // Max size of experience replay
44
+ .batchSize (32 ) // size of batches
45
+ .targetDqnUpdateFreq (500 ) // target update (hard)
46
+ .updateStart (10 ) // num step noop warmup
47
+ .rewardFactor (0.01 ) // reward scaling
48
+ .gamma (0.99 ) // gamma
49
+ .errorClamp (1.0 ) // /td-error clipping
50
+ .minEpsilon (0.1f ) // min epsilon
51
+ .epsilonNbStep (1000 ) // num step for eps greedy anneal
52
+ .doubleDQN (true ) // double DQN
53
+ .build ();
54
+
50
55
51
56
private static DQNFactoryStdDense .Configuration CARTPOLE_NET =
52
57
DQNFactoryStdDense .Configuration .builder ()
53
- .l2 (0.001 ).updater (new Adam (0.0005 )).numHiddenNodes (16 ).numLayer (3 ).build ();
58
+ .l2 (0.001 )
59
+ .updater (new Adam (0.0005 ))
60
+ .numHiddenNodes (16 )
61
+ .numLayer (3 )
62
+ .build ();
54
63
55
64
public static void main (String [] args ) {
56
65
DQNPolicy <Box > pol = cartPole ();
@@ -59,7 +68,7 @@ public static void main(String[] args) {
59
68
60
69
private static DQNPolicy <Box > cartPole () {
61
70
//define the mdp from gym (name, render)
62
- GymEnv <Box , Integer , org . deeplearning4j . rl4j . space . DiscreteSpace > mdp = new GymEnv <Box , Integer , org . deeplearning4j . rl4j . space . DiscreteSpace >("CartPole-v0" , false , false );
71
+ GymEnv <Box , Integer , DiscreteSpace > mdp = new GymEnv <Box , Integer , DiscreteSpace >("CartPole-v0" , false , false );
63
72
QLearningDiscreteDense <Box > dql = new QLearningDiscreteDense <Box >(mdp , CARTPOLE_NET , CARTPOLE_QL );
64
73
65
74
dql .train ();
@@ -72,7 +81,7 @@ private static void loadCartpole(DQNPolicy<Box> pol) {
72
81
//use the trained agent on a new similar mdp (but render it this time)
73
82
74
83
//define the mdp from gym (name, render)
75
- GymEnv <Box , Integer , org . deeplearning4j . rl4j . space . ActionSpace <Integer >> mdp2 = new GymEnv <Box , Integer , org . deeplearning4j . rl4j . space . ActionSpace <Integer >>("CartPole-v0" , true , false );
84
+ GymEnv <Box , Integer , ActionSpace <Integer >> mdp2 = new GymEnv <Box , Integer , ActionSpace <Integer >>("CartPole-v0" , true , false );
76
85
77
86
//evaluate the agent
78
87
double rewards = 0 ;
0 commit comments