Skip to content

Commit e5f78ae

Browse files
committed
Misc fixes and code quality improvements
1 parent c3f1d74 commit e5f78ae

File tree

4 files changed

+29
-32
lines changed

4 files changed

+29
-32
lines changed

src/kwx/model.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def _order_and_subset_by_coherence(tm, num_topics=10, num_keywords=10):
168168
bow=tm.bow_corpus, eps=0
169169
) # cutoff probability to 0
170170

171-
topics_per_response = [response for response in topic_corpus]
171+
topics_per_response = list(topic_corpus)
172172
flat_topic_coherences = [
173173
item for sublist in topics_per_response for item in sublist
174174
]
@@ -196,7 +196,7 @@ def _order_and_subset_by_coherence(tm, num_topics=10, num_keywords=10):
196196
counts_dict = {
197197
k: v for k, v in counts_dict.items() if k in non_blank_topic_idxs
198198
}
199-
keys_ordered = sorted([k for k in counts_dict])
199+
keys_ordered = sorted(list(counts_dict))
200200

201201
# Map to the range from 0 to the number of non-blank topics.
202202
counts_dict_mapped = {i: counts_dict[k] for i, k in enumerate(keys_ordered)}
@@ -223,10 +223,12 @@ def _order_and_subset_by_coherence(tm, num_topics=10, num_keywords=10):
223223
# Create selection indexes for each topic given its average coherence
224224
# and how many keywords are wanted.
225225
selection_indexes = [
226-
list(range(int(math.floor(num_keywords * a))))
227-
if math.floor(num_keywords * a) > 0
228-
else [0]
229-
for i, a in enumerate(ordered_topic_averages)
226+
(
227+
list(range(int(math.floor(num_keywords * a))))
228+
if math.floor(num_keywords * a) > 0
229+
else [0]
230+
)
231+
for a in ordered_topic_averages
230232
]
231233

232234
total_indexes = sum(len(i) for i in selection_indexes)
@@ -431,9 +433,8 @@ def extract_kws(
431433

432434
assert method in valid_methods, (
433435
"The value for the 'method' argument is invalid. Please choose one of "
434-
+ " ".join(m for m in valid_methods)
435-
+ "."
436-
)
436+
+ " ".join(valid_methods)
437+
) + "."
437438

438439
if method.lower() == "tfidf":
439440
assert corpuses_to_compare is not None, (
@@ -593,13 +594,13 @@ def extract_kws(
593594
new_words_to_ignore = words_to_ignore # initialize so that it can be added to
594595
while more_words_to_ignore:
595596
if first_iteration:
596-
print("The {} keywords are:\n".format(method.upper()))
597-
print(keywords)
597+
print(f"The {method.upper()} keywords are:\n")
598598

599599
else:
600600
print("\n")
601-
print("The new {} keywords are:\n".format(method.upper()))
602-
print(keywords)
601+
print(f"The new {method.upper()} keywords are:\n")
602+
603+
print(keywords)
603604

604605
new_words_to_ignore, words_added = utils.prompt_for_word_removal(
605606
words_to_ignore=new_words_to_ignore
@@ -618,14 +619,12 @@ def extract_kws(
618619
more_words_to_ignore = False
619620

620621
if output_language != input_language:
621-
translated_keywords = utils.translate_output(
622+
return utils.translate_output(
622623
outputs=keywords,
623624
input_language=input_language,
624625
output_language=output_language,
625626
)
626627

627-
return translated_keywords
628-
629628
else:
630629
return keywords
631630

@@ -688,9 +687,8 @@ def gen_files(
688687
-------
689688
A directory or zip file in the current working or save_dir directory.
690689
"""
691-
if isinstance(method, list):
692-
if len(method) == 1:
693-
method = method[0]
690+
if isinstance(method, list) and len(method) == 1:
691+
method = method[0]
694692

695693
if save_dir is None:
696694
save_dir = f"keyword_extraction_{time.strftime('%Y%m%d-%H%M%S')}"
@@ -699,12 +697,12 @@ def gen_files(
699697
if save_dir[-4:] != ".zip":
700698
save_dir += ".zip"
701699

702-
if os.path.exists(os.getcwd() + "/" + save_dir):
703-
os.remove(os.getcwd() + "/" + save_dir)
700+
if os.path.exists(f"{os.getcwd()}/{save_dir}"):
701+
os.remove(f"{os.getcwd()}/{save_dir}")
704702

705703
else:
706704
# Create the directory
707-
save_dir = os.getcwd() + "/" + save_dir
705+
save_dir = f"{os.getcwd()}/{save_dir}"
708706
os.makedirs(save_dir)
709707
if os.path.exists(save_dir):
710708
os.rmdir(save_dir)

src/kwx/topic_model.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,7 @@ def __init__(self, num_topics=10, method="lda", bert_model=None):
4848
modeling_methods = ["lda", "bert"]
4949
if method not in modeling_methods:
5050
ValueError(
51-
"The indicated method is invalid. Please choose from {}.".format(
52-
modeling_methods
53-
)
51+
f"The indicated method is invalid. Please choose from {modeling_methods}."
5452
)
5553

5654
self.num_topics = num_topics
@@ -63,7 +61,7 @@ def __init__(self, num_topics=10, method="lda", bert_model=None):
6361
self.vec = {}
6462
self.gamma = 15 # parameter for relative importance of LDA
6563
self.method = method.lower()
66-
self.id = method + "_" + datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
64+
self.id = f"{method}_" + datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
6765

6866
def _vectorize(self, text_corpus, method=None, **kwargs):
6967
"""

src/kwx/utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -438,11 +438,12 @@ def clean(
438438

439439
except OSError:
440440
try:
441-
os.system("python -m spacy download {}".format(input_language))
441+
os.system(f"python -m spacy download {input_language}")
442442
nlp = spacy.load(input_language)
443443
base_tokens = _lemmatize(
444444
tokens=tokens_remove_unwanted, nlp=nlp, verbose=verbose
445445
)
446+
446447
except OSError:
447448
nlp = None
448449

@@ -525,12 +526,12 @@ def clean(
525526
selected_idxs = list(range(len(text_corpus)))
526527

527528
else:
528-
selected_idxs = [
529-
i
530-
for i in random.choices(
529+
selected_idxs = list(
530+
random.choices(
531531
range(len(text_corpus)), k=int(sample_size * len(text_corpus))
532532
)
533-
]
533+
)
534+
534535
text_corpus = [
535536
_combine_texts_to_str(text_corpus=text_corpus[i]) for i in selected_idxs
536537
]

src/kwx/visuals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def jaccard_similarity(topic_1, topic_2):
259259
)
260260

261261
if "stability" in metrics:
262-
for j in range(0, len(topic_nums_to_compare) - 1):
262+
for j in range(len(topic_nums_to_compare) - 1):
263263
jaccard_sims = []
264264
for t1, topic1 in enumerate( # pylint: disable=unused-variable
265265
topics_dict[topic_nums_to_compare[j]]

0 commit comments

Comments
 (0)