Skip to content

Commit 2fa23a1

Browse files
RobAltenatreo
authored andcommitted
first commit.
Signed-off-by: Robert Altena <[email protected]>
1 parent 8c5b8a7 commit 2fa23a1

File tree

3 files changed

+177
-0
lines changed

3 files changed

+177
-0
lines changed

rl4j-cartpole-examples/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
/.idea
2+
/target
3+
*.iml

rl4j-cartpole-examples/pom.xml

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0"
3+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
4+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
5+
<modelVersion>4.0.0</modelVersion>
6+
7+
<groupId>cartpole</groupId>
8+
<artifactId>cartpole</artifactId>
9+
<version>1.0-SNAPSHOT</version>
10+
<properties>
11+
<nd4j.version>1.0.0-SNAPSHOT</nd4j.version>
12+
<rl4j.version>1.0.0-SNAPSHOT</rl4j.version>
13+
<logback.version>1.1.7</logback.version>
14+
15+
<nd4j.backend>nd4j-native-platform</nd4j.backend>
16+
</properties>
17+
18+
<repositories>
19+
<repository>
20+
<id>snapshots-repo</id>
21+
<url>https://oss.sonatype.org/content/repositories/snapshots</url>
22+
<releases>
23+
<enabled>false</enabled>
24+
</releases>
25+
<snapshots>
26+
<enabled>true</enabled>
27+
<updatePolicy>daily</updatePolicy> <!-- Optional, update daily -->
28+
</snapshots>
29+
</repository>
30+
</repositories>
31+
32+
<dependencies>
33+
<!-- ND4J backend. You need one in every DL4J project. Normally define artifactId as either nd4j-native-platform or nd4j-cuda-X.X-platform to use CUDA GPUs (check parent pom for supported cuda versions) -->
34+
<dependency>
35+
<groupId>org.nd4j</groupId>
36+
<artifactId>${nd4j.backend}</artifactId>
37+
<version>${nd4j.version}</version>
38+
</dependency>
39+
40+
<dependency>
41+
<groupId>org.deeplearning4j</groupId>
42+
<artifactId>rl4j-core</artifactId>
43+
<version>${rl4j.version}</version>
44+
</dependency>
45+
<dependency>
46+
<groupId>org.deeplearning4j</groupId>
47+
<artifactId>rl4j-gym</artifactId>
48+
<version>${rl4j.version}</version>
49+
</dependency>
50+
<dependency>
51+
<groupId>org.deeplearning4j</groupId>
52+
<artifactId>rl4j-ale</artifactId>
53+
<version>${rl4j.version}</version>
54+
</dependency>
55+
<!-- The Arcade Learning Environment (ALE) is under GPL license, so we cannot use it as a dependency of RL4J. -->
56+
<!--
57+
<dependency>
58+
<groupId>org.bytedeco</groupId>
59+
<artifactId>ale-platform</artifactId>
60+
<version>0.6.0-1.5</version>
61+
</dependency>
62+
-->
63+
<dependency>
64+
<groupId>org.deeplearning4j</groupId>
65+
<artifactId>rl4j-malmo</artifactId>
66+
<version>${rl4j.version}</version>
67+
</dependency>
68+
<dependency>
69+
<groupId>com.microsoft.msr.malmo</groupId>
70+
<artifactId>MalmoJavaJar</artifactId>
71+
<version>0.30.0</version>
72+
</dependency>
73+
<dependency>
74+
<groupId>junit</groupId>
75+
<artifactId>junit</artifactId>
76+
<version>3.8.1</version>
77+
<scope>test</scope>
78+
</dependency>
79+
<dependency>
80+
<groupId>ch.qos.logback</groupId>
81+
<artifactId>logback-classic</artifactId>
82+
<version>${logback.version}</version>
83+
</dependency>
84+
</dependencies>
85+
</project>
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/* *****************************************************************************
2+
* Copyright (c) 2015-2019 Skymind, Inc.
3+
*
4+
* This program and the accompanying materials are made available under the
5+
* terms of the Apache License, Version 2.0 which is available at
6+
* https://www.apache.org/licenses/LICENSE-2.0.
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
10+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
11+
* License for the specific language governing permissions and limitations
12+
* under the License.
13+
*
14+
* SPDX-License-Identifier: Apache-2.0
15+
******************************************************************************/
16+
17+
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
18+
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscreteDense;
19+
import org.deeplearning4j.rl4j.mdp.gym.GymEnv;
20+
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
21+
import org.deeplearning4j.rl4j.policy.DQNPolicy;
22+
import org.deeplearning4j.rl4j.space.Box;
23+
import org.nd4j.linalg.learning.config.Adam;
24+
25+
import java.io.IOException;
26+
import java.util.logging.Logger;
27+
28+
/**
29+
* @author rubenfiszel ([email protected]) on 8/11/16.
30+
*
31+
* Main example for Cartpole DQN
32+
*/
33+
public class Cartpole
34+
{
35+
private static QLearning.QLConfiguration CARTPOLE_QL =
36+
new QLearning.QLConfiguration(
37+
123, //Random seed
38+
200, //Max step By epoch
39+
150000, //Max step
40+
150000, //Max size of experience replay
41+
32, //size of batches
42+
500, //target update (hard)
43+
10, //num step noop warmup
44+
0.01, //reward scaling
45+
0.99, //gamma
46+
1.0, //td-error clipping
47+
0.1f, //min epsilon
48+
1000, //num step for eps greedy anneal
49+
true //double DQN
50+
);
51+
52+
private static DQNFactoryStdDense.Configuration CARTPOLE_NET =
53+
DQNFactoryStdDense.Configuration.builder()
54+
.l2(0.001).updater(new Adam(0.0005)).numHiddenNodes(16).numLayer(3).build();
55+
56+
public static void main(String[] args) throws IOException {
57+
DQNPolicy<Box> pol = cartPole();
58+
loadCartpole(pol);
59+
}
60+
61+
private static DQNPolicy<Box> cartPole() throws IOException {
62+
//define the mdp from gym (name, render)
63+
GymEnv<Box, Integer, org.deeplearning4j.rl4j.space.DiscreteSpace> mdp = new GymEnv<Box, Integer, org.deeplearning4j.rl4j.space.DiscreteSpace>("CartPole-v0", false, false);
64+
QLearningDiscreteDense<Box> dql = new QLearningDiscreteDense<Box>(mdp, CARTPOLE_NET, CARTPOLE_QL);
65+
66+
dql.train();
67+
mdp.close();
68+
69+
return dql.getPolicy(); //get the final policy
70+
}
71+
72+
private static void loadCartpole(DQNPolicy<Box> pol) throws IOException {
73+
//use the trained agent on a new similar mdp (but render it this time)
74+
75+
//define the mdp from gym (name, render)
76+
GymEnv<Box, Integer, org.deeplearning4j.rl4j.space.ActionSpace<Integer>> mdp2 = new GymEnv<Box, Integer, org.deeplearning4j.rl4j.space.ActionSpace<Integer>>("CartPole-v0", true, false);
77+
78+
//evaluate the agent
79+
double rewards = 0;
80+
for (int i = 0; i < 1000; i++) {
81+
mdp2.reset();
82+
double reward = pol.play(mdp2);
83+
rewards += reward;
84+
Logger.getAnonymousLogger().info("Reward: " + reward);
85+
}
86+
87+
Logger.getAnonymousLogger().info("average: " + rewards/1000);
88+
}
89+
}

0 commit comments

Comments
 (0)