Skip to content
8 changes: 7 additions & 1 deletion bertopic/_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
from bertopic.representation._mmr import mmr
from bertopic.backend._utils import select_backend
from bertopic.vectorizers import ClassTfidfTransformer
from bertopic.representation import BaseRepresentation
from bertopic.representation import BaseRepresentation, KeyBERTInspired
from bertopic.dimensionality import BaseDimensionalityReduction
from bertopic.cluster._utils import hdbscan_delegator, is_supported_hdbscan
from bertopic._utils import (
Expand Down Expand Up @@ -4051,6 +4051,7 @@ def _extract_topics(
documents,
fine_tune_representation=fine_tune_representation,
calculate_aspects=fine_tune_representation,
embeddings=embeddings,
)
self._create_topic_vectors(documents=documents, embeddings=embeddings, mappings=mappings)

Expand Down Expand Up @@ -4311,6 +4312,7 @@ def _extract_words_per_topic(
c_tf_idf: csr_matrix = None,
fine_tune_representation: bool = True,
calculate_aspects: bool = True,
embeddings: np.ndarray = None,
) -> Mapping[str, List[Tuple[str, float]]]:
"""Based on tf_idf scores per topic, extract the top n words per topic.

Expand All @@ -4326,6 +4328,8 @@ def _extract_words_per_topic(
fine_tune_representation: If True, the topic representation will be fine-tuned using representation models.
If False, the topic representation will remain as the base c-TF-IDF representation.
calculate_aspects: Whether to calculate additional topic aspects
embeddings: Pre-trained document embeddings. These can be used
instead of the sentence-transformer model

Returns:
topics: The top words per topic
Expand Down Expand Up @@ -4361,6 +4365,8 @@ def _extract_words_per_topic(
elif fine_tune_representation and isinstance(self.representation_model, list):
for tuner in self.representation_model:
topics = tuner.extract_topics(self, documents, c_tf_idf, topics)
elif fine_tune_representation and isinstance(self.representation_model, KeyBERTInspired):
topics = self.representation_model.extract_topics(self, documents, c_tf_idf, topics, embeddings)
elif fine_tune_representation and isinstance(self.representation_model, BaseRepresentation):
topics = self.representation_model.extract_topics(self, documents, c_tf_idf, topics)
elif fine_tune_representation and isinstance(self.representation_model, dict):
Expand Down
20 changes: 17 additions & 3 deletions bertopic/representation/_keybert.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def extract_topics(
documents: pd.DataFrame,
c_tf_idf: csr_matrix,
topics: Mapping[str, List[Tuple[str, float]]],
embeddings: np.ndarray = None,
) -> Mapping[str, List[Tuple[str, float]]]:
"""Extract topics.

Expand All @@ -79,6 +80,8 @@ def extract_topics(
documents: All input documents
c_tf_idf: The topic c-TF-IDF representation
topics: The candidate topics as calculated with c-TF-IDF
embeddings: Pre-trained document embeddings. These can be used
instead of the sentence-transformer model

Returns:
updated_topics: Updated topic representations
Expand All @@ -88,13 +91,19 @@ def extract_topics(
c_tf_idf, documents, topics, self.nr_samples, self.nr_repr_docs
)

# If document embeddings are precomputed extract the embeddings of the represenantative documents based on repr_doc_indices
repr_embeddings = None
if embeddings is not None:
repr_embeddings = [embeddings[index] for index in np.concatenate(repr_doc_indices)]

# We extract the top n words per class
topics = self._extract_candidate_words(topic_model, c_tf_idf, topics)

# We calculate the similarity between word and document embeddings and create
# topic embeddings from the representative document embeddings
sim_matrix, words = self._extract_embeddings(topic_model, topics, representative_docs, repr_doc_indices)

sim_matrix, words = self._extract_embeddings(
topic_model, topics, representative_docs, repr_doc_indices, repr_embeddings
)
# Find the best matching words based on the similarity matrix for each topic
updated_topics = self._extract_top_words(words, topics, sim_matrix)

Expand Down Expand Up @@ -150,6 +159,7 @@ def _extract_embeddings(
topics: Mapping[str, List[Tuple[str, float]]],
representative_docs: List[str],
repr_doc_indices: List[List[int]],
repr_embeddings: np.ndarray = None,
) -> Union[np.ndarray, List[str]]:
"""Extract the representative document embeddings and create topic embeddings.
Then extract word embeddings and calculate the cosine similarity between topic
Expand All @@ -162,13 +172,17 @@ def _extract_embeddings(
representative_docs: A flat list of representative documents
repr_doc_indices: The indices of representative documents
that belong to each topic
repr_embeddings: Embeddings of respective representative_docs

Returns:
sim: The similarity matrix between word and topic embeddings
vocab: The complete vocabulary of input documents
"""
# Calculate representative docs embeddings and create topic embeddings
repr_embeddings = topic_model._extract_embeddings(representative_docs, method="document", verbose=False)
# If there are no precomputed embeddings, only then create embeddings
if repr_embeddings is None:
repr_embeddings = topic_model._extract_embeddings(representative_docs, method="document", verbose=False)

topic_embeddings = [np.mean(repr_embeddings[i[0] : i[-1] + 1], axis=0) for i in repr_doc_indices]

# Calculate word embeddings and extract best matching with updated topic_embeddings
Expand Down
Loading