Skip to content

Commit fe1560e

Browse files
committed
wip. recurrent encode, lottery.
Signed-off-by: Robert Altena <[email protected]>
1 parent 55be9bb commit fe1560e

File tree

5 files changed

+50
-67
lines changed

5 files changed

+50
-67
lines changed

dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/encdec/CorpusIterator.java

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -50,9 +50,8 @@ public class CorpusIterator implements MultiDataSetIterator {
5050
private int currentMacroBatch = 0;
5151
private int dictSize;
5252
private int rowSize;
53-
private MultiDataSetPreProcessor preProcessor;
5453

55-
public CorpusIterator(List<List<Double>> corpus, int batchSize, int batchesPerMacrobatch, int dictSize, int rowSize) {
54+
CorpusIterator(List<List<Double>> corpus, int batchSize, int batchesPerMacrobatch, int dictSize, int rowSize) {
5655
this.corpus = corpus;
5756
this.batchSize = batchSize;
5857
this.batchesPerMacrobatch = batchesPerMacrobatch;
@@ -100,8 +99,8 @@ public MultiDataSet next(int num) {
10099
Nd4j.ones(rowPred.size()));
101100
// prediction (output) and decode ARE one-hots though, I couldn't add an embedding layer on top of the decoder and I'm not sure
102101
// it's a good idea either
103-
double predOneHot[][] = new double[dictSize][rowPred.size()];
104-
double decodeOneHot[][] = new double[dictSize][rowPred.size()];
102+
double[][] predOneHot = new double[dictSize][rowPred.size()];
103+
double[][] decodeOneHot = new double[dictSize][rowPred.size()];
105104
decodeOneHot[2][0] = 1; // <go> token
106105
int predIdx = 0;
107106
for (Double pred : rowPred) {
@@ -149,24 +148,23 @@ public int batch() {
149148
return currentBatch;
150149
}
151150

152-
public int totalBatches() {
151+
int totalBatches() {
153152
return totalBatches;
154153
}
155154

156-
public void setCurrentBatch(int currentBatch) {
155+
void setCurrentBatch(int currentBatch) {
157156
this.currentBatch = currentBatch;
158157
currentMacroBatch = getMacroBatchByCurrentBatch();
159158
}
160159

161-
public boolean hasNextMacrobatch() {
160+
boolean hasNextMacrobatch() {
162161
return getMacroBatchByCurrentBatch() < totalMacroBatches && currentMacroBatch < totalMacroBatches;
163162
}
164163

165-
public void nextMacroBatch() {
164+
void nextMacroBatch() {
166165
++currentMacroBatch;
167166
}
168167

169168
public void setPreProcessor(MultiDataSetPreProcessor preProcessor) {
170-
this.preProcessor = preProcessor;
171169
}
172170
}

dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/encdec/CorpusProcessor.java

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -21,19 +21,19 @@
2121
import java.util.*;
2222

2323
public class CorpusProcessor {
24-
public static final String SPECIALS = "!\"#$;%^:?*()[]{}<>«»,.–—=+…";
24+
static final String SPECIALS = "!\"#$;%^:?*()[]{}<>«»,.–—=+…";
2525
private Set<String> dictSet = new HashSet<>();
2626
private Map<String, Double> freq = new HashMap<>();
2727
private Map<String, Double> dict = new HashMap<>();
2828
private boolean countFreq;
2929
private InputStream is;
3030
private int rowSize;
3131

32-
public CorpusProcessor(String filename, int rowSize, boolean countFreq) throws FileNotFoundException {
32+
CorpusProcessor(String filename, int rowSize, boolean countFreq) throws FileNotFoundException {
3333
this(new FileInputStream(filename), rowSize, countFreq);
3434
}
3535

36-
public CorpusProcessor(InputStream is, int rowSize, boolean countFreq) {
36+
CorpusProcessor(InputStream is, int rowSize, boolean countFreq) {
3737
this.is = is;
3838
this.rowSize = rowSize;
3939
this.countFreq = countFreq;
@@ -43,33 +43,33 @@ public void start() throws IOException {
4343
try (BufferedReader br = new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8))) {
4444
String line;
4545
String lastName = "";
46-
String lastLine = "";
46+
StringBuilder lastLine = new StringBuilder();
4747
while ((line = br.readLine()) != null) {
4848
String[] lineSplit = line.toLowerCase().split(" \\+\\+\\+\\$\\+\\+\\+ ", 5);
4949
if (lineSplit.length > 4) {
5050
// join consecuitive lines from the same speaker
5151
if (lineSplit[1].equals(lastName)) {
52-
if (!lastLine.isEmpty()) {
52+
if (lastLine.length() > 0) {
5353
// if the previous line doesn't end with a special symbol, append a comma and the current line
5454
if (!SPECIALS.contains(lastLine.substring(lastLine.length() - 1))) {
55-
lastLine += ",";
55+
lastLine.append(",");
5656
}
57-
lastLine += " " + lineSplit[4];
57+
lastLine.append(" ").append(lineSplit[4]);
5858
} else {
59-
lastLine = lineSplit[4];
59+
lastLine = new StringBuilder(lineSplit[4]);
6060
}
6161
} else {
62-
if (lastLine.isEmpty()) {
63-
lastLine = lineSplit[4];
62+
if (lastLine.length() == 0) {
63+
lastLine = new StringBuilder(lineSplit[4]);
6464
} else {
65-
processLine(lastLine);
66-
lastLine = lineSplit[4];
65+
processLine(lastLine.toString());
66+
lastLine = new StringBuilder(lineSplit[4]);
6767
}
6868
lastName = lineSplit[1];
6969
}
7070
}
7171
}
72-
processLine(lastLine);
72+
processLine(lastLine.toString());
7373
}
7474
}
7575

@@ -78,7 +78,7 @@ protected void processLine(String lastLine) {
7878
}
7979

8080
// here we not only split the words but also store punctuation marks
81-
protected void tokenizeLine(String lastLine, Collection<String> resultCollection, boolean addSpecials) {
81+
void tokenizeLine(String lastLine, Collection<String> resultCollection, boolean addSpecials) {
8282
String[] words = lastLine.split("[ \t]");
8383
for (String word : words) {
8484
if (!word.isEmpty()) {
@@ -122,15 +122,11 @@ private void addWord(Collection<String> coll, String word) {
122122
}
123123
}
124124

125-
public Set<String> getDictSet() {
126-
return dictSet;
127-
}
128-
129-
public Map<String, Double> getFreq() {
125+
Map<String, Double> getFreq() {
130126
return freq;
131127
}
132128

133-
public void setDict(Map<String, Double> dict) {
129+
void setDict(Map<String, Double> dict) {
134130
this.dict = dict;
135131
}
136132

@@ -142,7 +138,7 @@ public void setDict(Map<String, Double> dict) {
142138
* sequence of words
143139
* @return list of indices.
144140
*/
145-
protected final List<Double> wordsToIndexes(final Iterable<String> words) {
141+
final List<Double> wordsToIndexes(final Iterable<String> words) {
146142
int i = rowSize;
147143
final List<Double> wordIdxs = new LinkedList<>();
148144
for (final String word : words) {

dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/encdec/EncoderDecoderLSTM.java

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -168,8 +168,6 @@ public class EncoderDecoderLSTM {
168168
*/
169169
private final Map<Double, String> revDict = new HashMap<>();
170170

171-
private final String CHARS = "-\\/_&" + CorpusProcessor.SPECIALS;
172-
173171
/**
174172
* The contents of the corpus. This is a list of sentences (each word of the
175173
* sentence is denoted by a {@link java.lang.Double}).
@@ -189,7 +187,6 @@ public class EncoderDecoderLSTM {
189187
// dictionary) are replaced with <unk> token
190188
private static final int TBPTT_SIZE = 25;
191189
private static final double LEARNING_RATE = 1e-1;
192-
private static final double RMS_DECAY = 0.95;
193190
private static final int ROW_SIZE = 40; // maximum line length in tokens
194191

195192
/**
@@ -207,10 +204,10 @@ public class EncoderDecoderLSTM {
207204
private ComputationGraph net;
208205

209206
public static void main(String[] args) throws IOException {
210-
new EncoderDecoderLSTM().run(args);
207+
new EncoderDecoderLSTM().run();
211208
}
212209

213-
private void run(String[] args) throws IOException {
210+
private void run() throws IOException {
214211
Nd4j.getMemoryManager().setAutoGcWindow(GC_WINDOW);
215212

216213
createDictionary();
@@ -227,7 +224,7 @@ private void run(String[] args) throws IOException {
227224
if (input.toLowerCase().equals("d")) {
228225
startDialog(scanner);
229226
} else {
230-
offset = Integer.valueOf(input);
227+
offset = Integer.parseInt(input);
231228
test();
232229
}
233230
}
@@ -326,6 +323,7 @@ private void train(File networkFile, int offset) throws IOException {
326323
}
327324
}
328325

326+
@SuppressWarnings("InfiniteLoopStatement")
329327
private void startDialog(Scanner scanner) throws IOException {
330328
System.out.println("Dialog started.");
331329
while (true) {
@@ -385,10 +383,10 @@ private void test() {
385383
private void output(List<Double> rowIn, boolean printUnknowns) {
386384
net.rnnClearPreviousState();
387385
Collections.reverse(rowIn);
388-
INDArray in = Nd4j.create(ArrayUtils.toPrimitive(rowIn.toArray(new Double[0])), new int[] { 1, 1, rowIn.size() });
386+
INDArray in = Nd4j.create(ArrayUtils.toPrimitive(rowIn.toArray(new Double[0])), 1, 1, rowIn.size());
389387
double[] decodeArr = new double[dict.size()];
390388
decodeArr[2] = 1;
391-
INDArray decode = Nd4j.create(decodeArr, new int[] { 1, dict.size(), 1 });
389+
INDArray decode = Nd4j.create(decodeArr, 1, dict.size(), 1);
392390
net.feedForward(new INDArray[] { in, decode }, false, false);
393391
org.deeplearning4j.nn.layers.recurrent.LSTM decoder = (org.deeplearning4j.nn.layers.recurrent.LSTM) net
394392
.getLayer("decoder");
@@ -419,19 +417,20 @@ private void output(List<Double> rowIn, boolean printUnknowns) {
419417
}
420418
double[] newDecodeArr = new double[dict.size()];
421419
newDecodeArr[idx] = 1;
422-
decode = Nd4j.create(newDecodeArr, new int[] { 1, dict.size(), 1 });
420+
decode = Nd4j.create(newDecodeArr, 1, dict.size(), 1);
423421
}
424422
System.out.println();
425423
}
426424

427-
private void createDictionary() throws IOException, FileNotFoundException {
425+
private void createDictionary() throws IOException {
428426
double idx = 3.0;
429427
dict.put("<unk>", 0.0);
430428
revDict.put(0.0, "<unk>");
431429
dict.put("<eos>", 1.0);
432430
revDict.put(1.0, "<eos>");
433431
dict.put("<go>", 2.0);
434432
revDict.put(2.0, "<go>");
433+
String CHARS = "-\\/_&" + CorpusProcessor.SPECIALS;
435434
for (char c : CHARS.toCharArray()) {
436435
if (!dict.containsKey(String.valueOf(c))) {
437436
dict.put(String.valueOf(c), idx);
@@ -443,7 +442,6 @@ private void createDictionary() throws IOException, FileNotFoundException {
443442
CorpusProcessor corpusProcessor = new CorpusProcessor(toTempPath(CORPUS_FILENAME), ROW_SIZE, true);
444443
corpusProcessor.start();
445444
Map<String, Double> freqs = corpusProcessor.getFreq();
446-
Set<String> dictSet = new TreeSet<>(); // the tokens order is preserved for TreeSet
447445
Map<Double, Set<String>> freqMap = new TreeMap<>(new Comparator<Double>() {
448446

449447
@Override
@@ -452,15 +450,13 @@ public int compare(Double o1, Double o2) {
452450
}
453451
}); // tokens of the same frequency fall under the same key, the order is reversed so the most frequent tokens go first
454452
for (Entry<String, Double> entry : freqs.entrySet()) {
455-
Set<String> set = freqMap.get(entry.getValue());
456-
if (set == null) {
457-
set = new TreeSet<>(); // tokens of the same frequency would be sorted alphabetically
458-
freqMap.put(entry.getValue(), set);
459-
}
453+
Set<String> set = freqMap.computeIfAbsent(entry.getValue(), k -> new TreeSet<>());
454+
// tokens of the same frequency would be sorted alphabetically
460455
set.add(entry.getKey());
461456
}
462457
int cnt = 0;
463-
dictSet.addAll(dict.keySet());
458+
// the tokens order is preserved for TreeSet
459+
Set<String> dictSet = new TreeSet<>(dict.keySet());
464460
// get most frequent tokens and put them to dictSet
465461
for (Entry<Double, Set<String>> entry : freqMap.entrySet()) {
466462
for (String val : entry.getValue()) {

dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/processlottery/BaseDataSetReader.java

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -19,7 +19,7 @@
1919
import org.nd4j.linalg.dataset.DataSet;
2020

2121
import java.io.Serializable;
22-
import java.nio.charset.Charset;
22+
import java.nio.charset.StandardCharsets;
2323
import java.nio.file.Files;
2424
import java.nio.file.Path;
2525
import java.util.Iterator;
@@ -32,13 +32,13 @@ public abstract class BaseDataSetReader implements Serializable {
3232

3333
protected Iterator<String> iter;
3434
protected Path filePath;
35-
protected int totalExamples;
36-
protected int currentCursor;
35+
private int totalExamples;
36+
int currentCursor;
3737

38-
public void doInitialize(){
38+
void doInitialize(){
3939
List<String> dataLines;
4040
try {
41-
dataLines = Files.readAllLines(filePath, Charset.forName("UTF-8"));
41+
dataLines = Files.readAllLines(filePath, StandardCharsets.UTF_8);
4242
} catch (Exception e) {
4343
throw new RuntimeException("loading data failed");
4444
}
@@ -62,11 +62,7 @@ public List<String> getLabels() {
6262
public void reset() {
6363
doInitialize();
6464
}
65-
public int totalExamples() {
65+
int totalExamples() {
6666
return totalExamples;
6767
}
68-
public int cursor() {
69-
return currentCursor;
70-
}
71-
7268
}

dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/processlottery/LotteryCharacterSequenceDataSetReader.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -29,7 +29,7 @@
2929
*/
3030
public class LotteryCharacterSequenceDataSetReader extends BaseDataSetReader {
3131

32-
public LotteryCharacterSequenceDataSetReader(File file) {
32+
LotteryCharacterSequenceDataSetReader(File file) {
3333
filePath = file.toPath();
3434
doInitialize();
3535
}
@@ -39,9 +39,6 @@ public DataSet next(int num) {
3939
INDArray features = Nd4j.create(new int[]{num, 10, 16}, 'f');
4040
INDArray labels = Nd4j.create(new int[]{num, 10, 16}, 'f');
4141

42-
43-
INDArray featuresMask = null;
44-
INDArray labelsMask = null;
4542
for (int i =0; i < num && iter.hasNext(); i ++) {
4643
String featureStr = iter.next();
4744
currentCursor ++;
@@ -54,7 +51,7 @@ public DataSet next(int num) {
5451
labels.putScalar(new int[]{i, label, j}, 1.0);
5552
}
5653
}
57-
return new DataSet(features, labels, featuresMask, labelsMask);
54+
return new DataSet(features, labels, null, null);
5855
}
5956

6057
}

0 commit comments

Comments
 (0)