Skip to content

Commit 55be9bb

Browse files
committed
recurrent basic and character.
Signed-off-by: Robert Altena <[email protected]>
1 parent 9f891d8 commit 55be9bb

File tree

4 files changed

+26
-22
lines changed

4 files changed

+26
-22
lines changed

dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/basic/BasicRNNExample.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -35,7 +35,6 @@
3535
import java.util.ArrayList;
3636
import java.util.LinkedHashSet;
3737
import java.util.List;
38-
import java.util.Random;
3938

4039
/**
4140
* This example trains a RNN. When trained we only have to put the first
@@ -55,7 +54,6 @@ public class BasicRNNExample {
5554
// RNN dimensions
5655
private static final int HIDDEN_LAYER_WIDTH = 50;
5756
private static final int HIDDEN_LAYER_CONT = 2;
58-
private static final Random r = new Random(7894);
5957

6058
public static void main(String[] args) {
6159

@@ -123,7 +121,7 @@ public static void main(String[] args) {
123121
DataSet trainingData = new DataSet(input, labels);
124122

125123
// some epochs
126-
for (int epoch = 0; epoch < 100; epoch++) {
124+
for (int epoch = 0; epoch < 1000; epoch++) {
127125

128126
System.out.println("Epoch " + epoch);
129127

@@ -143,7 +141,7 @@ public static void main(String[] args) {
143141
INDArray output = net.rnnTimeStep(testInit);
144142

145143
// now the net should guess LEARNSTRING.length more characters
146-
for (char dummy : LEARNSTRING) {
144+
for (char ignored : LEARNSTRING) {
147145

148146
// first process the last output of the network to a concrete
149147
// neuron, the neuron with the highest output has the highest

dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/character/CharacterIterator.java

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -62,8 +62,8 @@ public class CharacterIterator implements DataSetIterator {
6262
* @param rng Random number generator, for repeatability if required
6363
* @throws IOException If text file cannot be loaded
6464
*/
65-
public CharacterIterator(String textFilePath, Charset textFileEncoding, int miniBatchSize, int exampleLength,
66-
char[] validCharacters, Random rng) throws IOException {
65+
CharacterIterator(String textFilePath, Charset textFileEncoding, int miniBatchSize, int exampleLength,
66+
char[] validCharacters, Random rng) throws IOException {
6767
this(textFilePath,textFileEncoding,miniBatchSize,exampleLength,validCharacters,rng,null);
6868
}
6969
/**
@@ -130,7 +130,7 @@ public CharacterIterator(String textFilePath, Charset textFileEncoding, int mini
130130
}
131131

132132
/** A minimal character set, with a-z, A-Z, 0-9 and common punctuation etc */
133-
public static char[] getMinimalCharacterSet(){
133+
static char[] getMinimalCharacterSet(){
134134
List<Character> validChars = new LinkedList<>();
135135
for(char c='a'; c<='z'; c++) validChars.add(c);
136136
for(char c='A'; c<='Z'; c++) validChars.add(c);
@@ -144,7 +144,8 @@ public static char[] getMinimalCharacterSet(){
144144
}
145145

146146
/** As per getMinimalCharacterSet(), but with a few extra characters */
147-
public static char[] getDefaultCharacterSet(){
147+
@SuppressWarnings("unused")
148+
public static char[] getDefaultCharacterSet(){
148149
List<Character> validChars = new LinkedList<>();
149150
for(char c : getMinimalCharacterSet() ) validChars.add(c);
150151
char[] additionalChars = {'@', '#', '$', '%', '^', '*', '{', '}', '[', ']', '/', '+', '_',
@@ -205,7 +206,7 @@ public DataSet next(int num) {
205206
return new DataSet(input,labels);
206207
}
207208

208-
public int totalExamples() {
209+
private int totalExamples() {
209210
return (fileCharacters.length-1) / miniBatchSize - 2;
210211
}
211212

@@ -244,11 +245,13 @@ public int batch() {
244245
return miniBatchSize;
245246
}
246247

247-
public int cursor() {
248+
@SuppressWarnings("unused")
249+
public int cursor() {
248250
return totalExamples() - exampleStartOffsets.size();
249251
}
250252

251-
public int numExamples() {
253+
@SuppressWarnings("unused")
254+
public int numExamples() {
252255
return totalExamples();
253256
}
254257

dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/character/CompGraphLSTMExample.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -46,9 +46,11 @@
4646
*
4747
* @author Alex Black
4848
*/
49+
@SuppressWarnings("DuplicatedCode")
4950
public class CompGraphLSTMExample {
5051

51-
public static void main( String[] args ) throws Exception {
52+
@SuppressWarnings("ConstantConditions")
53+
public static void main(String[] args ) throws Exception {
5254
int lstmLayerSize = 200; //Number of units in each LSTM layer
5355
int miniBatchSize = 32; //Size of mini batch to use when training
5456
int exampleLength = 1000; //Length of each training example sequence to use. This could certainly be increased

dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/character/LSTMCharModellingExample.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -32,13 +32,12 @@
3232
import org.nd4j.linalg.indexing.BooleanIndexing;
3333
import org.nd4j.linalg.indexing.conditions.Conditions;
3434
import org.nd4j.linalg.learning.config.Adam;
35-
import org.nd4j.linalg.learning.config.Nadam;
3635
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
3736

3837
import java.io.File;
3938
import java.io.IOException;
4039
import java.net.URL;
41-
import java.nio.charset.Charset;
40+
import java.nio.charset.StandardCharsets;
4241
import java.util.Random;
4342

4443
/**LSTM Character modelling example
@@ -57,8 +56,10 @@
5756
http://deeplearning4j.org/lstm
5857
http://deeplearning4j.org/recurrentnetwork
5958
*/
59+
@SuppressWarnings("DuplicatedCode")
6060
public class LSTMCharModellingExample {
61-
public static void main( String[] args ) throws Exception {
61+
@SuppressWarnings("ConstantConditions")
62+
public static void main(String[] args ) throws Exception {
6263
int lstmLayerSize = 200; //Number of units in each LSTM layer
6364
int miniBatchSize = 32; //Size of mini batch to use when training
6465
int exampleLength = 1000; //Length of each training example sequence to use. This could certainly be increased
@@ -130,7 +131,7 @@ public static void main( String[] args ) throws Exception {
130131
* @param miniBatchSize Number of text segments in each training mini-batch
131132
* @param sequenceLength Number of characters in each text segment.
132133
*/
133-
public static CharacterIterator getShakespeareIterator(int miniBatchSize, int sequenceLength) throws Exception{
134+
static CharacterIterator getShakespeareIterator(int miniBatchSize, int sequenceLength) throws Exception{
134135
//The Complete Works of William Shakespeare
135136
//5.3MB file in UTF-8 Encoding, ~5.4 million characters
136137
//https://www.gutenberg.org/ebooks/100
@@ -148,7 +149,7 @@ public static CharacterIterator getShakespeareIterator(int miniBatchSize, int se
148149
if(!f.exists()) throw new IOException("File does not exist: " + fileLocation); //Download problem?
149150

150151
char[] validCharacters = CharacterIterator.getMinimalCharacterSet(); //Which characters are allowed? Others will be removed
151-
return new CharacterIterator(fileLocation, Charset.forName("UTF-8"),
152+
return new CharacterIterator(fileLocation, StandardCharsets.UTF_8,
152153
miniBatchSize, sequenceLength, validCharacters, new Random(12345));
153154
}
154155

@@ -210,7 +211,7 @@ private static String[] sampleCharactersFromNetwork(String initialization, Multi
210211
* and return the generated class index.
211212
* @param distribution Probability distribution over classes. Must sum to 1.0
212213
*/
213-
public static int sampleFromDistribution( double[] distribution, Random rng ){
214+
static int sampleFromDistribution(double[] distribution, Random rng){
214215
double d = 0.0;
215216
double sum = 0.0;
216217
for( int t=0; t<10; t++ ) {

0 commit comments

Comments
 (0)