Skip to content

Commit a66ad14

Browse files
committed
processnews sample.
Signed-off-by: Robert Altena <[email protected]>
1 parent fe1560e commit a66ad14

File tree

4 files changed

+82
-111
lines changed

4 files changed

+82
-111
lines changed

dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/processnews/NewsIterator.java

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -14,21 +14,6 @@
1414
* SPDX-License-Identifier: Apache-2.0
1515
******************************************************************************/
1616

17-
/**-
18-
* This is a DataSetIterator that is specialized for the News headlines dataset used in the TrainNews example
19-
* It takes either the train or test set data from this data set, plus a WordVectors object generated by
20-
* PrepareWordVector.java program and generates training data sets.<br>
21-
* Inputs/features: variable-length time series, where each word (with unknown words removed) is represented by
22-
* its Word2Vec vector representation.<br>
23-
* Labels/target: a single class (representing category, i.e. 0,1,2 etc. depending on content of categories.txt
24-
* file mentioned in TrainNews.java program.
25-
* <p>
26-
* Note :
27-
* - This program is a modification of original example named SentimentExampleIterator.java
28-
* - more details is given with each function's comments in the code
29-
* <p>
30-
* <b>KIT Solutions Pvt. Ltd. (www.kitsol.com)</b>
31-
*/
3217

3318

3419
package org.deeplearning4j.examples.recurrent.processnews;
@@ -48,13 +33,31 @@
4833
import java.io.File;
4934
import java.io.FileReader;
5035
import java.io.IOException;
36+
import java.nio.charset.Charset;
5137
import java.util.ArrayList;
5238
import java.util.List;
5339
import java.util.NoSuchElementException;
5440

5541
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
5642
import static org.nd4j.linalg.indexing.NDArrayIndex.point;
5743

