Skip to content

Commit 1b023ab

Browse files
RobAltenatreo
authored andcommitted
first commit
Signed-off-by: Paul Dubs <[email protected]>
1 parent 76d4760 commit 1b023ab

File tree

5 files changed

+272
-0
lines changed

5 files changed

+272
-0
lines changed

rl4j-ale-examples/.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
/.idea
2+
/target
3+
*.iml
4+
pong.bin
5+
ale-a3c.model

rl4j-ale-examples/pom.xml

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0"
3+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
4+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
5+
<modelVersion>4.0.0</modelVersion>
6+
7+
<groupId>ArcadeLearningEnvironment</groupId>
8+
<artifactId>ArcadeLearningEnvironment</artifactId>
9+
<version>1.0-SNAPSHOT</version>
10+
<properties>
11+
<nd4j.version>1.0.0-SNAPSHOT</nd4j.version>
12+
<rl4j.version>1.0.0-SNAPSHOT</rl4j.version>
13+
<logback.version>1.1.7</logback.version>
14+
15+
<nd4j.backend>nd4j-native-platform</nd4j.backend>
16+
</properties>
17+
18+
<repositories>
19+
<repository>
20+
<id>snapshots-repo</id>
21+
<url>https://oss.sonatype.org/content/repositories/snapshots</url>
22+
<releases>
23+
<enabled>false</enabled>
24+
</releases>
25+
<snapshots>
26+
<enabled>true</enabled>
27+
<updatePolicy>daily</updatePolicy> <!-- Optional, update daily -->
28+
</snapshots>
29+
</repository>
30+
</repositories>
31+
<dependencies>
32+
<!-- ND4J backend. You need one in every DL4J project. Normally define artifactId as either nd4j-native-platform or nd4j-cuda-X.X-platform to use CUDA GPUs (check parent pom for supported cuda versions) -->
33+
<dependency>
34+
<groupId>org.nd4j</groupId>
35+
<artifactId>${nd4j.backend}</artifactId>
36+
<version>${nd4j.version}</version>
37+
</dependency>
38+
39+
<dependency>
40+
<groupId>org.deeplearning4j</groupId>
41+
<artifactId>rl4j-core</artifactId>
42+
<version>${rl4j.version}</version>
43+
</dependency>
44+
<dependency>
45+
<groupId>org.deeplearning4j</groupId>
46+
<artifactId>rl4j-gym</artifactId>
47+
<version>${rl4j.version}</version>
48+
</dependency>
49+
<dependency>
50+
<groupId>org.deeplearning4j</groupId>
51+
<artifactId>rl4j-ale</artifactId>
52+
<version>${rl4j.version}</version>
53+
</dependency>
54+
<!-- The Arcade Learning Environment (ALE) is under GPL license, so we cannot use it as a dependency of RL4J. -->
55+
<dependency>
56+
<groupId>org.bytedeco</groupId>
57+
<artifactId>ale-platform</artifactId>
58+
<version>0.6.0-1.5.2</version>
59+
</dependency>
60+
<dependency>
61+
<groupId>org.deeplearning4j</groupId>
62+
<artifactId>rl4j-malmo</artifactId>
63+
<version>${rl4j.version}</version>
64+
</dependency>
65+
<dependency>
66+
<groupId>junit</groupId>
67+
<artifactId>junit</artifactId>
68+
<version>3.8.1</version>
69+
<scope>test</scope>
70+
</dependency>
71+
<dependency>
72+
<groupId>ch.qos.logback</groupId>
73+
<artifactId>logback-classic</artifactId>
74+
<version>${logback.version}</version>
75+
</dependency>
76+
</dependencies>
77+
78+
</project>
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/* *****************************************************************************
2+
* Copyright (c) 2015-2019 Skymind, Inc.
3+
*
4+
* This program and the accompanying materials are made available under the
5+
* terms of the Apache License, Version 2.0 which is available at
6+
* https://www.apache.org/licenses/LICENSE-2.0.
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
10+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
11+
* License for the specific language governing permissions and limitations
12+
* under the License.
13+
*
14+
* SPDX-License-Identifier: Apache-2.0
15+
******************************************************************************/
16+
17+
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
18+
import org.deeplearning4j.rl4j.learning.async.a3c.discrete.A3CDiscrete;
19+
import org.deeplearning4j.rl4j.learning.async.a3c.discrete.A3CDiscreteConv;
20+
import org.deeplearning4j.rl4j.mdp.ale.ALEMDP;
21+
import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraphStdConv;
22+
import org.nd4j.linalg.learning.config.Adam;
23+
24+
import java.io.IOException;
25+
26+
/**
27+
* @author saudet
28+
*
29+
* Main example for A3C with The Arcade Learning Environment (ALE)
30+
*
31+
*/
32+
public class A3CALE {
33+
34+
public static void main(String[] args) throws IOException {
35+
HistoryProcessor.Configuration ALE_HP = new HistoryProcessor.Configuration(
36+
4, //History length
37+
84, //resize width
38+
110, //resize height
39+
84, //crop width
40+
84, //crop height
41+
0, //cropping x offset
42+
0, //cropping y offset
43+
4 //skip mod (one frame is picked every x
44+
);
45+
46+
A3CDiscrete.A3CConfiguration ALE_A3C = new A3CDiscrete.A3CConfiguration(
47+
123, //Random seed
48+
10000, //Max step By epoch
49+
8000000, //Max step
50+
8, //Number of threads
51+
32, //t_max
52+
500, //num step noop warmup
53+
0.1, //reward scaling
54+
0.99, //gamma
55+
10.0 //td-error clipping
56+
);
57+
58+
final ActorCriticFactoryCompGraphStdConv.Configuration ALE_NET_A3C =
59+
new ActorCriticFactoryCompGraphStdConv.Configuration(
60+
0.000, //l2 regularization
61+
new Adam(0.00025), //learning rate
62+
null, false
63+
);
64+
65+
66+
67+
//setup the emulation environment through ALE, you will need a ROM file
68+
ALEMDP mdp = new ALEMDP("pong.bin");
69+
70+
//setup the training
71+
A3CDiscreteConv<ALEMDP.GameScreen> a3c = new A3CDiscreteConv<ALEMDP.GameScreen>(mdp, ALE_NET_A3C, ALE_HP, ALE_A3C);
72+
73+
//start the training
74+
a3c.train();
75+
76+
//save the model at the end
77+
a3c.getPolicy().save("ale-a3c.model");
78+
79+
//close the ALE env
80+
mdp.close();
81+
}
82+
}
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/* *****************************************************************************
2+
* Copyright (c) 2015-2019 Skymind, Inc.
3+
*
4+
* This program and the accompanying materials are made available under the
5+
* terms of the Apache License, Version 2.0 which is available at
6+
* https://www.apache.org/licenses/LICENSE-2.0.
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
10+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
11+
* License for the specific language governing permissions and limitations
12+
* under the License.
13+
*
14+
* SPDX-License-Identifier: Apache-2.0
15+
******************************************************************************/
16+
17+
import java.io.IOException;
18+
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
19+
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
20+
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscreteConv;
21+
import org.deeplearning4j.rl4j.mdp.ale.ALEMDP;
22+
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
23+
24+
/**
25+
* @author saudet
26+
*
27+
* Main example for DQN with The Arcade Learning Environment (ALE)
28+
* This sample shows how to set up a simple ALE for training. This setup will take a long time to master the game.
29+
*/
30+
public class ALE {
31+
32+
public static void main(String[] args) throws IOException {
33+
34+
HistoryProcessor.Configuration ALE_HP = new HistoryProcessor.Configuration(
35+
4, //History length
36+
84, //resize width
37+
110, //resize height
38+
84, //crop width
39+
84, //crop height
40+
0, //cropping x offset
41+
0, //cropping y offset
42+
4 //skip mod (one frame is picked every x
43+
);
44+
45+
QLearning.QLConfiguration ALE_QL =
46+
new QLearning.QLConfiguration(
47+
123, //Random seed
48+
10000, //Max step By epoch
49+
8000000, //Max step
50+
1000000, //Max size of experience replay
51+
32, //size of batches
52+
10000, //target update (hard)
53+
500, //num step noop warmup
54+
0.1, //reward scaling
55+
0.99, //gamma
56+
100.0, //td-error clipping
57+
0.1f, //min epsilon
58+
100000, //num step for eps greedy anneal
59+
true //double-dqn
60+
);
61+
62+
DQNFactoryStdConv.Configuration ALE_NET_QL =
63+
new DQNFactoryStdConv.Configuration(
64+
0.00025, //learning rate
65+
0.000, //l2 regularization
66+
null, null
67+
);
68+
69+
//setup the emulation environment through ALE, you will need a ROM file
70+
// set render to true to see the agent play (poorly). You can also see how slowly the data is generated and
71+
// understand why training would take a long time.
72+
ALEMDP mdp = new ALEMDP("E:\\projects\\ArcadeLearningEnvironment\\pong.bin", false);
73+
74+
//setup the training
75+
QLearningDiscreteConv<ALEMDP.GameScreen> dql = new QLearningDiscreteConv<ALEMDP.GameScreen>(mdp, ALE_NET_QL, ALE_HP, ALE_QL);
76+
77+
dql.train(); //start the training
78+
dql.getPolicy().save("ale-dql.model"); //save the model at the end
79+
mdp.close();
80+
}
81+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import org.deeplearning4j.rl4j.mdp.ale.ALEMDP;
2+
import org.deeplearning4j.rl4j.policy.ACPolicy;
3+
4+
import java.io.IOException;
5+
import java.util.logging.Logger;
6+
7+
public class PlayALE {
8+
public static void main(String[] args) throws IOException {
9+
ALEMDP mdp = new ALEMDP("pong.bin");
10+
11+
//load the previous agent
12+
ACPolicy<ALEMDP.GameScreen> pol2 = ACPolicy.load("ale-a3c.model");
13+
14+
//evaluate the agent
15+
double rewards = 0;
16+
for (int i = 0; i < 10; i++) {
17+
mdp.reset();
18+
double reward = pol2.play(mdp);
19+
rewards += reward;
20+
Logger.getAnonymousLogger().info("Reward: " + reward);
21+
}
22+
23+
Logger.getAnonymousLogger().info("average: " + rewards/1000);
24+
25+
}
26+
}

0 commit comments

Comments
 (0)