Skip to content

Commit 12e0d01

Browse files
committed
OPENNLP-855: Clean up Sentiment Analysis implementation and add tests
- Remove broken sequence labeling code (find/predict2) from SentimentME and SentimentDetector; sentiment is a classification task, not sequence labeling - Remove getSentimentModel() from SentimentModel (wrapped MaxentModel in unused BeamSearch) - Add toString/equals/hashCode to SentimentSample, remove unused id field - Fix SentimentSampleTypeFilter to actually filter by sentiment type - Fix SentimentDetailedFMeasureListener.asSpanArray() returning null - Remove dead detailedFListener code in CLI tools - Add 44 unit tests covering all runtime sentiment classes
1 parent 5e57dee commit 12e0d01

File tree

12 files changed

+116
-238
lines changed

12 files changed

+116
-238
lines changed

opennlp-api/src/main/java/opennlp/tools/sentiment/SentimentDetector.java

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
package opennlp.tools.sentiment;
1919

20-
import opennlp.tools.util.Span;
21-
2220
public interface SentimentDetector {
2321

2422
/**
@@ -36,30 +34,4 @@ public interface SentimentDetector {
3634
* @return The predicted sentiment.
3735
*/
3836
String predict(String[] tokens);
39-
40-
/**
41-
* Generates sentiment tags for the given sequence, typically a sentence,
42-
* returning token spans for any identified sentiments.
43-
*
44-
* @param tokens
45-
* an array of the tokens or words of the sequence, typically a
46-
* sentence
47-
* @return an array of spans for each of the names identified.
48-
*/
49-
Span[] find(String[] tokens);
50-
51-
/**
52-
* Generates sentiment tags for the given sequence, typically a sentence,
53-
* returning token spans for any identified sentiments.
54-
*
55-
* @param tokens
56-
* an array of the tokens or words of the sequence, typically a
57-
* sentence.
58-
* @param additionalContext
59-
* features which are based on context outside of the sentence but
60-
* which should also be used.
61-
*
62-
* @return an array of spans for each of the names identified.
63-
*/
64-
Span[] find(String[] tokens, String[][] additionalContext);
6537
}

opennlp-api/src/main/java/opennlp/tools/sentiment/SentimentSample.java

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import java.io.Serial;
2121
import java.util.List;
22+
import java.util.Objects;
2223

2324
import opennlp.tools.commons.Sample;
2425

@@ -33,7 +34,6 @@ public class SentimentSample implements Sample {
3334
private final String sentiment;
3435
private final List<String> sentence;
3536
private final boolean isClearAdaptiveData;
36-
private final String id = null;
3737

3838
/**
3939
* Instantiates a {@link SentimentSample} object.
@@ -75,15 +75,34 @@ public String[] getSentence() {
7575
return sentence.toArray(new String[0]);
7676
}
7777

78-
public String getId() {
79-
return id;
80-
}
81-
8278
/**
8379
* @return Returns the value of isClearAdaptiveData, {@code true} or {@code false}.
8480
*/
8581
public boolean isClearAdaptiveDataSet() {
8682
return isClearAdaptiveData;
8783
}
8884

85+
@Override
86+
public String toString() {
87+
return sentiment + " " + String.join(" ", sentence);
88+
}
89+
90+
@Override
91+
public boolean equals(Object obj) {
92+
if (this == obj) {
93+
return true;
94+
}
95+
if (obj == null || getClass() != obj.getClass()) {
96+
return false;
97+
}
98+
SentimentSample that = (SentimentSample) obj;
99+
return Objects.equals(sentiment, that.sentiment)
100+
&& Objects.equals(sentence, that.sentence);
101+
}
102+
103+
@Override
104+
public int hashCode() {
105+
return Objects.hash(sentiment, sentence);
106+
}
107+
89108
}

opennlp-core/opennlp-cli/src/main/java/opennlp/tools/cmdline/sentiment/SentimentCrossValidatorTool.java

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ public void run(String format, String[] args) {
8383
if (params.getMisclassified()) {
8484
listeners.add(new SentimentEvaluationErrorListener());
8585
}
86-
SentimentDetailedFMeasureListener detailedFListener = null;
8786
SentimentFactory sentimentFactory = new SentimentFactory();
8887

8988
SentimentCrossValidator validator;
@@ -107,12 +106,7 @@ public void run(String format, String[] args) {
107106
System.out.println("done");
108107

109108
System.out.println();
110-
111-
if (detailedFListener == null) {
112-
System.out.println(validator.getFMeasure());
113-
} else {
114-
System.out.println(detailedFListener.toString());
115-
}
109+
System.out.println(validator.getFMeasure());
116110
}
117111

118112
}

opennlp-core/opennlp-cli/src/main/java/opennlp/tools/cmdline/sentiment/SentimentDetailedFMeasureListener.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,6 @@ public class SentimentDetailedFMeasureListener
3838
*/
3939
@Override
4040
protected Span[] asSpanArray(SentimentSample sample) {
41-
return null;
41+
return new Span[] { new Span(0, 0, sample.getSentiment()) };
4242
}
4343
}

