Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST

lightning_logs/
stream_topic_data/
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
Expand Down Expand Up @@ -77,7 +78,7 @@ target/

# Jupyter Notebook
.ipynb_checkpoints

.ipynb
# IPython
profile_default/
ipython_config.py
Expand Down Expand Up @@ -166,7 +167,8 @@ docs/_build/*
examples/lightning_logs/
*.ckpt
# *.yaml

STREAM/lightning_logs/
/STREAM/lightning_logs/
# Additional files and directories to ignore
__pycache__/
*.pyc
Expand Down Expand Up @@ -198,3 +200,18 @@ embeddings/*
stream_topic.egg-info/
stream_topic.egg-info/*
stream_topic/stream_topic_data/*
checkpoints/
lightning_logs/
test1.py
exp_lda.py
exp_nmf.py
Chinese_test.ipynb
get_metric.ipynb
metrics.ipynb
NPMI_PMI.ipynb
run_parallel.ipynb
Untitled.ipynb
run_exp.ipynb
Untitled1.ipynb
run_exp2.ipynb
result/
9 changes: 8 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,11 @@ ipykernel<6.22.0
# tqdm
# pre-commit
optuna==3.6.1
optuna-integration==3.6.0
optuna-integration==3.6.0

# Chinese requirement
hanlp==2.1.1
jieba==0.42.1
OpenCC==1.1.9
snownlp=0.12.3
thulac=0.2.2
7 changes: 5 additions & 2 deletions stream_topic/commons/check_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .load_steps import load_model_preprocessing_steps


def check_dataset_steps(dataset, logger, model_type, preprocessing_steps=None):
def check_dataset_steps(dataset, logger, model_type, preprocessing_steps=None, language=None):
"""
Check if the dataset has been preprocessed according to the required steps for the model.

Expand All @@ -25,7 +25,10 @@ def check_dataset_steps(dataset, logger, model_type, preprocessing_steps=None):
True if the dataset has been preprocessed according to the required steps, False otherwise.
"""
if preprocessing_steps is None:
preprocessing_steps = load_model_preprocessing_steps(model_type)
if language == 'chinese':
preprocessing_steps = load_model_preprocessing_steps(model_type,language=language)
else:
preprocessing_steps = load_model_preprocessing_steps(model_type)

missing_steps = []

Expand Down
16 changes: 11 additions & 5 deletions stream_topic/commons/load_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os


def load_model_preprocessing_steps(model_type, filepath=None):
def load_model_preprocessing_steps(model_type, filepath=None, language=None):
"""
Load the default preprocessing steps from a JSON file.

Expand All @@ -21,10 +21,16 @@ def load_model_preprocessing_steps(model_type, filepath=None):
if filepath is None:
# Determine the absolute path based on the current file's location
current_dir = os.path.dirname(__file__)
filepath = os.path.join(
current_dir, "..", "preprocessor", "config", "default_preprocessing_steps.json"
)
filepath = os.path.abspath(filepath)
if language == 'chinese':
filepath = os.path.join(
current_dir, "..", "preprocessor", "config", "default_preprocessing_steps_chinese.json"
)
filepath = os.path.abspath(filepath)
else:
filepath = os.path.join(
current_dir, "..", "preprocessor", "config", "default_preprocessing_steps.json"
)
filepath = os.path.abspath(filepath)

with open(filepath, "r") as file:
all_steps = json.load(file)
Expand Down
16 changes: 9 additions & 7 deletions stream_topic/metrics/TopwordEmbeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,15 @@ def __init__(
create_new_file : bool, optional
Whether to create a new file to save the embeddings to (default is True).
"""
word_embedding_model_name = MetricsConfig.PARAPHRASE_embedder or PARAPHRASE_TRANSFORMER_MODEL
if os.path.exists(word_embedding_model_name):
print(f"Loading model from local path: {word_embedding_model_name}")
word_embedding_model = SentenceTransformer(word_embedding_model_name)
else:
print(f"Downloading model: {word_embedding_model_name}")
word_embedding_model = SentenceTransformer(word_embedding_model_name)
if not word_embedding_model:
word_embedding_model_name = MetricsConfig.PARAPHRASE_embedder or PARAPHRASE_TRANSFORMER_MODEL
if os.path.exists(word_embedding_model_name):
print(f"Loading model from local path: {word_embedding_model_name}")
word_embedding_model = SentenceTransformer(word_embedding_model_name)
else:
print(f"Downloading model: {word_embedding_model_name}")
word_embedding_model = SentenceTransformer(word_embedding_model_name)


self.word_embedding_model = word_embedding_model
self.cache_to_file = cache_to_file
Expand Down
60 changes: 32 additions & 28 deletions stream_topic/metrics/_helper_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ def embed_corpus(dataset,
Returns the embedding dict
"""
# Check if embedder is a local path or model name and load accordingly
embedder_name = MetricsConfig.SENTENCE_embedder or SENTENCE_TRANSFORMER_MODEL
if os.path.exists(embedder_name):
print(f"Loading model from local path: {embedder_name}")
embedder = SentenceTransformer(embedder_name)
else:
print(f"Downloading model: {embedder_name}")
embedder = SentenceTransformer(embedder_name)
if not embedder:
embedder_name = MetricsConfig.SENTENCE_embedder or SENTENCE_TRANSFORMER_MODEL
if os.path.exists(embedder_name):
print(f"Loading model from local path: {embedder_name}")
embedder = SentenceTransformer(embedder_name)
else:
print(f"Downloading model: {embedder_name}")
embedder = SentenceTransformer(embedder_name)

if emb_filename is None:
emb_filename = str(dataset)
Expand Down Expand Up @@ -65,13 +66,14 @@ def update_corpus_dic_list(
"""

# Check if embedder is a local path or model name and load accordingly
embedder_name = MetricsConfig.SENTENCE_embedder or SENTENCE_TRANSFORMER_MODEL
if os.path.exists(embedder_name):
print(f"Loading model from local path: {embedder_name}")
embedder = SentenceTransformer(embedder_name)
else:
print(f"Downloading model: {embedder_name}")
embedder = SentenceTransformer(embedder_name)
if not embedder:
embedder_name = MetricsConfig.SENTENCE_embedder or SENTENCE_TRANSFORMER_MODEL
if os.path.exists(embedder_name):
print(f"Loading model from local path: {embedder_name}")
embedder = SentenceTransformer(embedder_name)
else:
print(f"Downloading model: {embedder_name}")
embedder = SentenceTransformer(embedder_name)

try:
emb_dic = pickle.load(open(f"{emb_path}{emb_filename}.pickle", "rb"))
Expand Down Expand Up @@ -102,13 +104,14 @@ def embed_topic(
if possible, else use the embedder.
"""
# Check if embedder is a local path or model name and load accordingly
embedder_name = MetricsConfig.SENTENCE_embedder or SENTENCE_TRANSFORMER_MODEL
if os.path.exists(embedder_name):
print(f"Loading model from local path: {embedder_name}")
embedder = SentenceTransformer(embedder_name)
else:
print(f"Downloading model: {embedder_name}")
embedder = SentenceTransformer(embedder_name)
if not embedder:
embedder_name = MetricsConfig.SENTENCE_embedder or SENTENCE_TRANSFORMER_MODEL
if os.path.exists(embedder_name):
print(f"Loading model from local path: {embedder_name}")
embedder = SentenceTransformer(embedder_name)
else:
print(f"Downloading model: {embedder_name}")
embedder = SentenceTransformer(embedder_name)

topic_embeddings = []
for topic in tqdm(topics_tw):
Expand Down Expand Up @@ -137,13 +140,14 @@ def embed_stopwords(
"""

# Check if embedder is a local path or model name and load accordingly
embedder_name = MetricsConfig.SENTENCE_embedder or SENTENCE_TRANSFORMER_MODEL
if os.path.exists(embedder_name):
print(f"Loading model from local path: {embedder_name}")
embedder = SentenceTransformer(embedder_name)
else:
print(f"Downloading model: {embedder_name}")
embedder = SentenceTransformer(embedder_name)
if not embedder:
embedder_name = MetricsConfig.SENTENCE_embedder or SENTENCE_TRANSFORMER_MODEL
if os.path.exists(embedder_name):
print(f"Loading model from local path: {embedder_name}")
embedder = SentenceTransformer(embedder_name)
else:
print(f"Downloading model: {embedder_name}")
embedder = SentenceTransformer(embedder_name)

sw_dic = {} # first create dictionary with embedding of every unique word
stopwords_set = set(stopwords)
Expand Down
Loading