Skip to content

Commit 6669201

Browse files
authored
Speed up _create_topic_vectors by replacing DataFrame .loc with NumPy masking (#2406)
1 parent 32b2ddd commit 6669201

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

bertopic/_bertopic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4315,11 +4315,11 @@ def _create_topic_vectors(
43154315
if embeddings is not None and documents is not None:
43164316
topic_embeddings = []
43174317
topics = documents.sort_values("Topic").Topic.unique()
4318+
topic_ids = documents["Topic"].values
4319+
doc_ids = documents["ID"].values.astype(int)
43184320
for topic in topics:
4319-
indices = documents.loc[documents.Topic == topic, "ID"].values
4320-
indices = [int(index) for index in indices]
4321-
topic_embedding = np.mean(embeddings[indices], axis=0)
4322-
topic_embeddings.append(topic_embedding)
4321+
mask = topic_ids == topic
4322+
topic_embeddings.append(embeddings[doc_ids[mask]].mean(axis=0))
43234323
self.topic_embeddings_ = np.array(topic_embeddings)
43244324

43254325
# Topic embeddings when merging topics

0 commit comments

Comments
 (0)