44+
/**-
45+
* This is a DataSetIterator that is specialized for the News headlines dataset used in the TrainNews example
46+
* It takes either the train or test set data from this data set, plus a WordVectors object generated by
47+
* PrepareWordVector.java program and generates training data sets.<br>
48+
* Inputs/features: variable-length time series, where each word (with unknown words removed) is represented by
49+
* its Word2Vec vector representation.<br>
50+
* Labels/target: a single class (representing category, i.e. 0,1,2 etc. depending on content of categories.txt
51+
* file mentioned in TrainNews.java program.
52+
* <p>
53+
* Note :
54+
* - This program is a modification of original example named SentimentExampleIterator.java
55+
* - more details is given with each function's comments in the code
56+
* <p>
57+
* <b>KIT Solutions Pvt. Ltd. (www.kitsol.com)</b>
58+
*/
59+
@SuppressWarnings({"unused", "WeakerAccess"})
60+
//people may want to use this class outside this example. keeping unused methods and public access.
5861
public class NewsIterator implements DataSetIterator {
5962
private final WordVectors wordVectors;
6063
private final int batchSize;
@@ -95,8 +98,8 @@ private NewsIterator(String dataDirectory,
9598
this.tokenizerFactory = tokenizerFactory;
9699
this.populateData(train);
97100
this.labels = new ArrayList<>();
98-
for (int i = 0; i < this.categoryData.size(); i++) {
99-
this.labels.add(this.categoryData.get(i).getKey().split(",")[1]);
101+
for (Pair<String, List<String>> categoryDatum : this.categoryData) {
102+
this.labels.add(categoryDatum.getKey().split(",")[1]);
100103
}
101104
}
102105

@@ -108,14 +111,10 @@ public static Builder Builder() {
108111
@Override
109112
public DataSet next(int num) {
110113
if (cursor >= this.totalNews) throw new NoSuchElementException();
111-
try {
112-
return nextDataSet(num);
113-
} catch (IOException e) {
114-
throw new RuntimeException(e);
115-
}
114+
return nextDataSet(num);
116115
}
117116

118-
private DataSet nextDataSet(int num) throws IOException {
117+
private DataSet nextDataSet(int num) {
119118
// Loads news into news list from categoryData List along with category of each news
120119
List<String> news = new ArrayList<>(num);
121120
int[] category = new int[num];
@@ -181,8 +180,7 @@ private DataSet nextDataSet(int num) throws IOException {
181180
labelsMask.putScalar(new int[]{i, lastIdx - 1}, 1.0);
182181
}
183182

184-
DataSet ds = new DataSet(features, labels, featuresMask, labelsMask);
185-
return ds;
183+
return new DataSet(features, labels, featuresMask, labelsMask);
186184
}
187185

188186
/**
@@ -194,7 +192,7 @@ private DataSet nextDataSet(int num) throws IOException {
194192
* @throws IOException If file cannot be read
195193
*/
196194
public INDArray loadFeaturesFromFile(File file, int maxLength) throws IOException {
197-
String news = FileUtils.readFileToString(file);
195+
String news = FileUtils.readFileToString(file, (Charset)null);
198196
return loadFeaturesFromString(news, maxLength);
199197
}
200198

@@ -233,14 +231,14 @@ private void populateData(boolean train) {
233231
File categories = new File(this.dataDirectory + File.separator + "categories.txt");
234232

235233
try (BufferedReader brCategories = new BufferedReader(new FileReader(categories))) {
236-
String temp = "";
234+
String temp;
237235
while ((temp = brCategories.readLine()) != null) {
238-
String curFileName = train == true ?
236+
String curFileName = train ?
239237
this.dataDirectory + File.separator + "train" + File.separator + temp.split(",")[0] + ".txt" :
240238
this.dataDirectory + File.separator + "test" + File.separator + temp.split(",")[0] + ".txt";
241239
File currFile = new File(curFileName);
242240
BufferedReader currBR = new BufferedReader((new FileReader(currFile)));
243-
String tempCurrLine = "";
241+
String tempCurrLine;
244242
List<String> tempList = new ArrayList<>();
245243
while ((tempCurrLine = currBR.readLine()) != null) {
246244
tempList.add(tempCurrLine);
@@ -250,7 +248,6 @@ private void populateData(boolean train) {
250248
Pair<String, List<String>> tempPair = Pair.of(temp, tempList);
251249
this.categoryData.add(tempPair);
252250
}
253-
brCategories.close();
254251
} catch (Exception e) {
255252
System.out.println("Exception in reading file :" + e.getMessage());
256253
}

dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/processnews/PrepareWordVector.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -79,6 +79,7 @@ public static void main(String[] args) throws Exception {
7979
log.info("Writing word vectors to text file....");
8080

8181
// Write word vectors to file
82+
//noinspection unchecked
8283
WordVectorSerializer.writeWordVectors(vec.lookupTable(), new File(dataLocalPath, "NewsWordVector.txt").getAbsolutePath());
8384
}
8485
}

dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/processnews/TestNews.java

Lines changed: 26 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -14,14 +14,6 @@
1414
* SPDX-License-Identifier: Apache-2.0
1515
******************************************************************************/
1616

17-
/**-
18-
* This is a test program that uses word vector and trained network generated by PrepareWordVector.java and TrainNews.java
19-
* - Type or copy/paste news headline from news (indian news channel is preferred) and click on Check button
20-
* and see the predicted category right to the Check button
21-
* <p>
22-
* <b></b>KIT Solutions Pvt. Ltd. (www.kitsol.com)</b>
23-
*/
24-
2517
package org.deeplearning4j.examples.recurrent.processnews;
2618

2719
import org.deeplearning4j.examples.download.DownloaderUtility;
@@ -37,6 +29,7 @@
3729
import org.nd4j.linalg.indexing.INDArrayIndex;
3830
import org.nd4j.linalg.indexing.NDArrayIndex;
3931

32+
import javax.swing.*;
4033
import java.io.BufferedReader;
4134
import java.io.File;
4235
import java.io.FileReader;
@@ -45,22 +38,23 @@
4538
import java.util.logging.Level;
4639
import java.util.logging.Logger;
4740

41+
/**-
42+
* This is a test program that uses word vector and trained network generated by PrepareWordVector.java and TrainNews.java
43+
* - Type or copy/paste news headline from news (indian news channel is preferred) and click on Check button
44+
* and see the predicted category right to the Check button
45+
* <p>
46+
* <b></b>KIT Solutions Pvt. Ltd. (www.kitsol.com)</b>
47+
*/
4848
public class TestNews extends javax.swing.JFrame {
4949
private static WordVectors wordVectors;
5050
private static TokenizerFactory tokenizerFactory;
51-
private static int maxLength = 8;
5251

53-
// Variables declaration - do not modify
54-
private javax.swing.JButton jButton1;
55-
private javax.swing.JLabel jLabel1;
56-
private javax.swing.JLabel jLabel2;
5752
private javax.swing.JLabel jLabel3;
58-
private javax.swing.JScrollPane jScrollPane1;
5953
private javax.swing.JTextArea jTextArea1;
6054
private static MultiLayerNetwork net;
6155
private static String dataLocalPath;
6256

63-
public TestNews() throws Exception {
57+
private TestNews() throws Exception {
6458
initComponents();
6559
dataLocalPath = DownloaderUtility.NEWSDATA.Download();
6660
}
@@ -70,16 +64,16 @@ public TestNews() throws Exception {
7064
* WARNING: Do NOT modify this code. The content of this method is always
7165
* regenerated by the Form Editor.
7266
*/
73-
@SuppressWarnings("unchecked")
7467
// <editor-fold defaultstate="collapsed" desc="Generated Code">
7568
private void initComponents() {
7669

7770
this.setTitle("Predict News Category - KITS");
78-
jLabel1 = new javax.swing.JLabel();
79-
jScrollPane1 = new javax.swing.JScrollPane();
71+
javax.swing.JLabel jLabel1 = new javax.swing.JLabel();
72+
javax.swing.JScrollPane jScrollPane1 = new javax.swing.JScrollPane();
8073
jTextArea1 = new javax.swing.JTextArea();
81-
jButton1 = new javax.swing.JButton();
82-
jLabel2 = new javax.swing.JLabel();
74+
// Variables declaration - do not modify
75+
javax.swing.JButton jButton1 = new javax.swing.JButton();
76+
javax.swing.JLabel jLabel2 = new javax.swing.JLabel();
8377
jLabel3 = new javax.swing.JLabel();
8478

8579
setDefaultCloseOperation(javax.swing.WindowConstants.EXIT_ON_CLOSE);
@@ -91,11 +85,7 @@ private void initComponents() {
9185
jScrollPane1.setViewportView(jTextArea1);
9286

9387
jButton1.setText("Check");
94-
jButton1.addActionListener(new java.awt.event.ActionListener() {
95-
public void actionPerformed(java.awt.event.ActionEvent evt) {
96-
jButton1ActionPerformed(evt);
97-
}
98-
});
88+
jButton1.addActionListener(this::jButton1ActionPerformed);
9989

10090
jLabel2.setText("Category");
10191

@@ -142,10 +132,6 @@ private void jButton1ActionPerformed(java.awt.event.ActionEvent evt) {
142132
INDArray fet = testNews.getFeatures();
143133
INDArray predicted = net.output(fet, false);
144134
long[] arrsiz = predicted.shape();
145-
double crimeTotal = 0;
146-
double politicsTotal = 0;
147-
double bollywoodTotal = 0;
148-
double developmentTotal = 0;
149135

150136
File categories = new File(dataLocalPath, "LabelledNews/categories.txt");
151137

@@ -159,7 +145,7 @@ private void jButton1ActionPerformed(java.awt.event.ActionEvent evt) {
159145
}
160146

161147
try (BufferedReader brCategories = new BufferedReader(new FileReader(categories))) {
162-
String temp = "";
148+
String temp;
163149
List<String> labels = new ArrayList<>();
164150
while ((temp = brCategories.readLine()) != null) {
165151
labels.add(temp);
@@ -171,7 +157,7 @@ private void jButton1ActionPerformed(java.awt.event.ActionEvent evt) {
171157
}
172158
}
173159

174-
public static void main(String args[]) throws Exception{
160+
public static void main(String[] args) throws Exception{
175161

176162
try {
177163
for (javax.swing.UIManager.LookAndFeelInfo info : javax.swing.UIManager.getInstalledLookAndFeels()) {
@@ -180,32 +166,23 @@ public static void main(String args[]) throws Exception{
180166
break;
181167
}
182168
}
183-
} catch (ClassNotFoundException ex) {
184-
Logger.getLogger(TestNews.class.getName()).log(Level.SEVERE, null, ex);
185-
} catch (InstantiationException ex) {
186-
Logger.getLogger(TestNews.class.getName()).log(Level.SEVERE, null, ex);
187-
} catch (IllegalAccessException ex) {
188-
Logger.getLogger(TestNews.class.getName()).log(Level.SEVERE, null, ex);
189-
} catch (javax.swing.UnsupportedLookAndFeelException ex) {
169+
} catch (ClassNotFoundException | InstantiationException | IllegalAccessException | UnsupportedLookAndFeelException ex) {
190170
Logger.getLogger(TestNews.class.getName()).log(Level.SEVERE, null, ex);
191171
}
192172
TestNews test = new TestNews();
193173
test.setVisible(true);
194174

195-
try {
196-
tokenizerFactory = new DefaultTokenizerFactory();
197-
tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor());
198-
net = MultiLayerNetwork.load(new File(dataLocalPath,"NewsModel.net"), true);
199-
wordVectors = WordVectorSerializer.readWord2VecModel(new File(dataLocalPath,"NewsWordVector.txt"));
200-
} catch (Exception e) {
201-
202-
}
175+
tokenizerFactory = new DefaultTokenizerFactory();
176+
tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor());
177+
net = MultiLayerNetwork.load(new File(dataLocalPath,"NewsModel.net"), true);
178+
wordVectors = WordVectorSerializer.readWord2VecModel(new File(dataLocalPath,"NewsWordVector.txt"));
203179
}
204180

181+
// One news story gets transformed into a dataset with one element.
182+
@SuppressWarnings("DuplicatedCode")
205183
private static DataSet prepareTestData(String i_news) {
206184
List<String> news = new ArrayList<>(1);
207185
int[] category = new int[1];
208-
int currCategory = 0;
209186
news.add(i_news);
210187

211188
List<List<String>> allTokens = new ArrayList<>(news.size());
@@ -246,7 +223,6 @@ private static DataSet prepareTestData(String i_news) {
246223
labelsMask.putScalar(new int[]{i, lastIdx - 1}, 1.0);
247224
}
248225

249-
DataSet ds = new DataSet(features, labels, featuresMask, labelsMask);
250-
return ds;
226+
return new DataSet(features, labels, featuresMask, labelsMask);
251227
}
252228
}

dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/processnews/TrainNews.java

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*******************************************************************************
1+
/* *****************************************************************************
22
* Copyright (c) 2015-2019 Skymind, Inc.
33
*
44
* This program and the accompanying materials are made available under the
@@ -14,6 +14,31 @@
1414
* SPDX-License-Identifier: Apache-2.0
1515
******************************************************************************/
1616

17+
package org.deeplearning4j.examples.recurrent.processnews;
18+
19+
import org.deeplearning4j.examples.download.DownloaderUtility;
20+
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
21+
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
22+
import org.deeplearning4j.nn.conf.GradientNormalization;
23+
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
24+
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
25+
import org.deeplearning4j.nn.conf.layers.LSTM;
26+
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
27+
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
28+
import org.deeplearning4j.nn.weights.WeightInit;
29+
import org.deeplearning4j.optimize.api.InvocationType;
30+
import org.deeplearning4j.optimize.listeners.EvaluativeListener;
31+
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
32+
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
33+
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
34+
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
35+
import org.nd4j.evaluation.classification.Evaluation;
36+
import org.nd4j.linalg.activations.Activation;
37+
import org.nd4j.linalg.learning.config.RmsProp;
38+
import org.nd4j.linalg.lossfunctions.LossFunctions;
39+
40+
import java.io.File;
41+
1742
/**-
1843
* This program trains a RNN to predict category of a news headlines. It uses word vector generated from PrepareWordVector.java.
1944
* - Labeled News are stored in \dl4j-examples\src\main\resources\NewsData\LabelledNews folder in train and test folders.
@@ -48,37 +73,9 @@
4873
* <p>
4974
* <b>KIT Solutions Pvt. Ltd. (www.kitsol.com)</b>
5075
*/
51-
52-
package org.deeplearning4j.examples.recurrent.processnews;
53-
54-
import org.deeplearning4j.eval.Evaluation;
55-
import org.deeplearning4j.examples.download.DownloaderUtility;
56-
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
57-
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
58-
import org.deeplearning4j.nn.conf.GradientNormalization;
59-
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
60-
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
61-
import org.deeplearning4j.nn.conf.layers.LSTM;
62-
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
63-
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
64-
import org.deeplearning4j.nn.weights.WeightInit;
65-
import org.deeplearning4j.optimize.api.InvocationType;
66-
import org.deeplearning4j.optimize.listeners.EvaluativeListener;
67-
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
68-
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
69-
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
70-
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
71-
import org.nd4j.linalg.activations.Activation;
72-
import org.nd4j.linalg.learning.config.RmsProp;
73-
import org.nd4j.linalg.lossfunctions.LossFunctions;
74-
75-
import java.io.File;
76-
77-
7876
public class TrainNews {
7977
public static String DATA_PATH = "";
8078
public static WordVectors wordVectors;
81-
private static TokenizerFactory tokenizerFactory;
8279

8380
public static void main(String[] args) throws Exception {
8481
String dataLocalPath = DownloaderUtility.NEWSDATA.Download();

0 commit comments

Comments
 (0)