35
35
import org .nd4j .linalg .api .ndarray .INDArray ;
36
36
import org .nd4j .linalg .factory .Nd4j ;
37
37
import org .nd4j .linalg .indexing .NDArrayIndex ;
38
- import org .nd4j .linalg .io .ClassPathResource ;
39
38
import org .nd4j .linalg .learning .config .Adam ;
40
39
import org .nd4j .linalg .lossfunctions .LossFunctions ;
41
40
import org .nd4j .resources .Downloader ;
@@ -98,17 +97,17 @@ public static void main(String[] args) throws Exception {
98
97
99
98
//Set up network configuration
100
99
MultiLayerConfiguration conf = new NeuralNetConfiguration .Builder ()
101
- .seed (seed )
102
- .updater (new Adam (5e-3 ))
103
- .l2 (1e-5 )
104
- .weightInit (WeightInit .XAVIER )
105
- .gradientNormalization (GradientNormalization .ClipElementWiseAbsoluteValue ).gradientNormalizationThreshold (1.0 )
106
- .list ()
107
- .layer (new LSTM .Builder ().nIn (vectorSize ).nOut (256 )
108
- .activation (Activation .TANH ).build ())
109
- .layer (new RnnOutputLayer .Builder ().activation (Activation .SOFTMAX )
110
- .lossFunction (LossFunctions .LossFunction .MCXENT ).nIn (256 ).nOut (2 ).build ())
111
- .build ();
100
+ .seed (seed )
101
+ .updater (new Adam (5e-3 ))
102
+ .l2 (1e-5 )
103
+ .weightInit (WeightInit .XAVIER )
104
+ .gradientNormalization (GradientNormalization .ClipElementWiseAbsoluteValue ).gradientNormalizationThreshold (1.0 )
105
+ .list ()
106
+ .layer (new LSTM .Builder ().nIn (vectorSize ).nOut (256 )
107
+ .activation (Activation .TANH ).build ())
108
+ .layer (new RnnOutputLayer .Builder ().activation (Activation .SOFTMAX )
109
+ .lossFunction (LossFunctions .LossFunction .MCXENT ).nIn (256 ).nOut (2 ).build ())
110
+ .build ();
112
111
113
112
MultiLayerNetwork net = new MultiLayerNetwork (conf );
114
113
net .init ();
@@ -171,11 +170,12 @@ public static void downloadData() throws Exception {
171
170
172
171
public static void checkDownloadW2VECModel () throws IOException {
173
172
String defaultwordVectorsPath = FilenameUtils .concat (System .getProperty ("user.home" ), "dl4j-examples-data/w2vec300" );
173
+ String md5w2vec = "1c892c4707a8a1a508b01a01735c0339" ;
174
174
wordVectorsPath = new File (defaultwordVectorsPath , "GoogleNews-vectors-negative300.bin.gz" ).getAbsolutePath ();
175
175
if (new File (wordVectorsPath ).exists ()) {
176
176
System .out .println ("\n \t GoogleNews-vectors-negative300.bin.gz file found at path: " + defaultwordVectorsPath );
177
177
System .out .println ("\t Checking md5 of existing file.." );
178
- if (Downloader .checkMD5OfFile ("1c892c4707a8a1a508b01a01735c0339" , new File (wordVectorsPath ))) {
178
+ if (Downloader .checkMD5OfFile (md5w2vec , new File (wordVectorsPath ))) {
179
179
System .out .println ("\t Existing file hash matches." );
180
180
return ;
181
181
} else {
@@ -189,23 +189,8 @@ public static void checkDownloadW2VECModel() throws IOException {
189
189
Scanner scanner = new Scanner (System .in );
190
190
scanner .nextLine ();
191
191
System .out .println ("Starting model download (1.5GB!)..." );
192
- String downloadScript = new ClassPathResource ("w2vecdownload/word2vec-download300model.sh" ).getFile ().getAbsolutePath ();
193
- ProcessBuilder processBuilder = new ProcessBuilder (downloadScript , defaultwordVectorsPath );
194
- try {
195
- processBuilder .inheritIO ();
196
- Process process = processBuilder .start ();
197
- int exitVal = process .waitFor ();
198
- if (exitVal == 0 ) {
199
- System .out .println ("Successfully downloaded word2vec model!" );
200
- } else {
201
- System .out .println ("Download failed. Please download model manually and set the \" wordVectorsPath\" in the code with the path to it." );
202
- System .exit (0 );
203
- }
204
- } catch (IOException e ) {
205
- e .printStackTrace ();
206
- } catch (InterruptedException e ) {
207
- e .printStackTrace ();
208
- }
192
+ Downloader .download ("Word2Vec" , new URL ("https://dl4jdata.blob.core.windows.net/resources/wordvectors/GoogleNews-vectors-negative300.bin.gz" ), new File (wordVectorsPath ), md5w2vec , 5 );
193
+ System .out .println ("Successfully downloaded word2vec model to " + wordVectorsPath );
209
194
}
210
195
}
211
196
0 commit comments