Skip to content

Commit 7b0a63c

Browse files
RobAltenatreo
authored andcommitted
add test.
Signed-off-by: Paul Dubs <[email protected]>
1 parent 1b023ab commit 7b0a63c

File tree

3 files changed

+77
-4
lines changed

3 files changed

+77
-4
lines changed

rl4j-ale-examples/pom.xml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@
6363
<version>${rl4j.version}</version>
6464
</dependency>
6565
<dependency>
66-
<groupId>junit</groupId>
67-
<artifactId>junit</artifactId>
68-
<version>3.8.1</version>
66+
<groupId>org.junit.jupiter</groupId>
67+
<artifactId>junit-jupiter-engine</artifactId>
68+
<version>5.4.2</version>
6969
<scope>test</scope>
7070
</dependency>
7171
<dependency>

rl4j-ale-examples/src/main/java/PlayALE.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
public class PlayALE {
88
public static void main(String[] args) throws IOException {
9-
ALEMDP mdp = new ALEMDP("pong.bin");
9+
ALEMDP mdp = new ALEMDP("E:\\projects\\ArcadeLearningEnvironment\\pong.bin");
1010

1111
//load the previous agent
1212
ACPolicy<ALEMDP.GameScreen> pol2 = ACPolicy.load("ale-a3c.model");
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import org.deeplearning4j.nn.api.NeuralNetwork;
2+
import org.deeplearning4j.nn.graph.ComputationGraph;
3+
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
4+
import org.deeplearning4j.rl4j.learning.async.a3c.discrete.A3CDiscrete;
5+
import org.deeplearning4j.rl4j.learning.async.a3c.discrete.A3CDiscreteConv;
6+
import org.deeplearning4j.rl4j.mdp.ale.ALEMDP;
7+
import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraphStdConv;
8+
import org.deeplearning4j.rl4j.policy.ACPolicy;
9+
import org.junit.jupiter.api.Test;
10+
import org.nd4j.linalg.api.ndarray.INDArray;
11+
import org.nd4j.linalg.factory.Nd4j;
12+
import org.nd4j.linalg.learning.config.Adam;
13+
14+
import java.io.IOException;
15+
16+
public class AleTest {
17+
18+
@Test
19+
public void TrainModelDataShape(){
20+
21+
// Set up the training as in the sample.
22+
HistoryProcessor.Configuration ALE_HP = new HistoryProcessor.Configuration(
23+
4, //History length
24+
84, //resize width
25+
110, //resize height
26+
84, //crop width
27+
84, //crop height
28+
0, //cropping x offset
29+
0, //cropping y offset
30+
4 //skip mod (one frame is picked every x
31+
);
32+
33+
A3CDiscrete.A3CConfiguration ALE_A3C = new A3CDiscrete.A3CConfiguration(
34+
123, //Random seed
35+
10000, //Max step By epoch
36+
8000000, //Max step
37+
8, //Number of threads
38+
32, //t_max
39+
500, //num step noop warmup
40+
0.1, //reward scaling
41+
0.99, //gamma
42+
10.0 //td-error clipping
43+
);
44+
45+
final ActorCriticFactoryCompGraphStdConv.Configuration ALE_NET_A3C =
46+
new ActorCriticFactoryCompGraphStdConv.Configuration(
47+
0.000, //l2 regularization
48+
new Adam(0.00025), //learning rate
49+
null, false
50+
);
51+
ALEMDP mdp = new ALEMDP("pong.bin");
52+
A3CDiscreteConv<ALEMDP.GameScreen> a3c = new A3CDiscreteConv<ALEMDP.GameScreen>(mdp, ALE_NET_A3C, ALE_HP, ALE_A3C);
53+
54+
NeuralNetwork [] nns = a3c.getNeuralNet().getNeuralNetworks();
55+
ComputationGraph g = (ComputationGraph ) nns[0];
56+
57+
// Now pass in some dummy data in the expected shape.
58+
INDArray dummy = Nd4j.rand( 1,4, 84, 84);
59+
g.output(new INDArray[] {dummy}); //If we get the shape wrong we crash here.
60+
}
61+
62+
@Test
63+
void LoadModel() throws IOException {
64+
//load the previous agent
65+
ACPolicy<ALEMDP.GameScreen> pol = ACPolicy.load("ale-a3c.model");
66+
NeuralNetwork [] nns = pol.getNeuralNet().getNeuralNetworks();
67+
ComputationGraph g = (ComputationGraph ) nns[0];
68+
69+
// Now pass in some dummy data in the expected shape.
70+
INDArray dummy = Nd4j.rand( 1,4, 84, 84);
71+
g.output(new INDArray[] {dummy}); //If we get the shape wrong we crash here.
72+
}
73+
}

0 commit comments

Comments
 (0)