Skip to content

Commit a55508f

Browse files
committed
modified w2vec sentiment example to download model if doesn't exist
Signed-off-by: eraly <[email protected]>
1 parent 987c689 commit a55508f

File tree

1 file changed

+82
-38
lines changed

1 file changed

+82
-38
lines changed

dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/word2vecsentiment/Word2VecSentimentRNN.java

Lines changed: 82 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -35,51 +35,56 @@
3535
import org.nd4j.linalg.api.ndarray.INDArray;
3636
import org.nd4j.linalg.factory.Nd4j;
3737
import org.nd4j.linalg.indexing.NDArrayIndex;
38+
import org.nd4j.linalg.io.ClassPathResource;
3839
import org.nd4j.linalg.learning.config.Adam;
3940
import org.nd4j.linalg.lossfunctions.LossFunctions;
41+
import org.nd4j.resources.Downloader;
4042

4143
import java.io.File;
44+
import java.io.IOException;
4245
import java.net.URL;
4346
import java.nio.charset.Charset;
47+
import java.util.Scanner;
4448

45-
/**Example: Given a movie review (raw text), classify that movie review as either positive or negative based on the words it contains.
49+
/**
50+
* Example: Given a movie review (raw text), classify that movie review as either positive or negative based on the words it contains.
4651
* This is done by combining Word2Vec vectors and a recurrent neural network model. Each word in a review is vectorized
4752
* (using the Word2Vec model) and fed into a recurrent neural network.
4853
* Training data is the "Large Movie Review Dataset" from http://ai.stanford.edu/~amaas/data/sentiment/
4954
* This data set contains 25,000 training reviews + 25,000 testing reviews
50-
*
55+
* <p>
5156
* Process:
57+
* 0. If path to the wordvectors is not set and a download not found previously in the default location you will be prompted if you want to download it.
5258
* 1. Automatic on first run of example: Download data (movie reviews) + extract
53-
* 2. Load existing Word2Vec model (for example: Google News word vectors. You will have to download this MANUALLY)
59+
* 2. Load existing Word2Vec model (for example: Google News word vectors.)
5460
* 3. Load each each review. Convert words to vectors + reviews to sequences of vectors
5561
* 4. Train network
56-
*
62+
* <p>
5763
* With the current configuration, gives approx. 83% accuracy after 1 epoch. Better performance may be possible with
5864
* additional tuning.
5965
*
60-
* NOTE / INSTRUCTIONS:
61-
* You will have to download the Google News word vector model manually. ~1.5GB
62-
* The Google News vector model available here: https://code.google.com/p/word2vec/
63-
* Download the GoogleNews-vectors-negative300.bin.gz file
64-
* Then: set the WORD_VECTORS_PATH field to point to this location.
65-
*
6666
* @author Alex Black
6767
*/
6868
public class Word2VecSentimentRNN {
6969

70-
/** Data URL for downloading */
70+
/**
71+
* Data URL for downloading
72+
*/
7173
public static final String DATA_URL = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz";
72-
/** Location to save and extract the training/testing data */
74+
/**
75+
* Location to save and extract the training/testing data
76+
*/
7377
public static final String DATA_PATH = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "dl4j_w2vSentiment/");
74-
/** Location (local file system) for the Google News vectors. Set this manually. */
75-
public static final String WORD_VECTORS_PATH = "/PATH/TO/YOUR/VECTORS/GoogleNews-vectors-negative300.bin.gz";
76-
78+
/**
79+
* Location (local file system) for the Google News vectors. Set this manually.
80+
*/
81+
public static String wordVectorsPath = "/PATH/TO/YOUR/VECTORS/GoogleNews-vectors-negative300.bin.gz";
7782

