Skip to content

Commit bda1993

Browse files
committed
seq2seseq2seqq
Signed-off-by: Robert Altena <[email protected]>
1 parent a66ad14 commit bda1993

File tree

3 files changed

+36
-44
lines changed

3 files changed

+36
-44
lines changed

dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/seq2seq/AdditionRNN.java

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -27,7 +27,6 @@
2727
import org.deeplearning4j.nn.weights.WeightInit;
2828
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
2929
import org.nd4j.linalg.activations.Activation;
30-
import org.nd4j.linalg.api.buffer.DataBuffer;
3130
import org.nd4j.linalg.api.buffer.DataType;
3231
import org.nd4j.linalg.api.ndarray.INDArray;
3332
import org.nd4j.linalg.dataset.api.MultiDataSet;
@@ -79,27 +78,27 @@ public class AdditionRNN {
7978
To try out addition for numbers with different number of digits simply change "NUM_DIGITS"
8079
*/
8180

82-
public static final int NUM_DIGITS =2;
81+
static final int NUM_DIGITS =2;
8382
//Random number generator seed, for reproducability
8483
public static final int seed = 1234;
8584

8685
//Tweak these to tune the dataset size = batchSize * totalBatches
8786
public static int batchSize = 10;
88-
public static int totalBatches = 500;
8987
public static int nEpochs = 10;
9088

9189
//Tweak the number of hidden nodes
92-
public static final int numHiddenNodes = 128;
90+
private static final int numHiddenNodes = 128;
9391

9492
//This is the size of the one hot vector
95-
public static final int FEATURE_VEC_SIZE = 14;
93+
static final int FEATURE_VEC_SIZE = 14;
9694

97-
public static void main(String[] args) throws Exception {
95+
public static void main(String[] args) {
9896

9997
//DataType is set to double for higher precision
10098
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
10199

102100
//This is a custom iterator that returns MultiDataSets on each call of next - More details in comments in the class
101+
int totalBatches = 500;
103102
CustomSequenceIterator iterator = new CustomSequenceIterator(seed, batchSize, totalBatches);
104103

105104
ComputationGraphConfiguration configuration = new NeuralNetConfiguration.Builder()

dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/seq2seq/CustomSequenceIterator.java

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -37,25 +37,23 @@
3737
* Sequences generated during test are never before seen by the net
3838
* The random number generator seed is used for repeatability so that each reset of the iterator gives the same data in the same order.
3939
*/
40-
4140
public class CustomSequenceIterator implements MultiDataSetIterator {
4241

43-
private MultiDataSetPreProcessor preProcessor;
4442
private Random randnumG;
4543
private final int seed;
4644
private final int batchSize;
4745
private final int totalBatches;
4846

4947
private static final int numDigits = AdditionRNN.NUM_DIGITS;
50-
public static final int SEQ_VECTOR_DIM = AdditionRNN.FEATURE_VEC_SIZE;
51-
public static final Map<String, Integer> oneHotMap = new HashMap<String, Integer>();
52-
public static final String[] oneHotOrder = new String[SEQ_VECTOR_DIM];
48+
private static final int SEQ_VECTOR_DIM = AdditionRNN.FEATURE_VEC_SIZE;
49+
private static final Map<String, Integer> oneHotMap = new HashMap<>();
50+
private static final String[] oneHotOrder = new String[SEQ_VECTOR_DIM];
5351

54-
private Set<String> seenSequences = new HashSet<String>();
52+
private Set<String> seenSequences = new HashSet<>();
5553
private boolean toTestSet = false;
5654
private int currentBatch = 0;
5755

58-
public CustomSequenceIterator(int seed, int batchSize, int totalBatches) {
56+
CustomSequenceIterator(int seed, int batchSize, int totalBatches) {
5957

6058
this.seed = seed;
6159
this.randnumG = new Random(seed);
@@ -66,7 +64,7 @@ public CustomSequenceIterator(int seed, int batchSize, int totalBatches) {
6664
oneHotEncoding();
6765
}
6866

69-
public MultiDataSet generateTest(int testSize) {
67+
MultiDataSet generateTest(int testSize) {
7068
toTestSet = true;
7169
MultiDataSet testData = next(testSize);
7270
reset();
@@ -86,7 +84,7 @@ public MultiDataSet next(int sampleSize) {
8684
while (true) {
8785
num1 = randnumG.nextInt((int) Math.pow(10, numDigits));
8886
num2 = randnumG.nextInt((int) Math.pow(10, numDigits));
89-
String forSum = String.valueOf(num1) + "+" + String.valueOf(num2);
87+
String forSum = num1 + "+" + num2;
9088
if (seenSequences.add(forSum)) {
9189
break;
9290
}
@@ -126,7 +124,7 @@ public MultiDataSet next(int sampleSize) {
126124
public void reset() {
127125
currentBatch = 0;
128126
toTestSet = false;
129-
seenSequences = new HashSet<String>();
127+
seenSequences = new HashSet<>();
130128
randnumG = new Random(seed);
131129
}
132130

@@ -166,17 +164,17 @@ public MultiDataSetPreProcessor getPreProcessor() {
166164
Note that the string is padded to the correct length and reversed
167165
Eg. num1 = 7, num 2 = 13 will return {"3","1","+","7"," "}
168166
*/
169-
public String[] prepToString(int num1, int num2) {
167+
private String[] prepToString(int num1, int num2) {
170168

171169
String[] encoded = new String[numDigits * 2 + 1];
172-
String num1S = String.valueOf(num1);
173-
String num2S = String.valueOf(num2);
170+
StringBuilder num1S = new StringBuilder(String.valueOf(num1));
171+
StringBuilder num2S = new StringBuilder(String.valueOf(num2));
174172
//padding
175173
while (num1S.length() < numDigits) {
176-
num1S = " " + num1S;
174+
num1S.insert(0, " ");
177175
}
178176
while (num2S.length() < numDigits) {
179-
num2S = " " + num2S;
177+
num2S.insert(0, " ");
180178
}
181179

182180
String sumString = num1S + "+" + num2S;
@@ -199,7 +197,7 @@ Given a number, return a string array which represents the decoder input (or out
199197
if !goFirst will return {"3","1"," ","eos"}
200198
201199
*/
202-
public String[] prepToString(int sum, boolean goFirst) {
200+
private String[] prepToString(int sum, boolean goFirst) {
203201
int start, end;
204202
String[] decoded = new String[numDigits + 1 + 1];
205203
if (goFirst) {
@@ -244,24 +242,21 @@ private static INDArray mapToOneHot(String[] toEncode) {
244242
return ret;
245243
}
246244

247-
public static String mapToString (INDArray encodeSeq, INDArray decodeSeq) {
248-
return mapToString(encodeSeq,decodeSeq," --> ");
249-
}
250-
public static String mapToString(INDArray encodeSeq, INDArray decodeSeq, String sep) {
251-
String ret = "";
245+
static String mapToString(INDArray encodeSeq, INDArray decodeSeq) {
246+
StringBuilder ret = new StringBuilder();
252247
String [] encodeSeqS = oneHotDecode(encodeSeq);
253248
String [] decodeSeqS = oneHotDecode(decodeSeq);
254249
for (int i=0; i<encodeSeqS.length;i++) {
255-
ret += "\t" + encodeSeqS[i] + sep +decodeSeqS[i] + "\n";
250+
ret.append("\t").append(encodeSeqS[i]).append(" + ").append(decodeSeqS[i]).append("\n");
256251
}
257-
return ret;
252+
return ret.toString();
258253
}
259254

260255
/*
261256
Helper method that takes in a one hot encoded INDArray and returns an interpreted array of strings
262257
toInterpret size batchSize x one_hot_vector_size(14) x time_steps
263258
*/
264-
public static String[] oneHotDecode(INDArray toInterpret) {
259+
static String[] oneHotDecode(INDArray toInterpret) {
265260

266261
String[] decodedString = new String[(int)toInterpret.size(0)];
267262
INDArray oneHotIndices = Nd4j.argMax(toInterpret, 1); //drops a dimension, so now a two dim array of shape batchSize x time_steps
@@ -273,15 +268,15 @@ public static String[] oneHotDecode(INDArray toInterpret) {
273268
}
274269

275270
private static String mapFromOneHot(int[] toMap) {
276-
String ret = "";
277-
for (int i = 0; i < toMap.length; i++) {
278-
ret += oneHotOrder[toMap[i]];
271+
StringBuilder ret = new StringBuilder();
272+
for (int value : toMap) {
273+
ret.append(oneHotOrder[value]);
279274
}
280275
//encoder sequence, needs to be reversed
281276
if (toMap.length > numDigits + 1 + 1) {
282-
return new StringBuilder(ret).reverse().toString();
277+
return new StringBuilder(ret.toString()).reverse().toString();
283278
}
284-
return ret;
279+
return ret.toString();
285280
}
286281

287282
/*
@@ -308,6 +303,5 @@ private static void oneHotEncoding() {
308303
}
309304

310305
public void setPreProcessor(MultiDataSetPreProcessor preProcessor) {
311-
this.preProcessor = preProcessor;
312306
}
313307
}

dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/seq2seq/Seq2SeqPredicter.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -19,7 +19,6 @@
1919
import org.deeplearning4j.nn.graph.ComputationGraph;
2020
import org.nd4j.linalg.api.ndarray.INDArray;
2121
import org.nd4j.linalg.dataset.api.MultiDataSet;
22-
import org.nd4j.linalg.factory.Nd4j;
2322
import org.nd4j.linalg.indexing.NDArrayIndex;
2423

2524
/**
@@ -36,7 +35,7 @@ public class Seq2SeqPredicter {
3635
private ComputationGraph net;
3736
private INDArray decoderInputTemplate = null;
3837

39-
public Seq2SeqPredicter(ComputationGraph net) {
38+
Seq2SeqPredicter(ComputationGraph net) {
4039
this.net = net;
4140
}
4241

@@ -58,7 +57,7 @@ public INDArray output(MultiDataSet testSet) {
5857
public INDArray output(MultiDataSet testSet, boolean print) {
5958

6059
INDArray correctOutput = testSet.getLabels()[0];
61-
INDArray ret = Nd4j.zeros(correctOutput.shape());
60+
INDArray ret;
6261
decoderInputTemplate = testSet.getFeatures()[1].dup();
6362

6463
int currentStepThrough = 0;
@@ -68,7 +67,7 @@ public INDArray output(MultiDataSet testSet, boolean print) {
6867
if (print) {
6968
System.out.println("In time step "+currentStepThrough);
7069
System.out.println("\tEncoder input and Decoder input:");
71-
System.out.println(CustomSequenceIterator.mapToString(testSet.getFeatures()[0],decoderInputTemplate, " + "));
70+
System.out.println(CustomSequenceIterator.mapToString(testSet.getFeatures()[0],decoderInputTemplate));
7271

7372
}
7473
ret = stepOnce(testSet, currentStepThrough);
@@ -83,7 +82,7 @@ public INDArray output(MultiDataSet testSet, boolean print) {
8382
if (print) {
8483
System.out.println("Final time step "+currentStepThrough);
8584
System.out.println("\tEncoder input and Decoder input:");
86-
System.out.println(CustomSequenceIterator.mapToString(testSet.getFeatures()[0],decoderInputTemplate, " + "));
85+
System.out.println(CustomSequenceIterator.mapToString(testSet.getFeatures()[0],decoderInputTemplate));
8786
System.out.println("\tDecoder output:");
8887
System.out.println("\t"+String.join("\n\t",CustomSequenceIterator.oneHotDecode(ret)));
8988
}

0 commit comments

Comments
 (0)