Skip to content

Commit cba6062

Browse files
committed
java only download from our hosted copy
Signed-off-by: eraly <[email protected]>
1 parent a55508f commit cba6062

File tree

2 files changed

+15
-44
lines changed

2 files changed

+15
-44
lines changed

dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/word2vecsentiment/Word2VecSentimentRNN.java

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
import org.nd4j.linalg.api.ndarray.INDArray;
3636
import org.nd4j.linalg.factory.Nd4j;
3737
import org.nd4j.linalg.indexing.NDArrayIndex;
38-
import org.nd4j.linalg.io.ClassPathResource;
3938
import org.nd4j.linalg.learning.config.Adam;
4039
import org.nd4j.linalg.lossfunctions.LossFunctions;
4140
import org.nd4j.resources.Downloader;
@@ -98,17 +97,17 @@ public static void main(String[] args) throws Exception {
9897

9998
//Set up network configuration
10099
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
101-
.seed(seed)
102-
.updater(new Adam(5e-3))
103-
.l2(1e-5)
104-
.weightInit(WeightInit.XAVIER)
105-
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(1.0)
106-
.list()
107-
.layer(new LSTM.Builder().nIn(vectorSize).nOut(256)
108-
.activation(Activation.TANH).build())
109-
.layer(new RnnOutputLayer.Builder().activation(Activation.SOFTMAX)
110-
.lossFunction(LossFunctions.LossFunction.MCXENT).nIn(256).nOut(2).build())
111-
.build();
100+
.seed(seed)
101+
.updater(new Adam(5e-3))
102+
.l2(1e-5)
103+
.weightInit(WeightInit.XAVIER)
104+
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(1.0)
105+
.list()
106+
.layer(new LSTM.Builder().nIn(vectorSize).nOut(256)
107+
.activation(Activation.TANH).build())
108+
.layer(new RnnOutputLayer.Builder().activation(Activation.SOFTMAX)
109+
.lossFunction(LossFunctions.LossFunction.MCXENT).nIn(256).nOut(2).build())
110+
.build();
112111

113112
MultiLayerNetwork net = new MultiLayerNetwork(conf);
114113
net.init();
@@ -171,11 +170,12 @@ public static void downloadData() throws Exception {
171170

172171
public static void checkDownloadW2VECModel() throws IOException {
173172
String defaultwordVectorsPath = FilenameUtils.concat(System.getProperty("user.home"), "dl4j-examples-data/w2vec300");
173+
String md5w2vec = "1c892c4707a8a1a508b01a01735c0339";
174174
wordVectorsPath = new File(defaultwordVectorsPath, "GoogleNews-vectors-negative300.bin.gz").getAbsolutePath();
175175
if (new File(wordVectorsPath).exists()) {
176176
System.out.println("\n\tGoogleNews-vectors-negative300.bin.gz file found at path: " + defaultwordVectorsPath);
177177
System.out.println("\tChecking md5 of existing file..");
178-
if (Downloader.checkMD5OfFile("1c892c4707a8a1a508b01a01735c0339", new File(wordVectorsPath))) {
178+
if (Downloader.checkMD5OfFile(md5w2vec, new File(wordVectorsPath))) {
179179
System.out.println("\tExisting file hash matches.");
180180
return;
181181
} else {
@@ -189,23 +189,8 @@ public static void checkDownloadW2VECModel() throws IOException {
189189
Scanner scanner = new Scanner(System.in);
190190
scanner.nextLine();
191191
System.out.println("Starting model download (1.5GB!)...");
192-
String downloadScript = new ClassPathResource("w2vecdownload/word2vec-download300model.sh").getFile().getAbsolutePath();
193-
ProcessBuilder processBuilder = new ProcessBuilder(downloadScript, defaultwordVectorsPath);
194-
try {
195-
processBuilder.inheritIO();
196-
Process process = processBuilder.start();
197-
int exitVal = process.waitFor();
198-
if (exitVal == 0) {
199-
System.out.println("Successfully downloaded word2vec model!");
200-
} else {
201-
System.out.println("Download failed. Please download model manually and set the \"wordVectorsPath\" in the code with the path to it.");
202-
System.exit(0);
203-
}
204-
} catch (IOException e) {
205-
e.printStackTrace();
206-
} catch (InterruptedException e) {
207-
e.printStackTrace();
208-
}
192+
Downloader.download("Word2Vec", new URL("https://dl4jdata.blob.core.windows.net/resources/wordvectors/GoogleNews-vectors-negative300.bin.gz"), new File(wordVectorsPath), md5w2vec, 5);
193+
System.out.println("Successfully downloaded word2vec model to " + wordVectorsPath);
209194
}
210195
}
211196

dl4j-examples/src/main/resources/w2vecdownload/word2vec-download300model.sh

Lines changed: 0 additions & 14 deletions
This file was deleted.

0 commit comments

Comments
 (0)