diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index 31de626e..f0ed6017 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -4315,11 +4315,11 @@ def _create_topic_vectors( if embeddings is not None and documents is not None: topic_embeddings = [] topics = documents.sort_values("Topic").Topic.unique() + topic_ids = documents["Topic"].values + doc_ids = documents["ID"].values.astype(int) for topic in topics: - indices = documents.loc[documents.Topic == topic, "ID"].values - indices = [int(index) for index in indices] - topic_embedding = np.mean(embeddings[indices], axis=0) - topic_embeddings.append(topic_embedding) + mask = topic_ids == topic + topic_embeddings.append(embeddings[doc_ids[mask]].mean(axis=0)) self.topic_embeddings_ = np.array(topic_embeddings) # Topic embeddings when merging topics