1
- /** *****************************************************************************
1
+ /* *****************************************************************************
2
2
* Copyright (c) 2015-2019 Skymind, Inc.
3
3
*
4
4
* This program and the accompanying materials are made available under the
20
20
import org .apache .commons .io .FilenameUtils ;
21
21
import org .deeplearning4j .examples .recurrent .word2vecsentiment .Word2VecSentimentRNN ;
22
22
import org .deeplearning4j .iterator .CnnSentenceDataSetIterator ;
23
+ import org .deeplearning4j .iterator .CnnSentenceDataSetIterator .Format ;
23
24
import org .deeplearning4j .iterator .LabeledSentenceProvider ;
24
25
import org .deeplearning4j .iterator .provider .FileLabeledSentenceProvider ;
25
26
import org .deeplearning4j .models .embeddings .loader .WordVectorSerializer ;
47
48
import org .nd4j .linalg .lossfunctions .LossFunctions ;
48
49
49
50
import java .io .File ;
51
+ import java .nio .charset .Charset ;
50
52
import java .util .*;
51
53
52
54
/**
58
60
*/
59
61
public class CnnSentenceClassificationExample {
60
62
61
- /** Data URL for downloading */
62
- public static final String DATA_URL = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz" ;
63
63
/** Location to save and extract the training/testing data */
64
64
public static final String DATA_PATH = FilenameUtils .concat (System .getProperty ("java.io.tmpdir" ), "dl4j_w2vSentiment/" );
65
65
/** Location (local file system) for the Google News vectors. Set this manually. */
66
- public static final String WORD_VECTORS_PATH = "/PATH/TO/YOUR/VECTORS/GoogleNews-vectors-negative300.bin.gz" ;
66
+ private static final String WORD_VECTORS_PATH = "/PATH/TO/YOUR/VECTORS/GoogleNews-vectors-negative300.bin.gz" ;
67
67
68
68
public static void main (String [] args ) throws Exception {
69
+ //noinspection ConstantConditions
69
70
if (WORD_VECTORS_PATH .startsWith ("/PATH/TO/YOUR/VECTORS/" )){
70
71
throw new RuntimeException ("Please set the WORD_VECTORS_PATH before running this example" );
71
72
}
@@ -149,7 +150,7 @@ public static void main(String[] args) throws Exception {
149
150
150
151
//After training: load a single sentence and generate a prediction
151
152
String pathFirstNegativeFile = FilenameUtils .concat (DATA_PATH , "aclImdb/test/neg/0_2.txt" );
152
- String contentsFirstNegative = FileUtils .readFileToString (new File (pathFirstNegativeFile ));
153
+ String contentsFirstNegative = FileUtils .readFileToString (new File (pathFirstNegativeFile ), ( Charset ) null );
153
154
INDArray featuresFirstNegative = ((CnnSentenceDataSetIterator )testIter ).loadSingleSentence (contentsFirstNegative );
154
155
155
156
INDArray predictionsFirstNegative = net .outputSingle (featuresFirstNegative );
@@ -172,12 +173,12 @@ private static DataSetIterator getDataSetIterator(boolean isTraining, WordVector
172
173
File fileNegative = new File (negativeBaseDir );
173
174
174
175
Map <String ,List <File >> reviewFilesMap = new HashMap <>();
175
- reviewFilesMap .put ("Positive" , Arrays .asList (filePositive .listFiles ()));
176
- reviewFilesMap .put ("Negative" , Arrays .asList (fileNegative .listFiles ()));
176
+ reviewFilesMap .put ("Positive" , Arrays .asList (Objects . requireNonNull ( filePositive .listFiles () )));
177
+ reviewFilesMap .put ("Negative" , Arrays .asList (Objects . requireNonNull ( fileNegative .listFiles () )));
177
178
178
179
LabeledSentenceProvider sentenceProvider = new FileLabeledSentenceProvider (reviewFilesMap , rng );
179
180
180
- return new CnnSentenceDataSetIterator .Builder ()
181
+ return new CnnSentenceDataSetIterator .Builder (Format . CNN2D )
181
182
.sentenceProvider (sentenceProvider )
182
183
.wordVectors (wordVectors )
183
184
.minibatchSize (minibatchSize )
0 commit comments