Skip to content

Commit 4ed29a6

Browse files
authored
Merge pull request #940 from eraly/eraly
Word2vec model automatic download
2 parents 7eb12de + cba6062 commit 4ed29a6

File tree

1 file changed

+56
-27
lines changed

1 file changed

+56
-27
lines changed

dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/word2vecsentiment/Word2VecSentimentRNN.java

Lines changed: 56 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -37,49 +37,53 @@
3737
import org.nd4j.linalg.indexing.NDArrayIndex;
3838
import org.nd4j.linalg.learning.config.Adam;
3939
import org.nd4j.linalg.lossfunctions.LossFunctions;
40+
import org.nd4j.resources.Downloader;
4041

4142
import java.io.File;
43+
import java.io.IOException;
4244
import java.net.URL;
4345
import java.nio.charset.Charset;
46+
import java.util.Scanner;
4447

45-
/**Example: Given a movie review (raw text), classify that movie review as either positive or negative based on the words it contains.
48+
/**
49+
* Example: Given a movie review (raw text), classify that movie review as either positive or negative based on the words it contains.
4650
* This is done by combining Word2Vec vectors and a recurrent neural network model. Each word in a review is vectorized
4751
* (using the Word2Vec model) and fed into a recurrent neural network.
4852
* Training data is the "Large Movie Review Dataset" from http://ai.stanford.edu/~amaas/data/sentiment/
4953
* This data set contains 25,000 training reviews + 25,000 testing reviews
50-
*
54+
* <p>
5155
* Process:
56+
* 0. If path to the wordvectors is not set and a download not found previously in the default location you will be prompted if you want to download it.
5257
* 1. Automatic on first run of example: Download data (movie reviews) + extract
53-
* 2. Load existing Word2Vec model (for example: Google News word vectors. You will have to download this MANUALLY)
58+
* 2. Load existing Word2Vec model (for example: Google News word vectors.)
5459
* 3. Load each each review. Convert words to vectors + reviews to sequences of vectors
5560
* 4. Train network
56-
*
61+
* <p>
5762
* With the current configuration, gives approx. 83% accuracy after 1 epoch. Better performance may be possible with
5863
* additional tuning.
5964
*
60-
* NOTE / INSTRUCTIONS:
61-
* You will have to download the Google News word vector model manually. ~1.5GB
62-
* The Google News vector model available here: https://code.google.com/p/word2vec/
63-
* Download the GoogleNews-vectors-negative300.bin.gz file
64-
* Then: set the WORD_VECTORS_PATH field to point to this location.
65-
*
6665
* @author Alex Black
6766
*/
6867
public class Word2VecSentimentRNN {
6968

70-
/** Data URL for downloading */
69+
/**
70+
* Data URL for downloading
71+
*/
7172
public static final String DATA_URL = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz";
72-
/** Location to save and extract the training/testing data */
73+
/**
74+
* Location to save and extract the training/testing data
75+
*/
7376
public static final String DATA_PATH = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "dl4j_w2vSentiment/");
74-
/** Location (local file system) for the Google News vectors. Set this manually. */
75-
public static final String WORD_VECTORS_PATH = "/PATH/TO/YOUR/VECTORS/GoogleNews-vectors-negative300.bin.gz";
76-
77+
/**
78+
* Location (local file system) for the Google News vectors. Set this manually.
79+
*/
80+
public static String wordVectorsPath = "/PATH/TO/YOUR/VECTORS/GoogleNews-vectors-negative300.bin.gz";
7781

