Skip to content

Commit bc3ddfb

Browse files
authored
Merge pull request #110 from AnFreTh/main
Releasing version 0.2.0
2 parents dc2d4db + 50da9a3 commit bc3ddfb

File tree

10 files changed

+119
-79
lines changed

10 files changed

+119
-79
lines changed

README.md

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,22 +73,28 @@ You can install STREAM directly from PyPI or from the GitHub repository:
7373

7474
1. **PyPI (Recommended)**:
7575
```bash
76-
pip install stream_topic
76+
pip install stream-topic
7777
```
7878

7979
2. **GitHub**:
8080
```bash
8181
pip install git+https://github.com/AnFreTh/STREAM.git
8282
```
8383

84-
3. **Download NLTK Resources**:
85-
Ensure you have the necessary NLTK resources installed:
86-
```python
87-
import nltk
88-
nltk.download('stopwords')
89-
nltk.download('punkt')
90-
nltk.download('wordnet')
91-
nltk.download('averaged_perceptron_tagger')
84+
3. **Install requirements for add-ons**:
85+
To use STREAMS visualizations, simply run:
86+
```bash
87+
pip install stream-topic[plotting]
88+
```
89+
90+
For BERTopic, run:
91+
```bash
92+
pip install stream-topic[hdbscan]
93+
```
94+
95+
For DCTE:
96+
```bash
97+
pip install stream-topic[dcte]
9298
```
9399

94100
# 📦 Available Models

requirements.txt

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,28 @@
11
# basics
2-
numpy<=1.26.4
2+
numpy
33
pandas
4-
pyarrow
5-
scikit-learn==1.1.0
6-
scipy==1.10.1
4+
scikit-learn
5+
nltk
6+
datasets
77

88
# dl
9-
lightning==2.3.3
10-
torch==2.4.0
9+
lightning
10+
torch
11+
sentence_transformers
1112

1213
# nlp
13-
transformers==4.40.2
14-
setfit==1.0.3
15-
gensim==4.2.0
16-
umap-learn==0.5.6
17-
wordcloud==1.9.3
14+
transformers
15+
gensim
16+
umap-learn
17+
1818

1919
community
20-
networkx==3.3
20+
networkx
2121
python_louvain
2222
langdetect
23-
hdbscan==0.8.37
24-
huggingface_hub==0.23.5
25-
# nltk
26-
# datasets==2.20.0
27-
# sentence-transformers==3.0.1
28-
2923

30-
# plotting
31-
dash
32-
plotly
33-
matplotlib
3424

3525
# misc
3626
loguru
37-
ipywidgets
38-
ipykernel<6.22.0
39-
# tqdm
40-
# pre-commit
41-
optuna==3.6.1
42-
optuna-integration==3.6.0
27+
optuna
28+
optuna-integration

setup.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
# -*- coding: utf-8 -*-
33
import os
44
from pathlib import Path
5-
65
from setuptools import find_packages, setup
6+
from setuptools.command.install import install
77

88
# Package meta-data.
99
NAME = "stream_topic"
@@ -12,7 +12,27 @@
1212
DOCS = "https://stream.readthedocs.io/en/"
1313
1414
AUTHOR = "Anton Thielmann"
15-
REQUIRES_PYTHON = ">=3.6, <=3.11"
15+
REQUIRES_PYTHON = ">=3.6"
16+
17+
18+
class PostInstallCommand(install):
19+
"""Post-installation for downloading NLTK resources."""
20+
21+
def run(self):
22+
install.run(self)
23+
try:
24+
import nltk
25+
26+
nltk.download("stopwords")
27+
nltk.download("wordnet")
28+
nltk.download("punkt_tab")
29+
nltk.download("brown")
30+
nltk.download("averaged_perceptron_tagger_eng")
31+
except ImportError:
32+
print(
33+
"NLTK not installed. Ensure it is listed in install_requires or installed separately."
34+
)
35+
1636

1737
# Load the package's verison file and its content.
1838
ROOT_DIR = Path(__file__).resolve().parent
@@ -30,6 +50,13 @@
3050
if not line.startswith("#") and not line.startswith("git+")
3151
]
3252

53+
extras_require = {
54+
"plotting": ["dash", "plotly", "matplotlib", "wordcloud"],
55+
"bertopic": ["hdbscan"],
56+
"dcte": ["pyarrow", "setfit"],
57+
}
58+
59+
3360
# get long description from readme file
3461
with open(os.path.join(ROOT_DIR, "README.md")) as f:
3562
LONG_DESCRIPTION = f.read()
@@ -45,7 +72,7 @@
4572
author_email=EMAIL,
4673
python_requires=REQUIRES_PYTHON,
4774
install_requires=install_reqs,
48-
# extras_require=extras_reqs,
75+
extras_require=extras_require,
4976
license="MIT", # adapt based on your needs
5077
packages=find_packages(exclude=["examples", "examples.*", "tests", "tests.*"]),
5178
include_package_data=True,
@@ -65,4 +92,7 @@
6592
],
6693
project_urls={"Documentation": DOCS},
6794
url=HOMEPAGE,
95+
cmdclass={
96+
"install": PostInstallCommand,
97+
},
6898
)

stream_topic/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
"""Version information."""
22

33
# The following line *must* be the last in the module, exactly as formatted:
4-
__version__ = "0.1.9"
4+
__version__ = "0.2.0"

stream_topic/metrics/coherence_metrics.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,26 @@ def __init__(
355355

356356
self.n_words = n_words
357357

358+
def get_info(self):
359+
"""
360+
Get information about the metric.
361+
362+
Returns
363+
-------
364+
dict
365+
Dictionary containing model information including metric name,
366+
number of top words, number of intruders, embedding model name,
367+
metric range and metric description.
368+
"""
369+
370+
info = {
371+
"metric_name": "Embedding Coherence",
372+
"n_words": self.n_words,
373+
"description": "Embedding Coherence coherence",
374+
}
375+
376+
return info
377+
358378
def score_per_topic(self, topics):
359379
"""
360380
Calculates coherence scores for each topic individually based on embedding similarities.
@@ -414,16 +434,3 @@ def score(self, topics):
414434
"""
415435
res = self.score_per_topic(topics).values()
416436
return sum(res) / len(res)
417-
418-
419-
def _load_default_texts():
420-
"""
421-
Loads default general texts
422-
423-
Returns
424-
-------
425-
result : default 20newsgroup texts
426-
"""
427-
dataset = Dataset()
428-
dataset.fetch_dataset("20NewsGroup")
429-
return dataset.get_corpus()

stream_topic/models/DCTE.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,8 @@
44
import pyarrow as pa
55
from datasets import Dataset
66
from loguru import logger
7-
from sentence_transformers.losses import CosineSimilarityLoss
8-
from setfit import SetFitModel, TrainingArguments
9-
from setfit import Trainer as SetfitTrainer
107
from sklearn.model_selection import train_test_split
118
from sklearn.preprocessing import OneHotEncoder
12-
139
from ..commons.check_steps import check_dataset_steps
1410
from ..preprocessor._tf_idf import c_tf_idf, extract_tfidf_topics
1511
from ..utils.dataset import TMDataset
@@ -21,6 +17,19 @@
2117
# logger.add(f"{MODEL_NAME}_{time}.log", backtrace=True, diagnose=True)
2218

2319

20+
def import_setfit():
21+
try:
22+
from setfit import SetFitModel, TrainingArguments
23+
from setfit import Trainer as SetfitTrainer
24+
from sentence_transformers.losses import CosineSimilarityLoss
25+
26+
return SetFitModel, TrainingArguments, SetfitTrainer, CosineSimilarityLoss
27+
except ImportError as e:
28+
raise ImportError(
29+
"Setfit is not installed. Please install it by running 'pip install setfit'."
30+
) from e
31+
32+
2433
class DCTE(BaseModel):
2534
"""
2635
A document classification and topic extraction class that utilizes the SetFitModel for
@@ -62,6 +71,9 @@ def __init__(
6271
)
6372
self.n_topics = None
6473

74+
# Lazy import SetFit components
75+
SetFitModel, _, _, _ = import_setfit()
76+
6577
self.model = SetFitModel.from_pretrained(f"sentence-transformers/{model}")
6678
self._status = TrainingStatus.NOT_STARTED
6779
self.n_topics = None
@@ -122,7 +134,7 @@ def _get_topic_representation(self, predict_df: pd.DataFrame, top_words: int):
122134
n=top_words,
123135
)
124136

125-
one_hot_encoder = OneHotEncoder(sparse=False)
137+
one_hot_encoder = OneHotEncoder(sparse_output=False)
126138
predictions_one_hot = one_hot_encoder.fit_transform(predict_df[["predictions"]])
127139

128140
beta = tfidf
@@ -154,6 +166,8 @@ def fit(
154166
dict: A dictionary containing the extracted topics and the topic-word matrix.
155167
"""
156168

169+
_, TrainingArguments, SetfitTrainer, CosineSimilarityLoss = import_setfit()
170+
157171
assert isinstance(
158172
dataset, TMDataset
159173
), "The dataset must be an instance of TMDataset."

stream_topic/models/KmeansTM.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def fit(
220220
)
221221
self.topic_dict = extract_tfidf_topics(tfidf, count, docs_per_topic, n=100)
222222

223-
one_hot_encoder = OneHotEncoder(sparse=False)
223+
one_hot_encoder = OneHotEncoder(sparse_output=False)
224224
predictions_one_hot = one_hot_encoder.fit_transform(
225225
self.dataframe[["predictions"]]
226226
)

stream_topic/models/abstract_helper_models/base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,9 @@
33
import pickle
44
from abc import ABC, abstractmethod
55
from enum import Enum
6-
7-
import optuna
86
import torch.nn as nn
97
import umap.umap_ as umap
108
from loguru import logger
11-
from optuna.integration import PyTorchLightningPruningCallback
129

1310

1411
class BaseModel(ABC):
@@ -340,6 +337,10 @@ def optimize_hyperparameters(
340337
dict
341338
Dictionary containing the best parameters and the optimal number of topics.
342339
"""
340+
import importlib
341+
342+
optuna = importlib.import_module("optuna")
343+
343344
assert criterion in [
344345
"aic",
345346
"bic",

stream_topic/models/bertopicTM.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
from datetime import datetime
2-
3-
import hdbscan
42
import numpy as np
53
from loguru import logger
64
from sklearn.preprocessing import OneHotEncoder
@@ -121,6 +119,10 @@ def _clustering(self):
121119
Applies K-Means clustering to the reduced embeddings.
122120
"""
123121

122+
import importlib
123+
124+
hdbscan = importlib.import_module("hdbscan")
125+
124126
assert (
125127
hasattr(self, "reduced_embeddings") and self.reduced_embeddings is not None
126128
), "Reduced embeddings must be generated before clustering."
@@ -192,7 +194,7 @@ def fit(self, dataset, n_topics=None):
192194

193195
self.topic_dict = extract_tfidf_topics(tfidf, count, docs_per_topic, n=100)
194196

195-
one_hot_encoder = OneHotEncoder(sparse=False)
197+
one_hot_encoder = OneHotEncoder(sparse_output=False)
196198
predictions_one_hot = one_hot_encoder.fit_transform(
197199
self.dataframe[["predictions"]]
198200
)

stream_topic/models/cbc.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99

1010
from ..commons.check_steps import check_dataset_steps
1111
from ..preprocessor import c_tf_idf, extract_tfidf_topics
12-
from ..utils.cbc_utils import (DocumentCoherence,
13-
get_top_tfidf_words_per_document)
12+
from ..utils.cbc_utils import DocumentCoherence, get_top_tfidf_words_per_document
1413
from ..utils.dataset import TMDataset
1514
from .abstract_helper_models.base import BaseModel, TrainingStatus
1615

@@ -189,12 +188,10 @@ def fit(
189188
clusters = self.cluster_documents()
190189

191190
num_clusters = len(clusters)
192-
print(
193-
f"Iteration {iteration}: {num_clusters} clusters formed.")
191+
print(f"Iteration {iteration}: {num_clusters} clusters formed.")
194192

195193
# Prepare for the next iteration
196-
combined_documents = self.combine_documents(
197-
current_documents, clusters)
194+
combined_documents = self.combine_documents(current_documents, clusters)
198195
current_documents = combined_documents
199196
iteration += 1
200197

@@ -247,8 +244,7 @@ def fit(
247244
self.labels += 1
248245

249246
# Update the 'predictions' column in the dataframe with -1 where NaN was present
250-
self.dataframe["predictions"] = self.dataframe["predictions"].fillna(
251-
-1)
247+
self.dataframe["predictions"] = self.dataframe["predictions"].fillna(-1)
252248
self.dataframe["predictions"] += 1
253249
print("--- replaced NaN values with 0 in topics ---")
254250
print(
@@ -259,13 +255,11 @@ def fit(
259255
{"text": " ".join}
260256
)
261257
logger.info("--- Extract topics ---")
262-
tfidf, count = c_tf_idf(
263-
docs_per_topic["text"].values, m=len(self.dataframe))
264-
self.topic_dict = extract_tfidf_topics(
265-
tfidf, count, docs_per_topic, n=10)
258+
tfidf, count = c_tf_idf(docs_per_topic["text"].values, m=len(self.dataframe))
259+
self.topic_dict = extract_tfidf_topics(tfidf, count, docs_per_topic, n=10)
266260

267261
one_hot_encoder = OneHotEncoder(
268-
sparse=False
262+
sparse_output=False
269263
) # Use sparse=False to get a dense array
270264
predictions_one_hot = one_hot_encoder.fit_transform(
271265
self.dataframe[["predictions"]]

0 commit comments

Comments
 (0)