7883
public static void main(String[] args) throws Exception {
79-
if(WORD_VECTORS_PATH.startsWith("/PATH/TO/YOUR/VECTORS/")){
80-
throw new RuntimeException("Please set the WORD_VECTORS_PATH before running this example");
84+
if (wordVectorsPath.startsWith("/PATH/TO/YOUR/VECTORS/")) {
85+
System.out.println("wordVectorsPath has not been set. Checking default location in ~/dl4j-examples-data for download...");
86+
checkDownloadW2VECModel();
8187
}
82-
8388
//Download and extract data
8489
downloadData();
8590

@@ -93,23 +98,23 @@ public static void main(String[] args) throws Exception {
9398

9499
//Set up network configuration
95100
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
96-
.seed(seed)
97-
.updater(new Adam(5e-3))
98-
.l2(1e-5)
99-
.weightInit(WeightInit.XAVIER)
100-
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(1.0)
101-
.list()
102-
.layer(new LSTM.Builder().nIn(vectorSize).nOut(256)
103-
.activation(Activation.TANH).build())
104-
.layer(new RnnOutputLayer.Builder().activation(Activation.SOFTMAX)
105-
.lossFunction(LossFunctions.LossFunction.MCXENT).nIn(256).nOut(2).build())
106-
.build();
101+
.seed(seed)
102+
.updater(new Adam(5e-3))
103+
.l2(1e-5)
104+
.weightInit(WeightInit.XAVIER)
105+
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(1.0)
106+
.list()
107+
.layer(new LSTM.Builder().nIn(vectorSize).nOut(256)
108+
.activation(Activation.TANH).build())
109+
.layer(new RnnOutputLayer.Builder().activation(Activation.SOFTMAX)
110+
.lossFunction(LossFunctions.LossFunction.MCXENT).nIn(256).nOut(2).build())
111+
.build();
107112

108113
MultiLayerNetwork net = new MultiLayerNetwork(conf);
109114
net.init();
110115

111116
//DataSetIterators for training and testing respectively
112-
WordVectors wordVectors = WordVectorSerializer.loadStaticModel(new File(WORD_VECTORS_PATH));
117+
WordVectors wordVectors = WordVectorSerializer.loadStaticModel(new File(wordVectorsPath));
113118
SentimentExampleIterator train = new SentimentExampleIterator(DATA_PATH, wordVectors, batchSize, truncateReviewsToLength, true);
114119
SentimentExampleIterator test = new SentimentExampleIterator(DATA_PATH, wordVectors, batchSize, truncateReviewsToLength, false);
115120

@@ -119,7 +124,7 @@ public static void main(String[] args) throws Exception {
119124

120125
//After training: load a single example and generate predictions
121126
File shortNegativeReviewFile = new File(FilenameUtils.concat(DATA_PATH, "aclImdb/test/neg/12100_1.txt"));
122-
String shortNegativeReview = FileUtils.readFileToString(shortNegativeReviewFile, (Charset)null);
127+
String shortNegativeReview = FileUtils.readFileToString(shortNegativeReviewFile, (Charset) null);
123128

124129
INDArray features = test.loadFeaturesFromString(shortNegativeReview, truncateReviewsToLength);
125130
INDArray networkOutput = net.output(features);
@@ -138,15 +143,15 @@ public static void main(String[] args) throws Exception {
138143
public static void downloadData() throws Exception {
139144
//Create directory if required
140145
File directory = new File(DATA_PATH);
141-
if(!directory.exists()) directory.mkdir();
146+
if (!directory.exists()) directory.mkdir();
142147

143148
//Download file:
144149
String archizePath = DATA_PATH + "aclImdb_v1.tar.gz";
145150
File archiveFile = new File(archizePath);
146151
String extractedPath = DATA_PATH + "aclImdb";
147152
File extractedFile = new File(extractedPath);
148153

149-
if( !archiveFile.exists() ){
154+
if (!archiveFile.exists()) {
150155
System.out.println("Starting data download (80MB)...");
151156
FileUtils.copyURLToFile(new URL(DATA_URL), archiveFile);
152157
System.out.println("Data (.tar.gz file) downloaded to " + archiveFile.getAbsolutePath());
@@ -155,14 +160,53 @@ public static void downloadData() throws Exception {
155160
} else {
156161
//Assume if archive (.tar.gz) exists, then data has already been extracted
157162
System.out.println("Data (.tar.gz file) already exists at " + archiveFile.getAbsolutePath());
158-
if( !extractedFile.exists()){
159-
//Extract tar.gz file to output directory
160-
DataUtilities.extractTarGz(archizePath, DATA_PATH);
163+
if (!extractedFile.exists()) {
164+
//Extract tar.gz file to output directory
165+
DataUtilities.extractTarGz(archizePath, DATA_PATH);
161166
} else {
162-
System.out.println("Data (extracted) already exists at " + extractedFile.getAbsolutePath());
167+
System.out.println("Data (extracted) already exists at " + extractedFile.getAbsolutePath());
163168
}
164169
}
165170
}
166171

167-
172+
public static void checkDownloadW2VECModel() throws IOException {
173+
String defaultwordVectorsPath = FilenameUtils.concat(System.getProperty("user.home"), "dl4j-examples-data/w2vec300");
174+
wordVectorsPath = new File(defaultwordVectorsPath, "GoogleNews-vectors-negative300.bin.gz").getAbsolutePath();
175+
if (new File(wordVectorsPath).exists()) {
176+
System.out.println("\n\tGoogleNews-vectors-negative300.bin.gz file found at path: " + defaultwordVectorsPath);
177+
System.out.println("\tChecking md5 of existing file..");
178+
if (Downloader.checkMD5OfFile("1c892c4707a8a1a508b01a01735c0339", new File(wordVectorsPath))) {
179+
System.out.println("\tExisting file hash matches.");
180+
return;
181+
} else {
182+
System.out.println("\tExisting file hash doesn't match. Retrying download...");
183+
}
184+
} else {
185+
System.out.println("\n\tNo previous download of GoogleNews-vectors-negative300.bin.gz found at path: " + defaultwordVectorsPath);
186+
}
187+
System.out.println("\tWARNING: GoogleNews-vectors-negative300.bin.gz is a 1.5GB file.");
188+
System.out.println("\tPress \"ENTER\" to start a download of GoogleNews-vectors-negative300.bin.gz to " + defaultwordVectorsPath);
189+
Scanner scanner = new Scanner(System.in);
190+
scanner.nextLine();
191+
System.out.println("Starting model download (1.5GB!)...");
192+
String downloadScript = new ClassPathResource("w2vecdownload/word2vec-download300model.sh").getFile().getAbsolutePath();
193+
ProcessBuilder processBuilder = new ProcessBuilder(downloadScript, defaultwordVectorsPath);
194+
try {
195+
processBuilder.inheritIO();
196+
Process process = processBuilder.start();
197+
int exitVal = process.waitFor();
198+
if (exitVal == 0) {
199+
System.out.println("Successfully downloaded word2vec model!");
200+
} else {
201+
System.out.println("Download failed. Please download model manually and set the \"wordVectorsPath\" in the code with the path to it.");
202+
System.exit(0);
203+
}
204+
} catch (IOException e) {
205+
e.printStackTrace();
206+
} catch (InterruptedException e) {
207+
e.printStackTrace();
208+
}
209+
}
168210
}
211+
212+

0 commit comments

Comments
 (0)