Skip to content

Commit 9aebb91

Browse files
committed
cbt argument added
1 parent c6fbc46 commit 9aebb91

File tree

2 files changed

+23
-28
lines changed

2 files changed

+23
-28
lines changed

src/processing/MalletCalculator.java

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@ License, or (at your option) any later version.
1919
*/
2020
package processing;
2121

22-
import java.io.IOException;
2322
import java.util.ArrayList;
24-
import java.util.Collection;
2523
import java.util.LinkedHashMap;
2624
import java.util.List;
2725
import java.util.Map;
@@ -30,7 +28,6 @@ License, or (at your option) any later version.
3028
import java.util.TreeMap;
3129
import java.util.TreeSet;
3230
import java.util.concurrent.TimeUnit;
33-
import java.util.regex.Pattern;
3431

3532
import com.google.common.base.Stopwatch;
3633
import com.google.common.primitives.Doubles;
@@ -39,23 +36,12 @@ License, or (at your option) any later version.
3936
import common.DoubleMapComparator;
4037
import common.Bookmark;
4138
import common.Utilities;
42-
import cc.mallet.pipe.Array2FeatureVector;
43-
import cc.mallet.pipe.CharSequence2TokenSequence;
44-
import cc.mallet.pipe.CharSequenceArray2TokenSequence;
45-
import cc.mallet.pipe.CharSequenceLowercase;
46-
import cc.mallet.pipe.Pipe;
47-
import cc.mallet.pipe.PrintInputAndTarget;
48-
import cc.mallet.pipe.SerialPipes;
4939
import cc.mallet.pipe.StringList2FeatureSequence;
50-
import cc.mallet.pipe.TokenSequence2FeatureSequence;
51-
import cc.mallet.pipe.TokenSequenceLowercase;
5240
import cc.mallet.topics.ParallelTopicModel;
5341
import cc.mallet.types.Alphabet;
54-
import cc.mallet.types.FeatureSequence;
5542
import cc.mallet.types.IDSorter;
5643
import cc.mallet.types.Instance;
5744
import cc.mallet.types.InstanceList;
58-
import cc.mallet.types.TokenSequence;
5945
import file.PredictionFileWriter;
6046
import file.BookmarkReader;
6147
import file.BookmarkSplitter;
@@ -64,7 +50,7 @@ public class MalletCalculator {
6450

6551
private final static int MAX_RECOMMENDATIONS = 10;
6652
private final static int MAX_TERMS = 100;
67-
private final static int NUM_THREADS = 10;
53+
//private final static int NUM_THREADS = 10;
6854
private final static int NUM_ITERATIONS = 2000;
6955
private final static double ALPHA = 0.01;
7056
private final static double BETA = 0.01;
@@ -105,10 +91,10 @@ private List<Map<Integer, Double>> getMaxTopicsByDocs(ParallelTopicModel LDA, in
10591
double[] topicProbs = LDA.getTopicProbabilities(doc);
10692
//double probSum = 0.0;
10793
for (int topic = 0; topic < topicProbs.length && topic < maxTopicsPerDoc; topic++) {
108-
//if (topicProbs[topic] > 0.01) { // TODO
94+
if (topicProbs[topic] > TOPIC_THRESHOLD) { // TODO
10995
topicList.put(topic, topicProbs[topic]);
11096
//probSum += topicProbs[topic];
111-
//}
97+
}
11298
}
11399
//System.out.println("Topic Sum: " + probSum);
114100
Map<Integer, Double> sortedTopicList = new TreeMap<Integer, Double>(new DoubleMapComparator(topicList));
@@ -151,7 +137,7 @@ private List<Map<Integer, Double>> getMaxTermsByTopics(ParallelTopicModel LDA, i
151137
return topicList;
152138
}
153139

154-
public void predictValuesProbs() {
140+
public void predictValuesProbs(boolean topicCreation) {
155141
ParallelTopicModel LDA = new ParallelTopicModel(this.numTopics, ALPHA * this.numTopics, BETA); // TODO
156142
LDA.addInstances(this.instances);
157143
LDA.setNumThreads(1);
@@ -164,20 +150,25 @@ public void predictValuesProbs() {
164150
}
165151
this.docList = getMaxTopicsByDocs(LDA, this.numTopics);
166152
System.out.println("Fetched Doc-List");
167-
this.topicList = getMaxTermsByTopics(LDA, MAX_TERMS);
153+
this.topicList = !topicCreation ? getMaxTermsByTopics(LDA, MAX_TERMS) : null;
168154
System.out.println("Fetched Topic-List");
169155
}
170156

171157
public Map<Integer, Double> getValueProbsForID(int id, boolean topicCreation) {
172158
Map<Integer, Double> terms = null;
173159
if (id < this.docList.size()) {
174-
terms = new LinkedHashMap<Integer, Double>();
175160
Map<Integer, Double> docVals = this.docList.get(id);
161+
if (this.topicList == null) {
162+
return docVals;
163+
}
164+
terms = new LinkedHashMap<Integer, Double>();
165+
176166
for (Map.Entry<Integer, Double> topic : docVals.entrySet()) { // look at each assigned topic
177167
Set<Entry<Integer, Double>> entrySet = this.topicList.get(topic.getKey()).entrySet();
178168
double topicProb = topic.getValue();
179169
for (Map.Entry<Integer, Double> entry : entrySet) { // and its terms
180170
if (topicCreation) {
171+
// DEPRECATED
181172
if (topicProb > TOPIC_THRESHOLD) {
182173
terms.put(entry.getKey(), topicProb);
183174
break; // only use first tag as topic-name with the topic probability
@@ -301,7 +292,7 @@ public static List<Map<Integer, Double>> startLdaCreation(BookmarkReader reader,
301292
if (userBased) {
302293
userMaps = Utilities.getUserMaps(reader.getBookmarks().subList(0, trainSize));
303294
userCalc = new MalletCalculator(userMaps, numTopics);
304-
userCalc.predictValuesProbs();
295+
userCalc.predictValuesProbs(topicCreation);
305296
//userDenoms = getDenoms(userPredictionValues);
306297
System.out.println("User-Training finished");
307298
}
@@ -311,7 +302,7 @@ public static List<Map<Integer, Double>> startLdaCreation(BookmarkReader reader,
311302
if (resBased) {
312303
resMaps = Utilities.getResMaps(reader.getBookmarks().subList(0, trainSize));
313304
resCalc = new MalletCalculator(resMaps, numTopics);
314-
resCalc.predictValuesProbs();
305+
resCalc.predictValuesProbs(topicCreation);
315306
//resDenoms = getDenoms(resPredictionValues);
316307
System.out.println("Res-Training finished");
317308
}

src/test/Pipeline.java

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ public class Pipeline {
6565
// set for categorizer/describer split (true is describer, false is categorizer - null for nothing)
6666
private final static Boolean DESCRIBER = null;
6767
// placeholder for the topic posfix
68-
private final static String TOPIC_NAME = null;
68+
private static String TOPIC_NAME = null;
6969
// placeholder for the used dataset
70-
private final static String DATASET = "ml";
70+
private final static String DATASET = "cul";
7171

7272
public static void main(String[] args) {
7373
System.out.println("TagRecommender:\n" + "" +
@@ -90,10 +90,11 @@ public static void main(String[] args) {
9090
String path = dir + "/" + DATASET + "_sample";
9191
//getStatistics(path);
9292
//writeTensorFiles(path, false);
93-
//evaluate(dir, path, "wrmf_mml", null, false, true);
94-
//createLdaSamples(path, 1, 100, false);
93+
//evaluate(dir, path, "wrmf_500_mml", TOPIC_NAME, false, true);
94+
//createLdaSamples(path, 1, 500, false);
9595
//startCfResourceCalculator(dir, path, 1, 20, true, false, false, false, Features.ENTITIES);
9696
//startCfResourceCalculator(dir, path, 1, 20, false, true, true, false, Features.ENTITIES);
97+
//startCfResourceCalculator(dir, path, 1, 20, false, true, false, false, Features.TOPICS);
9798
//startResourceCIRTTCalculator(dir, path, "", 1, 20, Features.ENTITIES, false, true, false, true);
9899
//startBaselineCalculatorForResources(dir, path, 1, false);
99100

@@ -253,6 +254,9 @@ public static void main(String[] args) {
253254
} else if (op.equals("item_cfb")) {
254255
boolean userBased = true, resourceBased = false, allResources = false;
255256
startCfResourceCalculator(sampleDir, samplePath, sampleCount, 20, userBased, resourceBased, allResources, false, Features.ENTITIES);
257+
} else if (op.equals("item_cbt")) {
258+
TOPIC_NAME = "lda_500";
259+
startCfResourceCalculator(dir, path, 1, 20, false, true, false, false, Features.TOPICS);
256260
} else if (op.equals("item_zheng")) {
257261
startZhengResourceCalculator(sampleDir, samplePath, sampleCount);
258262
} else if (op.equals("item_huang")) {
@@ -501,9 +505,9 @@ private static void getTrainTestSize(String sample) {
501505

502506
// passing the trainSize means that MyMediaLite files will be evaluated
503507
private static void evaluate(String sampleDir, String sampleName, String prefix, String postfix, boolean calcTags, boolean mymedialite) {
504-
getTrainTestSize(sampleName);
508+
getTrainTestSize(sampleName + (postfix != null ? "_" + postfix : ""));
505509
BookmarkReader reader = new BookmarkReader(TRAIN_SIZE, false);
506-
reader.readFile(sampleName);
510+
reader.readFile(sampleName + (postfix != null ? "_" + postfix : ""));
507511
if (calcTags) {
508512
writeMetrics(sampleDir, sampleName, prefix, 1, 10, postfix, reader);
509513
} else {

0 commit comments

Comments
 (0)