7882
public static void main(String[] args) throws Exception {
79-
if(WORD_VECTORS_PATH.startsWith("/PATH/TO/YOUR/VECTORS/")){
80-
throw new RuntimeException("Please set the WORD_VECTORS_PATH before running this example");
83+
if (wordVectorsPath.startsWith("/PATH/TO/YOUR/VECTORS/")) {
84+
System.out.println("wordVectorsPath has not been set. Checking default location in ~/dl4j-examples-data for download...");
85+
checkDownloadW2VECModel();
8186
}
82-
8387
//Download and extract data
8488
downloadData();
8589

@@ -109,7 +113,7 @@ public static void main(String[] args) throws Exception {
109113
net.init();
110114

111115
//DataSetIterators for training and testing respectively
112-
WordVectors wordVectors = WordVectorSerializer.loadStaticModel(new File(WORD_VECTORS_PATH));
116+
WordVectors wordVectors = WordVectorSerializer.loadStaticModel(new File(wordVectorsPath));
113117
SentimentExampleIterator train = new SentimentExampleIterator(DATA_PATH, wordVectors, batchSize, truncateReviewsToLength, true);
114118
SentimentExampleIterator test = new SentimentExampleIterator(DATA_PATH, wordVectors, batchSize, truncateReviewsToLength, false);
115119

@@ -119,7 +123,7 @@ public static void main(String[] args) throws Exception {
119123

120124
//After training: load a single example and generate predictions
121125
File shortNegativeReviewFile = new File(FilenameUtils.concat(DATA_PATH, "aclImdb/test/neg/12100_1.txt"));
122-
String shortNegativeReview = FileUtils.readFileToString(shortNegativeReviewFile, (Charset)null);
126+
String shortNegativeReview = FileUtils.readFileToString(shortNegativeReviewFile, (Charset) null);
123127

124128
INDArray features = test.loadFeaturesFromString(shortNegativeReview, truncateReviewsToLength);
125129
INDArray networkOutput = net.output(features);
@@ -138,15 +142,15 @@ public static void main(String[] args) throws Exception {
138142
public static void downloadData() throws Exception {
139143
//Create directory if required
140144
File directory = new File(DATA_PATH);
141-
if(!directory.exists()) directory.mkdir();
145+
if (!directory.exists()) directory.mkdir();
142146

143147
//Download file:
144148
String archizePath = DATA_PATH + "aclImdb_v1.tar.gz";
145149
File archiveFile = new File(archizePath);
146150
String extractedPath = DATA_PATH + "aclImdb";
147151
File extractedFile = new File(extractedPath);
148152

149-
if( !archiveFile.exists() ){
153+
if (!archiveFile.exists()) {
150154
System.out.println("Starting data download (80MB)...");
151155
FileUtils.copyURLToFile(new URL(DATA_URL), archiveFile);
152156
System.out.println("Data (.tar.gz file) downloaded to " + archiveFile.getAbsolutePath());
@@ -155,14 +159,39 @@ public static void downloadData() throws Exception {
155159
} else {
156160
//Assume if archive (.tar.gz) exists, then data has already been extracted
157161
System.out.println("Data (.tar.gz file) already exists at " + archiveFile.getAbsolutePath());
158-
if( !extractedFile.exists()){
159-
//Extract tar.gz file to output directory
160-
DataUtilities.extractTarGz(archizePath, DATA_PATH);
162+
if (!extractedFile.exists()) {
163+
//Extract tar.gz file to output directory
164+
DataUtilities.extractTarGz(archizePath, DATA_PATH);
161165
} else {
162-
System.out.println("Data (extracted) already exists at " + extractedFile.getAbsolutePath());
166+
System.out.println("Data (extracted) already exists at " + extractedFile.getAbsolutePath());
163167
}
164168
}
165169
}
166170

167-
171+
public static void checkDownloadW2VECModel() throws IOException {
172+
String defaultwordVectorsPath = FilenameUtils.concat(System.getProperty("user.home"), "dl4j-examples-data/w2vec300");
173+
String md5w2vec = "1c892c4707a8a1a508b01a01735c0339";
174+
wordVectorsPath = new File(defaultwordVectorsPath, "GoogleNews-vectors-negative300.bin.gz").getAbsolutePath();
175+
if (new File(wordVectorsPath).exists()) {
176+
System.out.println("\n\tGoogleNews-vectors-negative300.bin.gz file found at path: " + defaultwordVectorsPath);
177+
System.out.println("\tChecking md5 of existing file..");
178+
if (Downloader.checkMD5OfFile(md5w2vec, new File(wordVectorsPath))) {
179+
System.out.println("\tExisting file hash matches.");
180+
return;
181+
} else {
182+
System.out.println("\tExisting file hash doesn't match. Retrying download...");
183+
}
184+
} else {
185+
System.out.println("\n\tNo previous download of GoogleNews-vectors-negative300.bin.gz found at path: " + defaultwordVectorsPath);
186+
}
187+
System.out.println("\tWARNING: GoogleNews-vectors-negative300.bin.gz is a 1.5GB file.");
188+
System.out.println("\tPress \"ENTER\" to start a download of GoogleNews-vectors-negative300.bin.gz to " + defaultwordVectorsPath);
189+
Scanner scanner = new Scanner(System.in);
190+
scanner.nextLine();
191+
System.out.println("Starting model download (1.5GB!)...");
192+
Downloader.download("Word2Vec", new URL("https://dl4jdata.blob.core.windows.net/resources/wordvectors/GoogleNews-vectors-negative300.bin.gz"), new File(wordVectorsPath), md5w2vec, 5);
193+
System.out.println("Successfully downloaded word2vec model to " + wordVectorsPath);
194+
}
168195
}
196+
197+

0 commit comments

Comments
 (0)