Skip to content

Commit adf1bdc

Browse files
authored
Enable ruff rule RUF (#2457)
1 parent 92d269a commit adf1bdc

35 files changed

+159
-156
lines changed

bertopic/_bertopic.py

Lines changed: 59 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,11 @@ def __init__(
148148
top_n_words: int = 10,
149149
n_gram_range: Tuple[int, int] = (1, 1),
150150
min_topic_size: int = 10,
151-
nr_topics: Union[int, str] = None,
151+
nr_topics: Union[int, str] | None = None,
152152
low_memory: bool = False,
153153
calculate_probabilities: bool = False,
154-
seed_topic_list: List[List[str]] = None,
155-
zeroshot_topic_list: List[str] = None,
154+
seed_topic_list: List[List[str]] | None = None,
155+
zeroshot_topic_list: List[str] | None = None,
156156
zeroshot_min_similarity: float = 0.7,
157157
embedding_model=None,
158158
umap_model=None,
@@ -351,7 +351,7 @@ def fit(
351351
self,
352352
documents: List[str],
353353
embeddings: np.ndarray = None,
354-
images: List[str] = None,
354+
images: List[str] | None = None,
355355
y: Union[List[int], np.ndarray] = None,
356356
):
357357
"""Fit the models on a collection of documents and generate topics.
@@ -396,7 +396,7 @@ def fit_transform(
396396
self,
397397
documents: List[str],
398398
embeddings: np.ndarray = None,
399-
images: List[str] = None,
399+
images: List[str] | None = None,
400400
y: Union[List[int], np.ndarray] = None,
401401
) -> Tuple[List[int], Union[np.ndarray, None]]:
402402
"""Fit the models on a collection of documents, generate topics,
@@ -546,7 +546,7 @@ def transform(
546546
self,
547547
documents: Union[str, List[str]],
548548
embeddings: np.ndarray = None,
549-
images: List[str] = None,
549+
images: List[str] | None = None,
550550
) -> Tuple[List[int], np.ndarray]:
551551
"""After having fit a model, use transform to predict new instances.
552552
@@ -798,9 +798,9 @@ def topics_over_time(
798798
self,
799799
docs: List[str],
800800
timestamps: Union[List[str], List[int]],
801-
topics: List[int] = None,
802-
nr_bins: int = None,
803-
datetime_format: str = None,
801+
topics: List[int] | None = None,
802+
nr_bins: int | None = None,
803+
datetime_format: str | None = None,
804804
evolution_tuning: bool = True,
805805
global_tuning: bool = True,
806806
) -> pd.DataFrame:
@@ -1036,8 +1036,8 @@ def hierarchical_topics(
10361036
self,
10371037
docs: List[str],
10381038
use_ctfidf: bool = True,
1039-
linkage_function: Callable[[csr_matrix], np.ndarray] = None,
1040-
distance_function: Callable[[csr_matrix], csr_matrix] = None,
1039+
linkage_function: Callable[[csr_matrix], np.ndarray] | None = None,
1040+
distance_function: Callable[[csr_matrix], csr_matrix] | None = None,
10411041
) -> pd.DataFrame:
10421042
"""Create a hierarchy of topics.
10431043
@@ -1428,7 +1428,9 @@ def approximate_distribution(
14281428

14291429
return topic_distributions, topic_token_distributions
14301430

1431-
def find_topics(self, search_term: str = None, image: str = None, top_n: int = 5) -> Tuple[List[int], List[float]]:
1431+
def find_topics(
1432+
self, search_term: str | None = None, image: str | None = None, top_n: int = 5
1433+
) -> Tuple[List[int], List[float]]:
14321434
"""Find topics most similar to a search_term.
14331435
14341436
Creates an embedding for a search query and compares that with
@@ -1486,10 +1488,10 @@ def find_topics(self, search_term: str = None, image: str = None, top_n: int = 5
14861488
def update_topics(
14871489
self,
14881490
docs: List[str],
1489-
images: List[str] = None,
1490-
topics: List[int] = None,
1491+
images: List[str] | None = None,
1492+
topics: List[int] | None = None,
14911493
top_n_words: int = 10,
1492-
n_gram_range: Tuple[int, int] = None,
1494+
n_gram_range: Tuple[int, int] | None = None,
14931495
vectorizer_model: CountVectorizer = None,
14941496
ctfidf_model: ClassTfidfTransformer = None,
14951497
representation_model: BaseRepresentation = None,
@@ -1645,7 +1647,7 @@ def get_topic(self, topic: int, full: bool = False) -> Union[Mapping[str, Tuple[
16451647
else:
16461648
return False
16471649

1648-
def get_topic_info(self, topic: int = None) -> pd.DataFrame:
1650+
def get_topic_info(self, topic: int | None = None) -> pd.DataFrame:
16491651
"""Get information about each topic including its ID, frequency, and name.
16501652
16511653
Arguments:
@@ -1671,7 +1673,7 @@ def get_topic_info(self, topic: int = None) -> pd.DataFrame:
16711673
info["CustomName"] = info["Topic"].map(labels)
16721674

16731675
# Main Keywords
1674-
values = {topic: list(list(zip(*values))[0]) for topic, values in self.topic_representations_.items()}
1676+
values = {topic: list(next(zip(*values))) for topic, values in self.topic_representations_.items()}
16751677
info["Representation"] = info["Topic"].map(values)
16761678

16771679
# Extract all topic aspects
@@ -1681,7 +1683,7 @@ def get_topic_info(self, topic: int = None) -> pd.DataFrame:
16811683
if isinstance(list(values.values())[-1][0], tuple) or isinstance(
16821684
list(values.values())[-1][0], list
16831685
):
1684-
values = {topic: list(list(zip(*value))[0]) for topic, value in values.items()}
1686+
values = {topic: list(next(zip(*value))) for topic, value in values.items()}
16851687
elif isinstance(list(values.values())[-1][0], str):
16861688
values = {topic: " ".join(value).strip() for topic, value in values.items()}
16871689
info[aspect] = info["Topic"].map(values)
@@ -1698,7 +1700,7 @@ def get_topic_info(self, topic: int = None) -> pd.DataFrame:
16981700

16991701
return info.reset_index(drop=True)
17001702

1701-
def get_topic_freq(self, topic: int = None) -> Union[pd.DataFrame, int]:
1703+
def get_topic_freq(self, topic: int | None = None) -> Union[pd.DataFrame, int]:
17021704
"""Return the size of topics (descending order).
17031705
17041706
Arguments:
@@ -1733,7 +1735,7 @@ def get_document_info(
17331735
self,
17341736
docs: List[str],
17351737
df: pd.DataFrame = None,
1736-
metadata: Mapping[str, Any] = None,
1738+
metadata: Mapping[str, Any] | None = None,
17371739
) -> pd.DataFrame:
17381740
"""Get information about the documents on which the topic was trained
17391741
including the documents themselves, their respective topics, the name
@@ -1797,7 +1799,7 @@ def get_document_info(
17971799
document_info = pd.merge(document_info, topic_info, on="Topic", how="left")
17981800

17991801
# Add top n words
1800-
top_n_words = {topic: " - ".join(list(zip(*self.get_topic(topic)))[0]) for topic in set(self.topics_)}
1802+
top_n_words = {topic: " - ".join(next(zip(*self.get_topic(topic)))) for topic in set(self.topics_)}
18011803
document_info["Top_n_words"] = document_info.Topic.map(top_n_words)
18021804

18031805
# Add flat probabilities
@@ -1821,7 +1823,7 @@ def get_document_info(
18211823
document_info[column] = values
18221824
return document_info
18231825

1824-
def get_representative_docs(self, topic: int = None) -> List[str]:
1826+
def get_representative_docs(self, topic: int | None = None) -> List[str]:
18251827
"""Extract the best representing documents per topic.
18261828
18271829
Note:
@@ -1869,7 +1871,7 @@ def get_representative_docs(self, topic: int = None) -> List[str]:
18691871
@staticmethod
18701872
def get_topic_tree(
18711873
hier_topics: pd.DataFrame,
1872-
max_distance: float = None,
1874+
max_distance: float | None = None,
18731875
tight_layout: bool = False,
18741876
) -> str:
18751877
"""Extract the topic tree such that it can be printed.
@@ -2041,9 +2043,9 @@ def generate_topic_labels(
20412043
self,
20422044
nr_words: int = 3,
20432045
topic_prefix: bool = True,
2044-
word_length: int = None,
2046+
word_length: int | None = None,
20452047
separator: str = "_",
2046-
aspect: str = None,
2048+
aspect: str | None = None,
20472049
) -> List[str]:
20482050
"""Get labels for each topic in a user-defined format.
20492051
@@ -2100,7 +2102,7 @@ def merge_topics(
21002102
self,
21012103
docs: List[str],
21022104
topics_to_merge: List[Union[Iterable[int], int]],
2103-
images: List[str] = None,
2105+
images: List[str] | None = None,
21042106
) -> None:
21052107
"""Arguments:
21062108
docs: The documents you used when calling either `fit` or `fit_transform`
@@ -2312,7 +2314,7 @@ def reduce_topics(
23122314
self,
23132315
docs: List[str],
23142316
nr_topics: Union[int, str] = 20,
2315-
images: List[str] = None,
2317+
images: List[str] | None = None,
23162318
use_ctfidf: bool = False,
23172319
) -> None:
23182320
"""Reduce the number of topics to a fixed number of topics
@@ -2379,7 +2381,7 @@ def reduce_outliers(
23792381
self,
23802382
documents: List[str],
23812383
topics: List[int],
2382-
images: List[str] = None,
2384+
images: List[str] | None = None,
23832385
strategy: str = "distributions",
23842386
probabilities: np.ndarray = None,
23852387
threshold: float = 0,
@@ -2538,8 +2540,8 @@ def reduce_outliers(
25382540

25392541
def visualize_topics(
25402542
self,
2541-
topics: List[int] = None,
2542-
top_n_topics: int = None,
2543+
topics: List[int] | None = None,
2544+
top_n_topics: int | None = None,
25432545
use_ctfidf: bool = False,
25442546
custom_labels: bool = False,
25452547
title: str = "<b>Intertopic Distance Map</b>",
@@ -2593,10 +2595,10 @@ def visualize_topics(
25932595
def visualize_documents(
25942596
self,
25952597
docs: List[str],
2596-
topics: List[int] = None,
2598+
topics: List[int] | None = None,
25972599
embeddings: np.ndarray = None,
25982600
reduced_embeddings: np.ndarray = None,
2599-
sample: float = None,
2601+
sample: float | None = None,
26002602
hide_annotations: bool = False,
26012603
hide_document_hover: bool = False,
26022604
custom_labels: bool = False,
@@ -2691,8 +2693,8 @@ def visualize_documents(
26912693

26922694
def visualize_document_datamap(
26932695
self,
2694-
docs: List[str] = None,
2695-
topics: List[int] = None,
2696+
docs: List[str] | None = None,
2697+
topics: List[int] | None = None,
26962698
embeddings: np.ndarray = None,
26972699
reduced_embeddings: np.ndarray = None,
26982700
custom_labels: Union[bool, str] = False,
@@ -2804,10 +2806,10 @@ def visualize_hierarchical_documents(
28042806
self,
28052807
docs: List[str],
28062808
hierarchical_topics: pd.DataFrame,
2807-
topics: List[int] = None,
2809+
topics: List[int] | None = None,
28082810
embeddings: np.ndarray = None,
28092811
reduced_embeddings: np.ndarray = None,
2810-
sample: Union[float, int] = None,
2812+
sample: Union[float, int] | None = None,
28112813
hide_annotations: bool = False,
28122814
hide_document_hover: bool = True,
28132815
nr_levels: int = 10,
@@ -2922,7 +2924,7 @@ def visualize_hierarchical_documents(
29222924

29232925
def visualize_term_rank(
29242926
self,
2925-
topics: List[int] = None,
2927+
topics: List[int] | None = None,
29262928
log_scale: bool = False,
29272929
custom_labels: bool = False,
29282930
title: str = "<b>Term score decline per Topic</b>",
@@ -2986,8 +2988,8 @@ def visualize_term_rank(
29862988
def visualize_topics_over_time(
29872989
self,
29882990
topics_over_time: pd.DataFrame,
2989-
top_n_topics: int = None,
2990-
topics: List[int] = None,
2991+
top_n_topics: int | None = None,
2992+
topics: List[int] | None = None,
29912993
normalize_frequency: bool = False,
29922994
custom_labels: bool = False,
29932995
title: str = "<b>Topics over Time</b>",
@@ -3043,7 +3045,7 @@ def visualize_topics_per_class(
30433045
self,
30443046
topics_per_class: pd.DataFrame,
30453047
top_n_topics: int = 10,
3046-
topics: List[int] = None,
3048+
topics: List[int] | None = None,
30473049
normalize_frequency: bool = False,
30483050
custom_labels: bool = False,
30493051
title: str = "<b>Topics per Class</b>",
@@ -3199,16 +3201,16 @@ def visualize_approximate_distribution(
31993201
def visualize_hierarchy(
32003202
self,
32013203
orientation: str = "left",
3202-
topics: List[int] = None,
3203-
top_n_topics: int = None,
3204+
topics: List[int] | None = None,
3205+
top_n_topics: int | None = None,
32043206
use_ctfidf: bool = True,
32053207
custom_labels: bool = False,
32063208
title: str = "<b>Hierarchical Clustering</b>",
32073209
width: int = 1000,
32083210
height: int = 600,
32093211
hierarchical_topics: pd.DataFrame = None,
3210-
linkage_function: Callable[[csr_matrix], np.ndarray] = None,
3211-
distance_function: Callable[[csr_matrix], csr_matrix] = None,
3212+
linkage_function: Callable[[csr_matrix], np.ndarray] | None = None,
3213+
distance_function: Callable[[csr_matrix], csr_matrix] | None = None,
32123214
color_threshold: int = 1,
32133215
) -> "go.Figure":
32143216
"""Visualize a hierarchical structure of the topics.
@@ -3298,9 +3300,9 @@ def visualize_hierarchy(
32983300

32993301
def visualize_heatmap(
33003302
self,
3301-
topics: List[int] = None,
3302-
top_n_topics: int = None,
3303-
n_clusters: int = None,
3303+
topics: List[int] | None = None,
3304+
top_n_topics: int | None = None,
3305+
n_clusters: int | None = None,
33043306
use_ctfidf: bool = False,
33053307
custom_labels: bool = False,
33063308
title: str = "<b>Similarity Matrix</b>",
@@ -3358,7 +3360,7 @@ def visualize_heatmap(
33583360

33593361
def visualize_barchart(
33603362
self,
3361-
topics: List[int] = None,
3363+
topics: List[int] | None = None,
33623364
top_n_topics: int = 8,
33633365
n_words: int = 5,
33643366
custom_labels: bool = False,
@@ -3750,8 +3752,8 @@ def push_to_hf_hub(
37503752
self,
37513753
repo_id: str,
37523754
commit_message: str = "Add BERTopic model",
3753-
token: str = None,
3754-
revision: str = None,
3755+
token: str | None = None,
3756+
revision: str | None = None,
37553757
private: bool = False,
37563758
create_pr: bool = False,
37573759
model_card: bool = True,
@@ -3842,9 +3844,9 @@ def get_params(self, deep: bool = False) -> Mapping[str, Any]:
38423844
def _extract_embeddings(
38433845
self,
38443846
documents: Union[List[str], str],
3845-
images: List[str] = None,
3847+
images: List[str] | None = None,
38463848
method: str = "document",
3847-
verbose: bool = None,
3849+
verbose: bool | None = None,
38483850
) -> np.ndarray:
38493851
"""Extract sentence/document embeddings through pre-trained embeddings
38503852
For an overview of pre-trained models: https://www.sbert.net/docs/pretrained_models.html.
@@ -4237,7 +4239,7 @@ def _extract_representative_docs(
42374239
topics: Mapping[str, List[Tuple[str, float]]],
42384240
nr_samples: int = 500,
42394241
nr_repr_docs: int = 5,
4240-
diversity: float = None,
4242+
diversity: float | None = None,
42414243
) -> Union[List[str], List[List[int]]]:
42424244
"""Approximate most representative documents per topic by sampling
42434245
a subset of the documents in each topic and calculating which are
@@ -4554,7 +4556,7 @@ def _extract_words_per_topic(
45544556
aspects = aspect_model.extract_topics(self, documents, c_tf_idf, aspects)
45554557
else:
45564558
raise TypeError(
4557-
f"unsupported type {type(aspect_model).__name__} for representation_model[{repr(aspect)}]"
4559+
f"unsupported type {type(aspect_model).__name__} for representation_model[{aspect!r}]"
45584560
)
45594561
self.topic_aspects_[aspect] = aspects
45604562

@@ -5017,9 +5019,9 @@ def _create_model_from_files(
50175019
topics: Mapping[str, Any],
50185020
params: Mapping[str, Any],
50195021
tensors: Mapping[str, np.array],
5020-
ctfidf_tensors: Mapping[str, Any] = None,
5021-
ctfidf_config: Mapping[str, Any] = None,
5022-
images: Mapping[int, Any] = None,
5022+
ctfidf_tensors: Mapping[str, Any] | None = None,
5023+
ctfidf_config: Mapping[str, Any] | None = None,
5024+
images: Mapping[int, Any] | None = None,
50235025
warn_no_backend: bool = True,
50245026
):
50255027
"""Create a BERTopic model from a variety of inputs.

bertopic/_save_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ def push_to_hf_hub(
107107
model,
108108
repo_id: str,
109109
commit_message: str = "Add BERTopic model",
110-
token: str = None,
111-
revision: str = None,
110+
token: str | None = None,
111+
revision: str | None = None,
112112
private: bool = False,
113113
create_pr: bool = False,
114114
model_card: bool = True,
@@ -286,7 +286,7 @@ def generate_readme(model, repo_id: str):
286286
nr_documents = ""
287287

288288
# Topic information
289-
topic_keywords = [" - ".join(list(zip(*model.get_topic(topic)))[0][:5]) for topic in topics]
289+
topic_keywords = [" - ".join(next(zip(*model.get_topic(topic)))[:5]) for topic in topics]
290290
topic_freq = [model.get_topic_freq(topic) for topic in topics]
291291
topic_labels = model.custom_labels_ if model.custom_labels_ else [model.topic_labels_[topic] for topic in topics]
292292
topics = [

0 commit comments

Comments
 (0)