@@ -19,9 +19,7 @@ License, or (at your option) any later version.
1919 */
2020package processing ;
2121
22- import java .io .IOException ;
2322import java .util .ArrayList ;
24- import java .util .Collection ;
2523import java .util .LinkedHashMap ;
2624import java .util .List ;
2725import java .util .Map ;
@@ -30,7 +28,6 @@ License, or (at your option) any later version.
3028import java .util .TreeMap ;
3129import java .util .TreeSet ;
3230import java .util .concurrent .TimeUnit ;
33- import java .util .regex .Pattern ;
3431
3532import com .google .common .base .Stopwatch ;
3633import com .google .common .primitives .Doubles ;
@@ -39,23 +36,12 @@ License, or (at your option) any later version.
3936import common .DoubleMapComparator ;
4037import common .Bookmark ;
4138import common .Utilities ;
42- import cc .mallet .pipe .Array2FeatureVector ;
43- import cc .mallet .pipe .CharSequence2TokenSequence ;
44- import cc .mallet .pipe .CharSequenceArray2TokenSequence ;
45- import cc .mallet .pipe .CharSequenceLowercase ;
46- import cc .mallet .pipe .Pipe ;
47- import cc .mallet .pipe .PrintInputAndTarget ;
48- import cc .mallet .pipe .SerialPipes ;
4939import cc .mallet .pipe .StringList2FeatureSequence ;
50- import cc .mallet .pipe .TokenSequence2FeatureSequence ;
51- import cc .mallet .pipe .TokenSequenceLowercase ;
5240import cc .mallet .topics .ParallelTopicModel ;
5341import cc .mallet .types .Alphabet ;
54- import cc .mallet .types .FeatureSequence ;
5542import cc .mallet .types .IDSorter ;
5643import cc .mallet .types .Instance ;
5744import cc .mallet .types .InstanceList ;
58- import cc .mallet .types .TokenSequence ;
5945import file .PredictionFileWriter ;
6046import file .BookmarkReader ;
6147import file .BookmarkSplitter ;
@@ -64,7 +50,7 @@ public class MalletCalculator {
6450
6551 private final static int MAX_RECOMMENDATIONS = 10 ;
6652 private final static int MAX_TERMS = 100 ;
67- private final static int NUM_THREADS = 10 ;
53+ // private final static int NUM_THREADS = 10;
6854 private final static int NUM_ITERATIONS = 2000 ;
6955 private final static double ALPHA = 0.01 ;
7056 private final static double BETA = 0.01 ;
@@ -105,10 +91,10 @@ private List<Map<Integer, Double>> getMaxTopicsByDocs(ParallelTopicModel LDA, in
10591 double [] topicProbs = LDA .getTopicProbabilities (doc );
10692 //double probSum = 0.0;
10793 for (int topic = 0 ; topic < topicProbs .length && topic < maxTopicsPerDoc ; topic ++) {
108- // if (topicProbs[topic] > 0.01 ) { // TODO
94+ if (topicProbs [topic ] > TOPIC_THRESHOLD ) { // TODO
10995 topicList .put (topic , topicProbs [topic ]);
11096 //probSum += topicProbs[topic];
111- // }
97+ }
11298 }
11399 //System.out.println("Topic Sum: " + probSum);
114100 Map <Integer , Double > sortedTopicList = new TreeMap <Integer , Double >(new DoubleMapComparator (topicList ));
@@ -151,7 +137,7 @@ private List<Map<Integer, Double>> getMaxTermsByTopics(ParallelTopicModel LDA, i
151137 return topicList ;
152138 }
153139
154- public void predictValuesProbs () {
140+ public void predictValuesProbs (boolean topicCreation ) {
155141 ParallelTopicModel LDA = new ParallelTopicModel (this .numTopics , ALPHA * this .numTopics , BETA ); // TODO
156142 LDA .addInstances (this .instances );
157143 LDA .setNumThreads (1 );
@@ -164,20 +150,25 @@ public void predictValuesProbs() {
164150 }
165151 this .docList = getMaxTopicsByDocs (LDA , this .numTopics );
166152 System .out .println ("Fetched Doc-List" );
167- this .topicList = getMaxTermsByTopics (LDA , MAX_TERMS );
153+ this .topicList = ! topicCreation ? getMaxTermsByTopics (LDA , MAX_TERMS ) : null ;
168154 System .out .println ("Fetched Topic-List" );
169155 }
170156
171157 public Map <Integer , Double > getValueProbsForID (int id , boolean topicCreation ) {
172158 Map <Integer , Double > terms = null ;
173159 if (id < this .docList .size ()) {
174- terms = new LinkedHashMap <Integer , Double >();
175160 Map <Integer , Double > docVals = this .docList .get (id );
161+ if (this .topicList == null ) {
162+ return docVals ;
163+ }
164+ terms = new LinkedHashMap <Integer , Double >();
165+
176166 for (Map .Entry <Integer , Double > topic : docVals .entrySet ()) { // look at each assigned topic
177167 Set <Entry <Integer , Double >> entrySet = this .topicList .get (topic .getKey ()).entrySet ();
178168 double topicProb = topic .getValue ();
179169 for (Map .Entry <Integer , Double > entry : entrySet ) { // and its terms
180170 if (topicCreation ) {
171+ // DEPRECATED
181172 if (topicProb > TOPIC_THRESHOLD ) {
182173 terms .put (entry .getKey (), topicProb );
183174 break ; // only use first tag as topic-name with the topic probability
@@ -301,7 +292,7 @@ public static List<Map<Integer, Double>> startLdaCreation(BookmarkReader reader,
301292 if (userBased ) {
302293 userMaps = Utilities .getUserMaps (reader .getBookmarks ().subList (0 , trainSize ));
303294 userCalc = new MalletCalculator (userMaps , numTopics );
304- userCalc .predictValuesProbs ();
295+ userCalc .predictValuesProbs (topicCreation );
305296 //userDenoms = getDenoms(userPredictionValues);
306297 System .out .println ("User-Training finished" );
307298 }
@@ -311,7 +302,7 @@ public static List<Map<Integer, Double>> startLdaCreation(BookmarkReader reader,
311302 if (resBased ) {
312303 resMaps = Utilities .getResMaps (reader .getBookmarks ().subList (0 , trainSize ));
313304 resCalc = new MalletCalculator (resMaps , numTopics );
314- resCalc .predictValuesProbs ();
305+ resCalc .predictValuesProbs (topicCreation );
315306 //resDenoms = getDenoms(resPredictionValues);
316307 System .out .println ("Res-Training finished" );
317308 }
0 commit comments