Skip to content

Commit 7e6ae4b

Browse files
committed
update CnnSentenceClassificationExample.
Signed-off-by: Robert Altena <[email protected]>
1 parent c96733c commit 7e6ae4b

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

dl4j-examples/src/main/java/org/deeplearning4j/examples/convolution/sentenceclassification/CnnSentenceClassificationExample.java

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -20,6 +20,7 @@
2020
import org.apache.commons.io.FilenameUtils;
2121
import org.deeplearning4j.examples.recurrent.word2vecsentiment.Word2VecSentimentRNN;
2222
import org.deeplearning4j.iterator.CnnSentenceDataSetIterator;
23+
import org.deeplearning4j.iterator.CnnSentenceDataSetIterator.Format;
2324
import org.deeplearning4j.iterator.LabeledSentenceProvider;
2425
import org.deeplearning4j.iterator.provider.FileLabeledSentenceProvider;
2526
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
@@ -47,6 +48,7 @@
4748
import org.nd4j.linalg.lossfunctions.LossFunctions;
4849

4950
import java.io.File;
51+
import java.nio.charset.Charset;
5052
import java.util.*;
5153

5254
/**
@@ -58,14 +60,13 @@
5860
*/
5961
public class CnnSentenceClassificationExample {
6062

61-
/** Data URL for downloading */
62-
public static final String DATA_URL = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz";
6363
/** Location to save and extract the training/testing data */
6464
public static final String DATA_PATH = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "dl4j_w2vSentiment/");
6565
/** Location (local file system) for the Google News vectors. Set this manually. */
66-
public static final String WORD_VECTORS_PATH = "/PATH/TO/YOUR/VECTORS/GoogleNews-vectors-negative300.bin.gz";
66+
private static final String WORD_VECTORS_PATH = "/PATH/TO/YOUR/VECTORS/GoogleNews-vectors-negative300.bin.gz";
6767

6868
public static void main(String[] args) throws Exception {
69+
//noinspection ConstantConditions
6970
if(WORD_VECTORS_PATH.startsWith("/PATH/TO/YOUR/VECTORS/")){
7071
throw new RuntimeException("Please set the WORD_VECTORS_PATH before running this example");
7172
}
@@ -149,7 +150,7 @@ public static void main(String[] args) throws Exception {
149150

150151
//After training: load a single sentence and generate a prediction
151152
String pathFirstNegativeFile = FilenameUtils.concat(DATA_PATH, "aclImdb/test/neg/0_2.txt");
152-
String contentsFirstNegative = FileUtils.readFileToString(new File(pathFirstNegativeFile));
153+
String contentsFirstNegative = FileUtils.readFileToString(new File(pathFirstNegativeFile), (Charset) null);
153154
INDArray featuresFirstNegative = ((CnnSentenceDataSetIterator)testIter).loadSingleSentence(contentsFirstNegative);
154155

155156
INDArray predictionsFirstNegative = net.outputSingle(featuresFirstNegative);
@@ -172,12 +173,12 @@ private static DataSetIterator getDataSetIterator(boolean isTraining, WordVector
172173
File fileNegative = new File(negativeBaseDir);
173174

174175
Map<String,List<File>> reviewFilesMap = new HashMap<>();
175-
reviewFilesMap.put("Positive", Arrays.asList(filePositive.listFiles()));
176-
reviewFilesMap.put("Negative", Arrays.asList(fileNegative.listFiles()));
176+
reviewFilesMap.put("Positive", Arrays.asList(Objects.requireNonNull(filePositive.listFiles())));
177+
reviewFilesMap.put("Negative", Arrays.asList(Objects.requireNonNull(fileNegative.listFiles())));
177178

178179
LabeledSentenceProvider sentenceProvider = new FileLabeledSentenceProvider(reviewFilesMap, rng);
179180

180-
return new CnnSentenceDataSetIterator.Builder()
181+
return new CnnSentenceDataSetIterator.Builder(Format.CNN2D)
181182
.sentenceProvider(sentenceProvider)
182183
.wordVectors(wordVectors)
183184
.minibatchSize(minibatchSize)

0 commit comments

Comments
 (0)