1
- /** *****************************************************************************
1
+ /* *****************************************************************************
2
2
* Copyright (c) 2015-2019 Skymind, Inc.
3
3
*
4
4
* This program and the accompanying materials are made available under the
37
37
* Sequences generated during test are never before seen by the net
38
38
* The random number generator seed is used for repeatability so that each reset of the iterator gives the same data in the same order.
39
39
*/
40
-
41
40
public class CustomSequenceIterator implements MultiDataSetIterator {
42
41
43
- private MultiDataSetPreProcessor preProcessor ;
44
42
private Random randnumG ;
45
43
private final int seed ;
46
44
private final int batchSize ;
47
45
private final int totalBatches ;
48
46
49
47
private static final int numDigits = AdditionRNN .NUM_DIGITS ;
50
- public static final int SEQ_VECTOR_DIM = AdditionRNN .FEATURE_VEC_SIZE ;
51
- public static final Map <String , Integer > oneHotMap = new HashMap <String , Integer >();
52
- public static final String [] oneHotOrder = new String [SEQ_VECTOR_DIM ];
48
+ private static final int SEQ_VECTOR_DIM = AdditionRNN .FEATURE_VEC_SIZE ;
49
+ private static final Map <String , Integer > oneHotMap = new HashMap <>();
50
+ private static final String [] oneHotOrder = new String [SEQ_VECTOR_DIM ];
53
51
54
- private Set <String > seenSequences = new HashSet <String >();
52
+ private Set <String > seenSequences = new HashSet <>();
55
53
private boolean toTestSet = false ;
56
54
private int currentBatch = 0 ;
57
55
58
- public CustomSequenceIterator (int seed , int batchSize , int totalBatches ) {
56
+ CustomSequenceIterator (int seed , int batchSize , int totalBatches ) {
59
57
60
58
this .seed = seed ;
61
59
this .randnumG = new Random (seed );
@@ -66,7 +64,7 @@ public CustomSequenceIterator(int seed, int batchSize, int totalBatches) {
66
64
oneHotEncoding ();
67
65
}
68
66
69
- public MultiDataSet generateTest (int testSize ) {
67
+ MultiDataSet generateTest (int testSize ) {
70
68
toTestSet = true ;
71
69
MultiDataSet testData = next (testSize );
72
70
reset ();
@@ -86,7 +84,7 @@ public MultiDataSet next(int sampleSize) {
86
84
while (true ) {
87
85
num1 = randnumG .nextInt ((int ) Math .pow (10 , numDigits ));
88
86
num2 = randnumG .nextInt ((int ) Math .pow (10 , numDigits ));
89
- String forSum = String . valueOf ( num1 ) + "+" + String . valueOf ( num2 ) ;
87
+ String forSum = num1 + "+" + num2 ;
90
88
if (seenSequences .add (forSum )) {
91
89
break ;
92
90
}
@@ -126,7 +124,7 @@ public MultiDataSet next(int sampleSize) {
126
124
public void reset () {
127
125
currentBatch = 0 ;
128
126
toTestSet = false ;
129
- seenSequences = new HashSet <String >();
127
+ seenSequences = new HashSet <>();
130
128
randnumG = new Random (seed );
131
129
}
132
130
@@ -166,17 +164,17 @@ public MultiDataSetPreProcessor getPreProcessor() {
166
164
Note that the string is padded to the correct length and reversed
167
165
Eg. num1 = 7, num 2 = 13 will return {"3","1","+","7"," "}
168
166
*/
169
- public String [] prepToString (int num1 , int num2 ) {
167
+ private String [] prepToString (int num1 , int num2 ) {
170
168
171
169
String [] encoded = new String [numDigits * 2 + 1 ];
172
- String num1S = String .valueOf (num1 );
173
- String num2S = String .valueOf (num2 );
170
+ StringBuilder num1S = new StringBuilder ( String .valueOf (num1 ) );
171
+ StringBuilder num2S = new StringBuilder ( String .valueOf (num2 ) );
174
172
//padding
175
173
while (num1S .length () < numDigits ) {
176
- num1S = " " + num1S ;
174
+ num1S . insert ( 0 , " " ) ;
177
175
}
178
176
while (num2S .length () < numDigits ) {
179
- num2S = " " + num2S ;
177
+ num2S . insert ( 0 , " " ) ;
180
178
}
181
179
182
180
String sumString = num1S + "+" + num2S ;
@@ -199,7 +197,7 @@ Given a number, return a string array which represents the decoder input (or out
199
197
if !goFirst will return {"3","1"," ","eos"}
200
198
201
199
*/
202
- public String [] prepToString (int sum , boolean goFirst ) {
200
+ private String [] prepToString (int sum , boolean goFirst ) {
203
201
int start , end ;
204
202
String [] decoded = new String [numDigits + 1 + 1 ];
205
203
if (goFirst ) {
@@ -244,24 +242,21 @@ private static INDArray mapToOneHot(String[] toEncode) {
244
242
return ret ;
245
243
}
246
244
247
- public static String mapToString (INDArray encodeSeq , INDArray decodeSeq ) {
248
- return mapToString (encodeSeq ,decodeSeq ," --> " );
249
- }
250
- public static String mapToString (INDArray encodeSeq , INDArray decodeSeq , String sep ) {
251
- String ret = "" ;
245
+ static String mapToString (INDArray encodeSeq , INDArray decodeSeq ) {
246
+ StringBuilder ret = new StringBuilder ();
252
247
String [] encodeSeqS = oneHotDecode (encodeSeq );
253
248
String [] decodeSeqS = oneHotDecode (decodeSeq );
254
249
for (int i =0 ; i <encodeSeqS .length ;i ++) {
255
- ret += "\t " + encodeSeqS [i ] + sep + decodeSeqS [i ] + "\n " ;
250
+ ret . append ( "\t " ). append ( encodeSeqS [i ]). append ( " + " ). append ( decodeSeqS [i ]). append ( "\n " ) ;
256
251
}
257
- return ret ;
252
+ return ret . toString () ;
258
253
}
259
254
260
255
/*
261
256
Helper method that takes in a one hot encoded INDArray and returns an interpreted array of strings
262
257
toInterpret size batchSize x one_hot_vector_size(14) x time_steps
263
258
*/
264
- public static String [] oneHotDecode (INDArray toInterpret ) {
259
+ static String [] oneHotDecode (INDArray toInterpret ) {
265
260
266
261
String [] decodedString = new String [(int )toInterpret .size (0 )];
267
262
INDArray oneHotIndices = Nd4j .argMax (toInterpret , 1 ); //drops a dimension, so now a two dim array of shape batchSize x time_steps
@@ -273,15 +268,15 @@ public static String[] oneHotDecode(INDArray toInterpret) {
273
268
}
274
269
275
270
private static String mapFromOneHot (int [] toMap ) {
276
- String ret = "" ;
277
- for (int i = 0 ; i < toMap . length ; i ++ ) {
278
- ret += oneHotOrder [toMap [ i ]] ;
271
+ StringBuilder ret = new StringBuilder () ;
272
+ for (int value : toMap ) {
273
+ ret . append ( oneHotOrder [value ]) ;
279
274
}
280
275
//encoder sequence, needs to be reversed
281
276
if (toMap .length > numDigits + 1 + 1 ) {
282
- return new StringBuilder (ret ).reverse ().toString ();
277
+ return new StringBuilder (ret . toString () ).reverse ().toString ();
283
278
}
284
- return ret ;
279
+ return ret . toString () ;
285
280
}
286
281
287
282
/*
@@ -308,6 +303,5 @@ private static void oneHotEncoding() {
308
303
}
309
304
310
305
public void setPreProcessor (MultiDataSetPreProcessor preProcessor ) {
311
- this .preProcessor = preProcessor ;
312
306
}
313
307
}
0 commit comments