|
| 1 | +import yaml |
1 | 2 | import warnings |
2 | 3 | warnings.filterwarnings("ignore", category=FutureWarning) |
3 | 4 | warnings.filterwarnings("ignore", category=UserWarning) |
4 | 5 |
|
| 6 | +try: |
| 7 | + yaml._warnings_enabled["YAMLLoadWarning"] = False |
| 8 | +except (KeyError, AttributeError, TypeError) as e: |
| 9 | + pass |
| 10 | + |
5 | 11 | import re |
6 | 12 | import joblib |
7 | 13 | import inspect |
@@ -162,7 +168,6 @@ def __init__(self, |
162 | 168 | self.topics = None |
163 | 169 | self.topic_mapper = None |
164 | 170 | self.topic_sizes = None |
165 | | - self.mapped_topics = None |
166 | 171 | self.merged_topics = None |
167 | 172 | self.topic_embeddings = None |
168 | 173 | self.topic_sim_matrix = None |
@@ -372,10 +377,8 @@ def transform(self, |
372 | 377 | else: |
373 | 378 | probabilities = None |
374 | 379 |
|
375 | | - if self.mapped_topics: |
376 | | - predictions = self._map_predictions(predictions) |
377 | | - probabilities = self._map_probabilities(probabilities) |
378 | | - |
| 380 | + probabilities = self._map_probabilities(probabilities, original_topics=True) |
| 381 | + predictions = self._map_predictions(predictions) |
379 | 382 | return predictions, probabilities |
380 | 383 |
|
381 | 384 | def topics_over_time(self, |
@@ -780,7 +783,7 @@ def get_topic_freq(self, topic: int = None) -> Union[pd.DataFrame, int]: |
780 | 783 | return pd.DataFrame(self.topic_sizes.items(), columns=['Topic', 'Count']).sort_values("Count", |
781 | 784 | ascending=False) |
782 | 785 |
|
783 | | - def get_representative_docs(self, topic: int) -> List[str]: |
| 786 | + def get_representative_docs(self, topic: int = None) -> List[str]: |
784 | 787 | """ Extract representative documents per topic |
785 | 788 |
|
786 | 789 | Arguments: |
@@ -1338,13 +1341,12 @@ def _extract_embeddings(self, |
1338 | 1341 |
|
1339 | 1342 | def _map_predictions(self, predictions: List[int]) -> List[int]: |
1340 | 1343 | """ Map predictions to the correct topics if topics were reduced """ |
1341 | | - if self.mapped_topics: |
1342 | | - return [self.mapped_topics[prediction] |
1343 | | - if prediction in self.mapped_topics |
1344 | | - else prediction |
1345 | | - for prediction in predictions] |
1346 | | - else: |
1347 | | - return predictions |
| 1344 | + mappings = self.topic_mapper.get_mappings(original_topics=True) |
| 1345 | + mapped_predictions = [mappings[prediction] |
| 1346 | + if prediction in mappings |
| 1347 | + else -1 |
| 1348 | + for prediction in predictions] |
| 1349 | + return mapped_predictions |
1348 | 1350 |
|
1349 | 1351 | def _reduce_dimensionality(self, |
1350 | 1352 | embeddings: Union[np.ndarray, csr_matrix], |
@@ -1786,9 +1788,6 @@ def _sort_mappings_by_frequency(self, documents: pd.DataFrame) -> pd.DataFrame: |
1786 | 1788 | """ |
1787 | 1789 | self._update_topic_size(documents) |
1788 | 1790 |
|
1789 | | - if not self.mapped_topics: |
1790 | | - self.mapped_topics = {topic: topic for topic in set(self.hdbscan_model.labels_)} |
1791 | | - |
1792 | 1791 | # Map topics based on frequency |
1793 | 1792 | df = pd.DataFrame(self.topic_sizes.items(), columns=["Old_Topic", "Size"]).sort_values("Size", ascending=False) |
1794 | 1793 | df = df[df.Old_Topic != -1] |
|
0 commit comments