35
35
import org .nd4j .linalg .api .ndarray .INDArray ;
36
36
import org .nd4j .linalg .factory .Nd4j ;
37
37
import org .nd4j .linalg .indexing .NDArrayIndex ;
38
+ import org .nd4j .linalg .io .ClassPathResource ;
38
39
import org .nd4j .linalg .learning .config .Adam ;
39
40
import org .nd4j .linalg .lossfunctions .LossFunctions ;
41
+ import org .nd4j .resources .Downloader ;
40
42
41
43
import java .io .File ;
44
+ import java .io .IOException ;
42
45
import java .net .URL ;
43
46
import java .nio .charset .Charset ;
47
+ import java .util .Scanner ;
44
48
45
- /**Example: Given a movie review (raw text), classify that movie review as either positive or negative based on the words it contains.
49
+ /**
50
+ * Example: Given a movie review (raw text), classify that movie review as either positive or negative based on the words it contains.
46
51
* This is done by combining Word2Vec vectors and a recurrent neural network model. Each word in a review is vectorized
47
52
* (using the Word2Vec model) and fed into a recurrent neural network.
48
53
* Training data is the "Large Movie Review Dataset" from http://ai.stanford.edu/~amaas/data/sentiment/
49
54
* This data set contains 25,000 training reviews + 25,000 testing reviews
50
- *
55
+ * <p>
51
56
* Process:
57
+ * 0. If path to the wordvectors is not set and a download not found previously in the default location you will be prompted if you want to download it.
52
58
* 1. Automatic on first run of example: Download data (movie reviews) + extract
53
- * 2. Load existing Word2Vec model (for example: Google News word vectors. You will have to download this MANUALLY )
59
+ * 2. Load existing Word2Vec model (for example: Google News word vectors.)
54
60
* 3. Load each each review. Convert words to vectors + reviews to sequences of vectors
55
61
* 4. Train network
56
- *
62
+ * <p>
57
63
* With the current configuration, gives approx. 83% accuracy after 1 epoch. Better performance may be possible with
58
64
* additional tuning.
59
65
*
60
- * NOTE / INSTRUCTIONS:
61
- * You will have to download the Google News word vector model manually. ~1.5GB
62
- * The Google News vector model available here: https://code.google.com/p/word2vec/
63
- * Download the GoogleNews-vectors-negative300.bin.gz file
64
- * Then: set the WORD_VECTORS_PATH field to point to this location.
65
- *
66
66
* @author Alex Black
67
67
*/
68
68
public class Word2VecSentimentRNN {
69
69
70
- /** Data URL for downloading */
70
+ /**
71
+ * Data URL for downloading
72
+ */
71
73
public static final String DATA_URL = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz" ;
72
- /** Location to save and extract the training/testing data */
74
+ /**
75
+ * Location to save and extract the training/testing data
76
+ */
73
77
public static final String DATA_PATH = FilenameUtils .concat (System .getProperty ("java.io.tmpdir" ), "dl4j_w2vSentiment/" );
74
- /** Location (local file system) for the Google News vectors. Set this manually. */
75
- public static final String WORD_VECTORS_PATH = "/PATH/TO/YOUR/VECTORS/GoogleNews-vectors-negative300.bin.gz" ;
76
-
78
+ /**
79
+ * Location (local file system) for the Google News vectors. Set this manually.
80
+ */
81
+ public static String wordVectorsPath = "/PATH/TO/YOUR/VECTORS/GoogleNews-vectors-negative300.bin.gz" ;
77
82
78
83
public static void main (String [] args ) throws Exception {
79
- if (WORD_VECTORS_PATH .startsWith ("/PATH/TO/YOUR/VECTORS/" )){
80
- throw new RuntimeException ("Please set the WORD_VECTORS_PATH before running this example" );
84
+ if (wordVectorsPath .startsWith ("/PATH/TO/YOUR/VECTORS/" )) {
85
+ System .out .println ("wordVectorsPath has not been set. Checking default location in ~/dl4j-examples-data for download..." );
86
+ checkDownloadW2VECModel ();
81
87
}
82
-
83
88
//Download and extract data
84
89
downloadData ();
85
90
@@ -93,23 +98,23 @@ public static void main(String[] args) throws Exception {
93
98
94
99
//Set up network configuration
95
100
MultiLayerConfiguration conf = new NeuralNetConfiguration .Builder ()
96
- .seed (seed )
97
- .updater (new Adam (5e-3 ))
98
- .l2 (1e-5 )
99
- .weightInit (WeightInit .XAVIER )
100
- .gradientNormalization (GradientNormalization .ClipElementWiseAbsoluteValue ).gradientNormalizationThreshold (1.0 )
101
- .list ()
102
- .layer (new LSTM .Builder ().nIn (vectorSize ).nOut (256 )
103
- .activation (Activation .TANH ).build ())
104
- .layer (new RnnOutputLayer .Builder ().activation (Activation .SOFTMAX )
105
- .lossFunction (LossFunctions .LossFunction .MCXENT ).nIn (256 ).nOut (2 ).build ())
106
- .build ();
101
+ .seed (seed )
102
+ .updater (new Adam (5e-3 ))
103
+ .l2 (1e-5 )
104
+ .weightInit (WeightInit .XAVIER )
105
+ .gradientNormalization (GradientNormalization .ClipElementWiseAbsoluteValue ).gradientNormalizationThreshold (1.0 )
106
+ .list ()
107
+ .layer (new LSTM .Builder ().nIn (vectorSize ).nOut (256 )
108
+ .activation (Activation .TANH ).build ())
109
+ .layer (new RnnOutputLayer .Builder ().activation (Activation .SOFTMAX )
110
+ .lossFunction (LossFunctions .LossFunction .MCXENT ).nIn (256 ).nOut (2 ).build ())
111
+ .build ();
107
112
108
113
MultiLayerNetwork net = new MultiLayerNetwork (conf );
109
114
net .init ();
110
115
111
116
//DataSetIterators for training and testing respectively
112
- WordVectors wordVectors = WordVectorSerializer .loadStaticModel (new File (WORD_VECTORS_PATH ));
117
+ WordVectors wordVectors = WordVectorSerializer .loadStaticModel (new File (wordVectorsPath ));
113
118
SentimentExampleIterator train = new SentimentExampleIterator (DATA_PATH , wordVectors , batchSize , truncateReviewsToLength , true );
114
119
SentimentExampleIterator test = new SentimentExampleIterator (DATA_PATH , wordVectors , batchSize , truncateReviewsToLength , false );
115
120
@@ -119,7 +124,7 @@ public static void main(String[] args) throws Exception {
119
124
120
125
//After training: load a single example and generate predictions
121
126
File shortNegativeReviewFile = new File (FilenameUtils .concat (DATA_PATH , "aclImdb/test/neg/12100_1.txt" ));
122
- String shortNegativeReview = FileUtils .readFileToString (shortNegativeReviewFile , (Charset )null );
127
+ String shortNegativeReview = FileUtils .readFileToString (shortNegativeReviewFile , (Charset ) null );
123
128
124
129
INDArray features = test .loadFeaturesFromString (shortNegativeReview , truncateReviewsToLength );
125
130
INDArray networkOutput = net .output (features );
@@ -138,15 +143,15 @@ public static void main(String[] args) throws Exception {
138
143
public static void downloadData () throws Exception {
139
144
//Create directory if required
140
145
File directory = new File (DATA_PATH );
141
- if (!directory .exists ()) directory .mkdir ();
146
+ if (!directory .exists ()) directory .mkdir ();
142
147
143
148
//Download file:
144
149
String archizePath = DATA_PATH + "aclImdb_v1.tar.gz" ;
145
150
File archiveFile = new File (archizePath );
146
151
String extractedPath = DATA_PATH + "aclImdb" ;
147
152
File extractedFile = new File (extractedPath );
148
153
149
- if ( !archiveFile .exists () ) {
154
+ if ( !archiveFile .exists ()) {
150
155
System .out .println ("Starting data download (80MB)..." );
151
156
FileUtils .copyURLToFile (new URL (DATA_URL ), archiveFile );
152
157
System .out .println ("Data (.tar.gz file) downloaded to " + archiveFile .getAbsolutePath ());
@@ -155,14 +160,53 @@ public static void downloadData() throws Exception {
155
160
} else {
156
161
//Assume if archive (.tar.gz) exists, then data has already been extracted
157
162
System .out .println ("Data (.tar.gz file) already exists at " + archiveFile .getAbsolutePath ());
158
- if ( !extractedFile .exists ()){
159
- //Extract tar.gz file to output directory
160
- DataUtilities .extractTarGz (archizePath , DATA_PATH );
163
+ if ( !extractedFile .exists ()) {
164
+ //Extract tar.gz file to output directory
165
+ DataUtilities .extractTarGz (archizePath , DATA_PATH );
161
166
} else {
162
- System .out .println ("Data (extracted) already exists at " + extractedFile .getAbsolutePath ());
167
+ System .out .println ("Data (extracted) already exists at " + extractedFile .getAbsolutePath ());
163
168
}
164
169
}
165
170
}
166
171
167
-
172
+ public static void checkDownloadW2VECModel () throws IOException {
173
+ String defaultwordVectorsPath = FilenameUtils .concat (System .getProperty ("user.home" ), "dl4j-examples-data/w2vec300" );
174
+ wordVectorsPath = new File (defaultwordVectorsPath , "GoogleNews-vectors-negative300.bin.gz" ).getAbsolutePath ();
175
+ if (new File (wordVectorsPath ).exists ()) {
176
+ System .out .println ("\n \t GoogleNews-vectors-negative300.bin.gz file found at path: " + defaultwordVectorsPath );
177
+ System .out .println ("\t Checking md5 of existing file.." );
178
+ if (Downloader .checkMD5OfFile ("1c892c4707a8a1a508b01a01735c0339" , new File (wordVectorsPath ))) {
179
+ System .out .println ("\t Existing file hash matches." );
180
+ return ;
181
+ } else {
182
+ System .out .println ("\t Existing file hash doesn't match. Retrying download..." );
183
+ }
184
+ } else {
185
+ System .out .println ("\n \t No previous download of GoogleNews-vectors-negative300.bin.gz found at path: " + defaultwordVectorsPath );
186
+ }
187
+ System .out .println ("\t WARNING: GoogleNews-vectors-negative300.bin.gz is a 1.5GB file." );
188
+ System .out .println ("\t Press \" ENTER\" to start a download of GoogleNews-vectors-negative300.bin.gz to " + defaultwordVectorsPath );
189
+ Scanner scanner = new Scanner (System .in );
190
+ scanner .nextLine ();
191
+ System .out .println ("Starting model download (1.5GB!)..." );
192
+ String downloadScript = new ClassPathResource ("w2vecdownload/word2vec-download300model.sh" ).getFile ().getAbsolutePath ();
193
+ ProcessBuilder processBuilder = new ProcessBuilder (downloadScript , defaultwordVectorsPath );
194
+ try {
195
+ processBuilder .inheritIO ();
196
+ Process process = processBuilder .start ();
197
+ int exitVal = process .waitFor ();
198
+ if (exitVal == 0 ) {
199
+ System .out .println ("Successfully downloaded word2vec model!" );
200
+ } else {
201
+ System .out .println ("Download failed. Please download model manually and set the \" wordVectorsPath\" in the code with the path to it." );
202
+ System .exit (0 );
203
+ }
204
+ } catch (IOException e ) {
205
+ e .printStackTrace ();
206
+ } catch (InterruptedException e ) {
207
+ e .printStackTrace ();
208
+ }
209
+ }
168
210
}
211
+
212
+
0 commit comments