Skip to content

Commit 32b2ddd

Browse files
Add delete_topics (#2322)
1 parent 458e2e0 commit 32b2ddd

File tree

3 files changed

+198
-4
lines changed

3 files changed

+198
-4
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ docs/_build/
5959

6060
# Jupyter Notebook
6161
.ipynb_checkpoints
62+
notebooks/
6263

6364
# IPython
6465
profile_default/

bertopic/_bertopic.py

Lines changed: 138 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import numpy as np
1919
import pandas as pd
2020
import scipy.sparse as sp
21+
from copy import deepcopy
2122

2223
from tqdm import tqdm
2324
from pathlib import Path
@@ -827,7 +828,7 @@ def topics_over_time(
827828
nr_bins: The number of bins you want to create for the timestamps. The left interval will
828829
be chosen as the timestamp. An additional column will be created with the
829830
entire interval.
830-
datetime_format: The datetime format of the timestamps if they are strings, eg %d/%m/%Y.
831+
datetime_format: The datetime format of the timestamps if they are strings, eg "%d/%m/%Y".
831832
Set this to None if you want to have it automatically detect the format.
832833
See strftime documentation for more information on choices:
833834
https://docs.python.org/3/library/datetime.html#strftime-and-strptime-behavior.
@@ -1778,7 +1779,6 @@ def get_document_info(
17781779
# the topic distributions
17791780
document_info = topic_model.get_document_info(docs, df=df,
17801781
metadata={"Topic_distribution": distributions})
1781-
```
17821782
"""
17831783
check_documents_type(docs)
17841784
if df is not None:
@@ -2168,6 +2168,142 @@ def merge_topics(
21682168
self._save_representative_docs(documents)
21692169
self.probabilities_ = self._map_probabilities(self.probabilities_)
21702170

2171+
def delete_topics(
2172+
self,
2173+
topics_to_delete: List[int],
2174+
) -> None:
2175+
"""Delete topics from the topic model.
2176+
2177+
The deleted topics will be mapped to -1 (outlier topic). Core topic attributes
2178+
like topic embeddings and c-TF-IDF will be automatically updated.
2179+
2180+
Arguments:
2181+
topics_to_delete: List of topics to delete
2182+
"""
2183+
check_is_fitted(self)
2184+
2185+
topics_df = pd.DataFrame({"Topic": self.topics_})
2186+
2187+
# Check if -1 exists in the current topics
2188+
had_outliers = -1 in set(self.topics_)
2189+
2190+
# If adding -1 for the first time, initialize its attributes
2191+
if not had_outliers and any(topic in topics_to_delete for topic in self.topics_):
2192+
# Initialize c-TF-IDF for -1 topic (zeros)
2193+
outlier_row = np.zeros((1, self.c_tf_idf_.shape[1]))
2194+
outlier_row = sp.csr_matrix(outlier_row)
2195+
self.c_tf_idf_ = sp.vstack([outlier_row, self.c_tf_idf_])
2196+
2197+
# Initialize topic embeddings for -1 topic (zeros)
2198+
outlier_embedding = np.zeros((1, self.topic_embeddings_.shape[1]))
2199+
self.topic_embeddings_ = np.vstack([outlier_embedding, self.topic_embeddings_])
2200+
2201+
# Initialize topic representations for -1 topic: ("", 1e-05)
2202+
self.topic_representations_[-1] = [("", 1e-05)]
2203+
2204+
# Initialize representative docs for -1 topic (empty list)
2205+
self.representative_docs_[-1] = []
2206+
2207+
# Initialize representative images for -1 topic if images are being used
2208+
if self.representative_images_ is not None:
2209+
outlier_image = np.zeros((1, self.representative_images_.shape[1]))
2210+
self.representative_images_ = np.vstack([outlier_image, self.representative_images_])
2211+
2212+
# Initialize custom labels for -1 topic if they exist
2213+
if hasattr(self, "custom_labels_") and self.custom_labels_ is not None:
2214+
self.custom_labels_[-1] = ""
2215+
2216+
# Initialize ctfidf model diagonal for -1 topic (ones) if it exists
2217+
if hasattr(self, "ctfidf_model") and self.ctfidf_model is not None:
2218+
n_features = self.ctfidf_model._idf_diag.shape[1]
2219+
outlier_diag = sp.csr_matrix(([1.0], ([0], [0])), shape=(1, n_features))
2220+
self.ctfidf_model._idf_diag = sp.vstack([outlier_diag, self.ctfidf_model._idf_diag])
2221+
2222+
# Initialize topic aspects for -1 topic (empty dict for each aspect) if they exist
2223+
if hasattr(self, "topic_aspects_") and self.topic_aspects_ is not None:
2224+
for aspect in self.topic_aspects_:
2225+
self.topic_aspects_[aspect][-1] = {}
2226+
2227+
# First map deleted topics to -1
2228+
mapping = {topic: -1 if topic in topics_to_delete else topic for topic in set(self.topics_)}
2229+
mapping[-1] = -1
2230+
2231+
# Track mappings and sizes of topics for merging topic embeddings
2232+
mappings = defaultdict(list)
2233+
for key, val in sorted(mapping.items()):
2234+
mappings[val].append(key)
2235+
mappings = {
2236+
topic_to: {
2237+
"topics_from": topics_from,
2238+
"topic_sizes": [self.topic_sizes_[topic] for topic in topics_from],
2239+
}
2240+
for topic_to, topics_from in mappings.items()
2241+
}
2242+
2243+
# remove deleted topics and update attributes
2244+
topics_df.Topic = topics_df.Topic.map(mapping)
2245+
self.topic_mapper_.add_mappings(mapping, topic_model=deepcopy(self))
2246+
topics_df = self._sort_mappings_by_frequency(topics_df)
2247+
self._update_topic_size(topics_df)
2248+
self.probabilities_ = self._map_probabilities(self.probabilities_)
2249+
2250+
final_mapping = self.topic_mapper_.get_mappings(original_topics=False)
2251+
2252+
# Update dictionary-based attributes to remove deleted topics
2253+
# Handle topic_aspects_ if it exists
2254+
if hasattr(self, "topic_aspects_") and self.topic_aspects_ is not None:
2255+
new_aspects = {
2256+
aspect: {
2257+
(final_mapping[old_topic] if old_topic != -1 else -1): content
2258+
for old_topic, content in topics.items()
2259+
if old_topic not in topics_to_delete
2260+
}
2261+
for aspect, topics in self.topic_aspects_.items()
2262+
}
2263+
self.topic_aspects_ = new_aspects
2264+
2265+
# Update custom labels if they exist
2266+
if hasattr(self, "custom_labels_") and self.custom_labels_ is not None:
2267+
new_labels = {
2268+
(final_mapping[old_topic] if old_topic != -1 else -1): label
2269+
for old_topic, label in self.custom_labels_.items()
2270+
if old_topic not in topics_to_delete
2271+
}
2272+
self.custom_labels_ = new_labels
2273+
2274+
# Update topic representations
2275+
new_representations = {
2276+
(final_mapping[old_topic] if old_topic != -1 else -1): content
2277+
for old_topic, content in self.topic_representations_.items()
2278+
if old_topic not in topics_to_delete
2279+
}
2280+
self.topic_representations_ = new_representations
2281+
2282+
# Update representative docs if they exist
2283+
new_representative_docs = {
2284+
(final_mapping[old_topic] if old_topic != -1 else -1): docs
2285+
for old_topic, docs in self.representative_docs_.items()
2286+
if old_topic not in topics_to_delete
2287+
}
2288+
self.representative_docs_ = new_representative_docs
2289+
2290+
# Update representative images if they exist
2291+
if self.representative_images_ is not None:
2292+
# Create a mask for non-deleted topics
2293+
mask = np.array([topic not in topics_to_delete for topic in range(len(self.representative_images_))])
2294+
self.representative_images_ = self.representative_images_[mask] if mask.any() else None
2295+
2296+
# Update array-based attributes using masks to remove deleted topics
2297+
for attr in ["topic_embeddings_", "c_tf_idf_"]:
2298+
matrix = getattr(self, attr)
2299+
mask = np.array([topic not in topics_to_delete for topic in range(matrix.shape[0])])
2300+
setattr(self, attr, matrix[mask])
2301+
2302+
# Update ctfidf model to remove deleted topics if it exists
2303+
if hasattr(self, "ctfidf_model") and self.ctfidf_model is not None:
2304+
mask = np.array([topic not in topics_to_delete for topic in range(self.ctfidf_model._idf_diag.shape[0])])
2305+
self.ctfidf_model._idf_diag = self.ctfidf_model._idf_diag[mask]
2306+
21712307
def reduce_topics(
21722308
self,
21732309
docs: List[str],
@@ -4840,13 +4976,11 @@ def add_mappings(self, mappings: Mapping[int, int], topic_model: BERTopic):
48404976
).flatten()
48414977
best_zeroshot_topic_idx = np.argmax(cosine_similarities)
48424978
best_cosine_similarity = cosine_similarities[best_zeroshot_topic_idx]
4843-
48444979
if best_cosine_similarity >= topic_model.zeroshot_min_similarity:
48454980
# Using the topic ID from before mapping, get the idx into the zeroshot topic list
48464981
new_topic_id_to_zeroshot_topic_idx[topic_to] = topic_model._topic_id_to_zeroshot_topic_idx[
48474982
zeroshot_topic_ids[best_zeroshot_topic_idx]
48484983
]
4849-
48504984
topic_model._topic_id_to_zeroshot_topic_idx = new_topic_id_to_zeroshot_topic_idx
48514985

48524986
def add_new_topics(self, mappings: Mapping[int, int]):
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import copy
2+
import pytest
3+
4+
5+
@pytest.mark.parametrize(
6+
"model",
7+
[
8+
("kmeans_pca_topic_model"),
9+
("base_topic_model"),
10+
("custom_topic_model"),
11+
("merged_topic_model"),
12+
("reduced_topic_model"),
13+
("online_topic_model"),
14+
],
15+
)
16+
def test_delete(model, request):
17+
topic_model = copy.deepcopy(request.getfixturevalue(model))
18+
nr_topics = len(set(topic_model.topics_))
19+
length_documents = len(topic_model.topics_)
20+
21+
# First deletion
22+
topics_to_delete = [1, 2]
23+
topic_model.delete_topics(topics_to_delete)
24+
mappings = topic_model.topic_mapper_.get_mappings(list(topic_model.hdbscan_model.labels_))
25+
mapped_labels = [mappings[label] for label in topic_model.hdbscan_model.labels_]
26+
27+
if model == "online_topic_model" or model == "kmeans_pca_topic_model":
28+
assert nr_topics == len(set(topic_model.topics_)) + 1
29+
assert topic_model.get_topic_info().Count.sum() == length_documents
30+
else:
31+
assert nr_topics == len(set(topic_model.topics_)) + 2
32+
assert topic_model.get_topic_info().Count.sum() == length_documents
33+
34+
if model == "online_topic_model":
35+
assert mapped_labels == topic_model.topics_[950:]
36+
else:
37+
assert mapped_labels == topic_model.topics_
38+
39+
# Find two existing topics for second deletion
40+
remaining_topics = sorted(list(set(topic_model.topics_)))
41+
remaining_topics = [t for t in remaining_topics if t != -1] # Exclude outlier topic
42+
topics_to_delete = remaining_topics[:2] # Take first two remaining topics
43+
44+
# Second deletion
45+
topic_model.delete_topics(topics_to_delete)
46+
mappings = topic_model.topic_mapper_.get_mappings(list(topic_model.hdbscan_model.labels_))
47+
mapped_labels = [mappings[label] for label in topic_model.hdbscan_model.labels_]
48+
49+
if model == "online_topic_model" or model == "kmeans_pca_topic_model":
50+
assert nr_topics == len(set(topic_model.topics_)) + 3
51+
assert topic_model.get_topic_info().Count.sum() == length_documents
52+
else:
53+
assert nr_topics == len(set(topic_model.topics_)) + 4
54+
assert topic_model.get_topic_info().Count.sum() == length_documents
55+
56+
if model == "online_topic_model":
57+
assert mapped_labels == topic_model.topics_[950:]
58+
else:
59+
assert mapped_labels == topic_model.topics_

0 commit comments

Comments
 (0)