33
33
*/
34
34
public class Cartpole
35
35
{
36
- /*
37
- Q learning configuration. Note that none of these are specific to the cartpole problem.
38
- */
39
- private static QLearning .QLConfiguration CARTPOLE_QL = QLearning .QLConfiguration .builder ()
40
- .seed (123 ) //Random seed (for reproducability)
41
- .maxEpochStep (200 ) // Max step By epoch
42
- .maxStep (15000 ) // Max step
43
- .expRepMaxSize (150000 ) // Max size of experience replay
44
- .batchSize (128 ) // size of batches
45
- .targetDqnUpdateFreq (500 ) // target update (hard)
46
- .updateStart (10 ) // num step noop warmup
47
- .rewardFactor (0.01 ) // reward scaling
48
- .gamma (0.99 ) // gamma
49
- .errorClamp (1.0 ) // /td-error clipping
50
- .minEpsilon (0.1f ) // min epsilon
51
- .epsilonNbStep (1000 ) // num step for eps greedy anneal
52
- .doubleDQN (true ) // double DQN
53
- .build ();
54
-
55
- private static DQNFactoryStdDense .Configuration CARTPOLE_NET =
56
- DQNFactoryStdDense .Configuration .builder ()
57
- .l2 (0 )
58
- .updater (new RmsProp (0.000025 ))
59
- .numHiddenNodes (300 )
60
- .numLayer (2 )
61
- .build ();
36
+ private static String envUD = "CartPole-v1" ;
62
37
63
38
public static void main (String [] args ) {
64
- DQNPolicy <Box > pol = cartPole ();
65
- loadCartpole (pol );
39
+ DQNPolicy <Box > pol = cartPole (); //get a trained agent to play the game.
40
+ loadCartpole (pol ); //show off the trained agent.
66
41
}
67
42
68
43
private static DQNPolicy <Box > cartPole () {
69
- //define the mdp from gym (name, render)
70
- GymEnv <Box , Integer , DiscreteSpace > mdp = new GymEnv <Box , Integer , DiscreteSpace >("CartPole-v0" , false , false );
44
+
45
+ // Q learning configuration. Note that none of these are specific to the cartpole problem.
46
+
47
+ QLearning .QLConfiguration CARTPOLE_QL = QLearning .QLConfiguration .builder ()
48
+ .seed (123 ) //Random seed (for reproducability)
49
+ .maxEpochStep (200 ) // Max step By epoch
50
+ .maxStep (15000 ) // Max step
51
+ .expRepMaxSize (150000 ) // Max size of experience replay
52
+ .batchSize (128 ) // size of batches
53
+ .targetDqnUpdateFreq (500 ) // target update (hard)
54
+ .updateStart (10 ) // num step noop warmup
55
+ .rewardFactor (0.01 ) // reward scaling
56
+ .gamma (0.99 ) // gamma
57
+ .errorClamp (1.0 ) // /td-error clipping
58
+ .minEpsilon (0.1f ) // min epsilon
59
+ .epsilonNbStep (1000 ) // num step for eps greedy anneal
60
+ .doubleDQN (true ) // double DQN
61
+ .build ();
62
+
63
+ // The neural network used by the agent. Note that there is no need to specify the number of inputs/outputs.
64
+ // These will be read from the gym environment at the start of training.
65
+ DQNFactoryStdDense .Configuration CARTPOLE_NET =
66
+ DQNFactoryStdDense .Configuration .builder ()
67
+ .l2 (0 )
68
+ .updater (new RmsProp (0.000025 ))
69
+ .numHiddenNodes (300 )
70
+ .numLayer (2 )
71
+ .build ();
72
+
73
+ //Create the gym environment. We include these through the rl4j-gym dependency.
74
+ GymEnv <Box , Integer , DiscreteSpace > mdp = new GymEnv <Box , Integer , DiscreteSpace >(envUD , false , false );
75
+
76
+ //Create the solver. This class implements the 2013 article by Mnih et al. from deepmind.
77
+ // https://arxiv.org/pdf/1312.5602.pdf
71
78
QLearningDiscreteDense <Box > dql = new QLearningDiscreteDense <Box >(mdp , CARTPOLE_NET , CARTPOLE_QL );
72
79
73
80
dql .train ();
74
81
mdp .close ();
75
82
76
- return dql .getPolicy (); //get the final policy
83
+ return dql .getPolicy (); //return the trained agent.
77
84
}
78
85
79
86
private static void loadCartpole (DQNPolicy <Box > pol ) {
80
87
//use the trained agent on a new similar mdp (but render it this time)
81
88
82
89
//define the mdp from gym (name, render)
83
- GymEnv <Box , Integer , ActionSpace <Integer >> mdp2 = new GymEnv <Box , Integer , ActionSpace <Integer >>("CartPole-v0" , true , false );
90
+ GymEnv <Box , Integer , ActionSpace <Integer >> mdp2 = new GymEnv <Box , Integer , ActionSpace <Integer >>(envUD , true , false );
84
91
85
92
//evaluate the agent
86
93
double rewards = 0 ;
@@ -92,5 +99,6 @@ private static void loadCartpole(DQNPolicy<Box> pol) {
92
99
}
93
100
94
101
Logger .getAnonymousLogger ().info ("average: " + rewards /1000 );
102
+ mdp2 .close ();
95
103
}
96
104
}
0 commit comments