opennlp-core/opennlp-cli/src/main/java/opennlp/tools/cmdline/sentiment/SentimentEvaluatorTool.java

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,6 @@ public void run(String format, String[] args) {
8989
if (params.getMisclassified()) {
9090
listeners.add(new SentimentEvaluationErrorListener());
9191
}
92-
SentimentDetailedFMeasureListener detailedFListener = null;
93-
9492
if (params.getNameTypes() != null) {
9593
String[] nameTypes = params.getNameTypes().split(",");
9694
sampleStream = new SentimentSampleTypeFilter(nameTypes, sampleStream);
@@ -142,12 +140,7 @@ public void close() throws IOException {
142140
monitor.stopAndPrintFinalResult();
143141

144142
System.out.println();
145-
146-
if (detailedFListener == null) {
147-
System.out.println(evaluator.getFMeasure());
148-
} else {
149-
System.out.println(detailedFListener.toString());
150-
}
143+
System.out.println(evaluator.getFMeasure());
151144
}
152145

153146
}

opennlp-core/opennlp-runtime/src/main/java/opennlp/tools/sentiment/SentimentCrossValidator.java

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -204,18 +204,14 @@ public void evaluate(ObjectStream<SentimentSample> samples, int nFolds)
204204
CrossValidationPartitioner<DocumentSample> partitioner = new CrossValidationPartitioner<>(
205205
new SentimentToDocumentSampleStream(samples), nFolds);
206206

207-
SentimentModel model = null;
208-
209207
while (partitioner.hasNext()) {
210208

211209
CrossValidationPartitioner.TrainingSampleStream<DocumentSample> trainingSampleStream = partitioner
212210
.next();
213211

214-
if (factory != null) {
215-
model = SentimentME.train(languageCode,
216-
new DocumentToSentimentSampleStream(trainingSampleStream), params,
217-
factory);
218-
}
212+
SentimentModel model = SentimentME.train(languageCode,
213+
new DocumentToSentimentSampleStream(trainingSampleStream), params,
214+
factory);
219215

220216
// do testing
221217
SentimentEvaluator evaluator = new SentimentEvaluator(

opennlp-core/opennlp-runtime/src/main/java/opennlp/tools/sentiment/SentimentME.java

Lines changed: 11 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,14 @@
1919

2020
import java.io.IOException;
2121
import java.util.HashMap;
22-
import java.util.List;
2322
import java.util.Map;
2423

2524
import opennlp.tools.ml.EventTrainer;
2625
import opennlp.tools.ml.TrainerFactory;
2726
import opennlp.tools.ml.model.Event;
2827
import opennlp.tools.ml.model.MaxentModel;
29-
import opennlp.tools.ml.model.SequenceClassificationModel;
30-
import opennlp.tools.namefind.BioCodec;
3128
import opennlp.tools.util.ObjectStream;
32-
import opennlp.tools.util.Sequence;
33-
import opennlp.tools.util.SequenceCodec;
34-
import opennlp.tools.util.SequenceValidator;
35-
import opennlp.tools.util.Span;
3629
import opennlp.tools.util.TrainingParameters;
37-
import opennlp.tools.util.featuregen.AdaptiveFeatureGenerator;
38-
import opennlp.tools.util.featuregen.AdditionalContextFeatureGenerator;
3930

4031
/**
4132
* A {@link SentimentDetector} implementation for creating and using
@@ -45,22 +36,9 @@
4536
*/
4637
public class SentimentME implements SentimentDetector {
4738

48-
public static final int DEFAULT_BEAM_SIZE = 3;
49-
50-
private static final String[][] EMPTY = new String[0][0];
51-
52-
protected SentimentContextGenerator contextGenerator;
53-
54-
private final AdditionalContextFeatureGenerator additionalContextFeatureGenerator =
55-
new AdditionalContextFeatureGenerator();
56-
57-
private Sequence bestSequence;
58-
private SequenceValidator<String> sequenceValidator;
59-
private final SequenceClassificationModel model;
39+
private final SentimentContextGenerator contextGenerator;
6040
private final SentimentFactory factory;
6141
private final MaxentModel maxentModel;
62-
private final SequenceCodec<String> seqCodec = new BioCodec();
63-
private AdaptiveFeatureGenerator[] featureGenerators;
6442

6543
/**
6644
* Instantiates a {@link SentimentME} with the specified model.
@@ -73,7 +51,6 @@ public SentimentME(SentimentModel sentModel) {
7351
if (sentModel == null) {
7452
throw new IllegalArgumentException("SentimentModel must not be null!");
7553
}
76-
this.model = sentModel.getSentimentModel();
7754
maxentModel = sentModel.getMaxentModel();
7855
factory = sentModel.getFactory();
7956
contextGenerator = factory.createContextGenerator();
@@ -91,22 +68,22 @@ public SentimentME(SentimentModel sentModel) {
9168
* @param factory
9269
* a Sentiment Analysis factory
9370
* @return A valid {@link SentimentModel}.
71+
* @throws IOException Thrown if IO errors occurred during training.
9472
*/
9573
public static SentimentModel train(String languageCode, ObjectStream<SentimentSample> samples,
9674
TrainingParameters trainParams, SentimentFactory factory)
9775
throws IOException {
9876

9977
Map<String, String> entries = new HashMap<>();
100-
MaxentModel sentimentModel;
10178

10279
ObjectStream<Event> eventStream = new SentimentEventStream(samples, factory.createContextGenerator());
10380

10481
EventTrainer<TrainingParameters> trainer = TrainerFactory.getEventTrainer(trainParams, entries);
105-
sentimentModel = trainer.train(eventStream);
82+
MaxentModel sentimentModel = trainer.train(eventStream);
10683

107-
return new SentimentModel(languageCode, sentimentModel, new HashMap<>(), factory);
84+
return new SentimentModel(languageCode, sentimentModel, entries, factory);
10885
}
109-
86+
11087
@Override
11188
public String predict(String sentence) {
11289
String[] tokens = factory.getTokenizer().tokenize(sentence);
@@ -120,126 +97,22 @@ public String predict(String[] tokens) {
12097
}
12198

12299
/**
123-
* Returns the best chosen sentiment for the text predicted on
100+
* Returns the best chosen sentiment for the given probability distribution.
124101
*
125-
* @param outcome
126-
* the outcome
127-
* @return the best sentiment
102+
* @param outcome the probability distribution over outcomes.
103+
* @return the best sentiment label.
128104
*/
129105
public String getBestSentiment(double[] outcome) {
130106
return maxentModel.getBestOutcome(outcome);
131107
}
132108

133109
/**
134-
* Returns the analysis probabilities
110+
* Returns the probability distribution over sentiment labels for the given tokens.
135111
*
136-
* @param text
137-
* the text to categorize
112+
* @param text the tokens to classify.
113+
* @return the probability distribution over sentiment labels.
138114
*/
139115
public double[] probabilities(String[] text) {
140116
return maxentModel.eval(contextGenerator.getContext(text));
141117
}
142-
143-
/**
144-
* Returns an array of probabilities for each of the specified spans which is
145-
* the arithmetic mean of the probabilities for each of the outcomes which
146-
* make up the span.
147-
*
148-
* @param spans
149-
* The spans of the sentiments for which probabilities are desired.
150-
* @return an array of probabilities for each of the specified spans.
151-
*/
152-
public double[] probs(Span[] spans) {
153-
154-
double[] sprobs = new double[spans.length];
155-
double[] probs = bestSequence.getProbs();
156-
157-
for (int si = 0; si < spans.length; si++) {
158-
159-
double p = 0;
160-
161-
for (int oi = spans[si].getStart(); oi < spans[si].getEnd(); oi++) {
162-
p += probs[oi];
163-
}
164-
165-
p /= spans[si].length();
166-
167-
sprobs[si] = p;
168-
}
169-
170-
return sprobs;
171-
}
172-
173-
/**
174-
* Sets the probs for the spans
175-
*
176-
* @param spans
177-
* the spans to be analysed
178-
* @return the span of probs
179-
*/
180-
private Span[] setProbs(Span[] spans) {
181-
double[] probs = probs(spans);
182-
if (probs != null) {
183-
184-
for (int i = 0; i < probs.length; i++) {
185-
double prob = probs[i];
186-
spans[i] = new Span(spans[i], prob);
187-
}
188-
}
189-
return spans;
190-
}
191-
192-
@Override
193-
public Span[] find(String[] tokens) {
194-
return find(tokens, EMPTY);
195-
}
196-
197-
@Override
198-
public Span[] find(String[] tokens, String[][] additionalContext) {
199-
200-
additionalContextFeatureGenerator.setCurrentContext(additionalContext);
201-
202-
bestSequence = model.bestSequence(tokens, additionalContext,
203-
contextGenerator, sequenceValidator);
204-
205-
List<String> c = bestSequence.getOutcomes();
206-
207-
contextGenerator.updateAdaptiveData(tokens, c.toArray(new String[0]));
208-
Span[] spans = seqCodec.decode(c);
209-
spans = setProbs(spans);
210-
return spans;
211-
}
212-
213-
/**
214-
* Makes a sentiment prediction by calling the helper method
215-
*
216-
* @param tokens
217-
* the text to be analysed for its sentiment
218-
* @return the prediction made by the helper method
219-
*/
220-
public Span[] predict2(String[] tokens) {
221-
return predict2(tokens, EMPTY);
222-
}
223-
224-
/**
225-
* Makes a sentiment prediction
226-
*
227-
* @param tokens
228-
* the text to be analysed for its sentiment
229-
* @param additionalContext
230-
* any required additional context
231-
* @return the predictions
232-
*/
233-
public Span[] predict2(String[] tokens, String[][] additionalContext) {
234-
235-
additionalContextFeatureGenerator.setCurrentContext(additionalContext);
236-
237-
bestSequence = model.bestSequence(tokens, additionalContext,
238-
contextGenerator, sequenceValidator);
239-
240-
List<String> c = bestSequence.getOutcomes();
241-
242-
return seqCodec.decode(c);
243-
}
244-
245118
}

0 commit comments

Comments
 (0)