diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index 92fe0855..61418526 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -59,7 +59,7 @@ from bertopic.representation._mmr import mmr from bertopic.backend._utils import select_backend from bertopic.vectorizers import ClassTfidfTransformer -from bertopic.representation import BaseRepresentation +from bertopic.representation import BaseRepresentation, KeyBERTInspired from bertopic.dimensionality import BaseDimensionalityReduction from bertopic.cluster._utils import hdbscan_delegator, is_supported_hdbscan from bertopic._utils import ( @@ -4051,6 +4051,7 @@ def _extract_topics( documents, fine_tune_representation=fine_tune_representation, calculate_aspects=fine_tune_representation, + embeddings=embeddings, ) self._create_topic_vectors(documents=documents, embeddings=embeddings, mappings=mappings) @@ -4311,6 +4312,7 @@ def _extract_words_per_topic( c_tf_idf: csr_matrix = None, fine_tune_representation: bool = True, calculate_aspects: bool = True, + embeddings: np.ndarray = None, ) -> Mapping[str, List[Tuple[str, float]]]: """Based on tf_idf scores per topic, extract the top n words per topic. @@ -4326,6 +4328,8 @@ def _extract_words_per_topic( fine_tune_representation: If True, the topic representation will be fine-tuned using representation models. If False, the topic representation will remain as the base c-TF-IDF representation. calculate_aspects: Whether to calculate additional topic aspects + embeddings: Pre-trained document embeddings. These can be used + instead of an embedding model Returns: topics: The top words per topic @@ -4361,6 +4365,8 @@ def _extract_words_per_topic( elif fine_tune_representation and isinstance(self.representation_model, list): for tuner in self.representation_model: topics = tuner.extract_topics(self, documents, c_tf_idf, topics) + elif fine_tune_representation and isinstance(self.representation_model, KeyBERTInspired): + topics = self.representation_model.extract_topics(self, documents, c_tf_idf, topics, embeddings) elif fine_tune_representation and isinstance(self.representation_model, BaseRepresentation): topics = self.representation_model.extract_topics(self, documents, c_tf_idf, topics) elif fine_tune_representation and isinstance(self.representation_model, dict): diff --git a/bertopic/representation/_keybert.py b/bertopic/representation/_keybert.py index f91c01cc..10812369 100644 --- a/bertopic/representation/_keybert.py +++ b/bertopic/representation/_keybert.py @@ -71,6 +71,7 @@ def extract_topics( documents: pd.DataFrame, c_tf_idf: csr_matrix, topics: Mapping[str, List[Tuple[str, float]]], + embeddings: np.ndarray = None, ) -> Mapping[str, List[Tuple[str, float]]]: """Extract topics. @@ -79,6 +80,8 @@ def extract_topics( documents: All input documents c_tf_idf: The topic c-TF-IDF representation topics: The candidate topics as calculated with c-TF-IDF + embeddings: Pre-trained document embeddings. These can be used + instead of an embedding model Returns: updated_topics: Updated topic representations @@ -88,13 +91,19 @@ def extract_topics( c_tf_idf, documents, topics, self.nr_samples, self.nr_repr_docs ) + # If document embeddings are precomputed, extract the embeddings of the representative documents based on repr_doc_indices + repr_embeddings = None + if embeddings is not None: + repr_embeddings = [embeddings[index] for index in np.concatenate(repr_doc_indices)] + # We extract the top n words per class topics = self._extract_candidate_words(topic_model, c_tf_idf, topics) # We calculate the similarity between word and document embeddings and create # topic embeddings from the representative document embeddings - sim_matrix, words = self._extract_embeddings(topic_model, topics, representative_docs, repr_doc_indices) - + sim_matrix, words = self._extract_embeddings( + topic_model, topics, representative_docs, repr_doc_indices, repr_embeddings + ) # Find the best matching words based on the similarity matrix for each topic updated_topics = self._extract_top_words(words, topics, sim_matrix) @@ -150,6 +159,7 @@ def _extract_embeddings( topics: Mapping[str, List[Tuple[str, float]]], representative_docs: List[str], repr_doc_indices: List[List[int]], + repr_embeddings: np.ndarray = None, ) -> Union[np.ndarray, List[str]]: """Extract the representative document embeddings and create topic embeddings. Then extract word embeddings and calculate the cosine similarity between topic @@ -162,13 +172,16 @@ def _extract_embeddings( representative_docs: A flat list of representative documents repr_doc_indices: The indices of representative documents that belong to each topic + repr_embeddings: Embeddings of respective representative_docs Returns: sim: The similarity matrix between word and topic embeddings vocab: The complete vocabulary of input documents """ - # Calculate representative docs embeddings and create topic embeddings - repr_embeddings = topic_model._extract_embeddings(representative_docs, method="document", verbose=False) + # Calculate representative document embeddings if there are no precomputed embeddings. + if repr_embeddings is None: + repr_embeddings = topic_model._extract_embeddings(representative_docs, method="document", verbose=False) + topic_embeddings = [np.mean(repr_embeddings[i[0] : i[-1] + 1], axis=0) for i in repr_doc_indices] # Calculate word embeddings and extract best matching with updated topic_embeddings