Skip to content

Commit 7cc69e9

Browse files
authored
Merge pull request #1232 from PyThaiNLP/copilot/fix-inconsistent-type-hints
Fix type hint inconsistencies and improve type precision
2 parents 7f4a62f + ba221c3 commit 7cc69e9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+259
-212
lines changed

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,10 @@ classifiers = [
6161
]
6262

6363
# Core dependencies
64-
dependencies = ["tzdata; sys_platform == 'win32'"]
64+
dependencies = [
65+
"importlib_resources; python_version < '3.11'",
66+
"tzdata; sys_platform == 'win32'",
67+
]
6568

6669
[project.optional-dependencies]
6770

pythainlp/augment/lm/wangchanberta.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(self):
3232
self.MASK_TOKEN = self.tokenizer.mask_token
3333

3434
def generate(self, sentence: str, num_replace_tokens: int = 3):
35-
self.sent2 = []
35+
sent2: list[str] = []
3636
self.input_text = sentence
3737
sent = [
3838
i for i in self.tokenizer.tokenize(self.input_text) if i != "▁"
@@ -42,13 +42,13 @@ def generate(self, sentence: str, num_replace_tokens: int = 3):
4242
masked_text = self.input_text
4343
for i in range(num_replace_tokens):
4444
masked_text = masked_text + self.MASK_TOKEN
45-
self.sent2 += [
45+
sent2 += [
4646
str(j["sequence"]).replace("<s> ", "").replace("</s>", "")
4747
for j in self.fill_mask(masked_text)
48-
if j["sequence"] not in self.sent2
48+
if j["sequence"] not in sent2
4949
]
5050
masked_text = self.input_text
51-
return self.sent2
51+
return sent2
5252

5353
def augment(self, sentence: str, num_replace_tokens: int = 3) -> list[str]:
5454
"""Text augmentation from WangchanBERTa
@@ -73,6 +73,5 @@ def augment(self, sentence: str, num_replace_tokens: int = 3) -> list[str]:
7373
'ช้างมีทั้งหมด 50 ตัว บนนั้น',
7474
'ช้างมีทั้งหมด 50 ตัว บนหัว']
7575
"""
76-
self.sent2 = []
77-
self.sent2 = self.generate(sentence, num_replace_tokens)
78-
return self.sent2
76+
sent2 = self.generate(sentence, num_replace_tokens)
77+
return sent2

pythainlp/augment/word2vec/bpemb_wv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ def load_w2v(self):
3535

3636
def augment(
3737
self, sentence: str, n_sent: int = 1, p: float = 0.7
38-
) -> list[tuple[str]]:
38+
) -> list[str]:
3939
"""Text Augment using word2vec from BPEmb
4040
4141
:param str sentence: Thai sentence
4242
:param int n_sent: number of sentence
4343
:param float p: probability of word
4444
4545
:return: list of synonyms
46-
:rtype: List[str]
46+
:rtype: list[str]
4747
:Example:
4848
::
4949

pythainlp/augment/word2vec/core.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
from __future__ import annotations
55

66
import itertools
7+
from typing import Callable
78

89

910
class Word2VecAug:
1011
def __init__(
11-
self, model: str, tokenize: object, type: str = "file"
12+
self, model: str, tokenize: Callable[[str], list[str]], type: str = "file"
1213
) -> None:
1314
""":param str model: path of model
14-
:param object tokenize: tokenize function
15+
:param Callable[[str], list[str]] tokenize: tokenize function
1516
:param str type: model type (file, binary)
1617
"""
1718
import gensim.models.keyedvectors as word2vec
@@ -27,10 +28,10 @@ def __init__(
2728
self.model = model
2829
self.dict_wv = list(self.model.key_to_index.keys())
2930

30-
def modify_sent(self, sent: str, p: float = 0.7) -> list[list[str]]:
31-
""":param str sent: text of sentence
31+
def modify_sent(self, sent: list[str], p: float = 0.7) -> list[list[str]]:
32+
""":param list[str] sent: list of tokens
3233
:param float p: probability
33-
:rtype: List[List[str]]
34+
:rtype: list[list[str]]
3435
"""
3536
list_sent_new = []
3637
for i in sent:
@@ -46,17 +47,17 @@ def modify_sent(self, sent: str, p: float = 0.7) -> list[list[str]]:
4647

4748
def augment(
4849
self, sentence: str, n_sent: int = 1, p: float = 0.7
49-
) -> list[tuple[str]]:
50+
) -> list[tuple[str, ...]]:
5051
""":param str sentence: text of sentence
5152
:param int n_sent: maximum number of synonymous sentences
5253
:param int p: probability
5354
5455
:return: list of synonyms
55-
:rtype: List[Tuple[str]]
56+
:rtype: list[tuple[str, ...]]
5657
"""
57-
self.sentence = self.tokenizer(sentence)
58-
self.list_synonym = self.modify_sent(self.sentence, p=p)
58+
_sentence = self.tokenizer(sentence)
59+
_list_synonym = self.modify_sent(_sentence, p=p)
5960
new_sentences = []
60-
for x in list(itertools.product(*self.list_synonym))[0:n_sent]:
61+
for x in list(itertools.product(*_list_synonym))[0:n_sent]:
6162
new_sentences.append(x)
6263
return new_sentences

pythainlp/benchmarks/word_tokenization.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -148,35 +148,35 @@ def compute_stats(ref_sample: str, raw_sample: str) -> dict:
148148
:return: metrics at character- and word-level and indicators of correctly tokenized words
149149
:rtype: dict[str, Union[float, str]]
150150
"""
151-
ref_sample = _binary_representation(ref_sample)
152-
sample = _binary_representation(raw_sample)
151+
ref_sample_arr = _binary_representation(ref_sample)
152+
sample_arr = _binary_representation(raw_sample)
153153

154154
# Compute character-level statistics
155-
c_pos_pred, c_neg_pred = np.argwhere(sample == 1), np.argwhere(sample == 0)
155+
c_pos_pred, c_neg_pred = np.argwhere(sample_arr == 1), np.argwhere(sample_arr == 0)
156156

157-
c_pos_pred = c_pos_pred[c_pos_pred < ref_sample.shape[0]]
158-
c_neg_pred = c_neg_pred[c_neg_pred < ref_sample.shape[0]]
157+
c_pos_pred = c_pos_pred[c_pos_pred < ref_sample_arr.shape[0]]
158+
c_neg_pred = c_neg_pred[c_neg_pred < ref_sample_arr.shape[0]]
159159

160-
c_tp = np.sum(ref_sample[c_pos_pred] == 1)
161-
c_fp = np.sum(ref_sample[c_pos_pred] == 0)
160+
c_tp = np.sum(ref_sample_arr[c_pos_pred] == 1)
161+
c_fp = np.sum(ref_sample_arr[c_pos_pred] == 0)
162162

163-
c_tn = np.sum(ref_sample[c_neg_pred] == 0)
164-
c_fn = np.sum(ref_sample[c_neg_pred] == 1)
163+
c_tn = np.sum(ref_sample_arr[c_neg_pred] == 0)
164+
c_fn = np.sum(ref_sample_arr[c_neg_pred] == 1)
165165

166166
# Compute word-level statistics
167167

168168
# Find correctly tokenized words in the reference sample
169-
word_boundaries = _find_word_boundaries(ref_sample)
169+
word_boundaries = _find_word_boundaries(ref_sample_arr)
170170

171171
# Find correctly tokenized words in the sample
172-
ss_boundaries = _find_word_boundaries(sample)
172+
ss_boundaries = _find_word_boundaries(sample_arr)
173173
tokenization_indicators = _find_words_correctly_tokenised(
174174
word_boundaries, ss_boundaries
175175
)
176176

177177
correctly_tokenised_words = np.sum(tokenization_indicators)
178178

179-
tokenization_indicators = list(map(str, tokenization_indicators))
179+
tokenization_indicators_str = list(map(str, tokenization_indicators))
180180

181181
return {
182182
"char_level": {
@@ -187,11 +187,11 @@ def compute_stats(ref_sample: str, raw_sample: str) -> dict:
187187
},
188188
"word_level": {
189189
"correctly_tokenised_words": correctly_tokenised_words,
190-
"total_words_in_sample": np.sum(sample),
191-
"total_words_in_ref_sample": np.sum(ref_sample),
190+
"total_words_in_sample": np.sum(sample_arr),
191+
"total_words_in_ref_sample": np.sum(ref_sample_arr),
192192
},
193193
"global": {
194-
"tokenisation_indicators": "".join(tokenization_indicators)
194+
"tokenisation_indicators": "".join(tokenization_indicators_str)
195195
},
196196
}
197197

@@ -246,14 +246,14 @@ def _find_word_boundaries(bin_reps) -> list:
246246
def _find_words_correctly_tokenised(
247247
ref_boundaries: list[tuple[int, int]],
248248
predicted_boundaries: list[tuple[int, int]],
249-
) -> tuple[int]:
249+
) -> tuple[int, ...]:
250250
"""Find whether each word is correctly tokenized.
251251
252252
:param list[tuple(int, int)] ref_boundaries: word boundaries of reference tokenization
253253
:param list[tuple(int, int)] predicted_boundaries: word boundareies of predicted tokenization
254254
255255
:return: binary sequence where 1 indicates the corresponding word is tokenized correctly
256-
:rtype: tuple[int]
256+
:rtype: tuple[int, ...]
257257
"""
258258
ref_b = dict(zip(ref_boundaries, [1] * len(ref_boundaries)))
259259

pythainlp/corpus/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@
6363
_THAI_ORST_WORDS: frozenset[str] = frozenset()
6464

6565
_THAI_DICT: dict[str, list[str]] = {}
66-
_THAI_WSD_DICT: dict[str, list[str]] = {}
67-
_THAI_SYNONYMS: dict[str, list[str]] = {}
66+
_THAI_WSD_DICT: dict[str, Union[list[str], list[list[str]]]] = {}
67+
_THAI_SYNONYMS: dict[str, Union[list[str], list[list[str]]]] = {}
6868

6969

7070
def countries() -> frozenset[str]:

pythainlp/corpus/core.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def json(self) -> dict:
4040
try:
4141
return json.loads(self._content.decode("utf-8"))
4242
except (json.JSONDecodeError, UnicodeDecodeError) as err:
43-
raise ValueError(f"Failed to parse JSON response: {err}")
43+
raise ValueError(f"Failed to parse JSON response: {err}") from err
4444

4545

4646
def get_corpus_db(url: str) -> Optional[_ResponseWrapper]:
@@ -298,9 +298,17 @@ def get_corpus_path(name: str, version: str = "", force: bool = False) -> Option
298298
if corpus_db_detail and corpus_db_detail.get("filename"):
299299
# corpus is in the local catalog, get full path to the file
300300
if corpus_db_detail.get("is_folder"):
301-
path = get_full_data_path(corpus_db_detail.get("foldername"))
301+
foldername = corpus_db_detail.get("foldername")
302+
if foldername:
303+
path = get_full_data_path(foldername)
304+
else:
305+
return None
302306
else:
303-
path = get_full_data_path(corpus_db_detail.get("filename"))
307+
filename = corpus_db_detail.get("filename")
308+
if filename:
309+
path = get_full_data_path(filename)
310+
else:
311+
return None
304312
# check if the corpus file actually exists, download it if not
305313
if not os.path.exists(path):
306314
download(name, version=version, force=force)
@@ -736,10 +744,14 @@ def remove(name: str) -> bool:
736744
if data[0].get("is_folder"):
737745
import shutil
738746

739-
os.remove(get_full_data_path(data[0].get("filename")))
740-
shutil.rmtree(path, ignore_errors=True)
747+
filename = data[0].get("filename")
748+
if filename:
749+
os.remove(get_full_data_path(filename))
750+
if path:
751+
shutil.rmtree(path, ignore_errors=True)
741752
else:
742-
os.remove(path)
753+
if path:
754+
os.remove(path)
743755
for i, corpus in db["_default"].copy().items():
744756
if corpus["name"] == name:
745757
del db["_default"][i]
@@ -751,7 +763,10 @@ def remove(name: str) -> bool:
751763

752764

753765
def get_path_folder_corpus(name: str, version: str, *path: str) -> str:
754-
return os.path.join(get_corpus_path(name, version), *path)
766+
corpus_path = get_corpus_path(name, version)
767+
if corpus_path is None:
768+
raise ValueError(f"Corpus path not found for {name} version {version}")
769+
return os.path.join(corpus_path, *path)
755770

756771

757772
def make_safe_directory_name(name: str) -> str:

pythainlp/corpus/util.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,13 @@ def find_badwords(
4343
:return: words that are considered to make `tokenize` perform badly
4444
:rtype: Set[str]
4545
"""
46-
right = Counter()
47-
wrong = Counter()
46+
right: Counter[str] = Counter()
47+
wrong: Counter[str] = Counter()
4848

4949
for train_words in training_data:
50-
train_set = set(index_pairs(train_words))
51-
test_words = tokenize("".join(train_words))
50+
train_words_list = list(train_words)
51+
train_set = set(index_pairs(train_words_list))
52+
test_words = tokenize("".join(train_words_list))
5253
test_pairs = index_pairs(test_words)
5354
for w, p in zip(test_words, test_pairs):
5455
if p in train_set:

pythainlp/lm/text_util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
def calculate_ngram_counts(
99
list_words: list[str], n_min: int = 2, n_max: int = 4
10-
) -> dict[tuple[str], int]:
10+
) -> dict[tuple[str, ...], int]:
1111
"""Calculates the counts of n-grams in the list words for the specified range.
1212
1313
:param List[str] list_words: List of string
@@ -20,7 +20,7 @@ def calculate_ngram_counts(
2020
if not list_words:
2121
return {}
2222

23-
ngram_counts = {}
23+
ngram_counts: dict[tuple[str, ...], int] = {}
2424

2525
for n in range(n_min, n_max + 1):
2626
for i in range(len(list_words) - n + 1):
@@ -51,7 +51,7 @@ def remove_repeated_ngrams(string_list: list[str], n: int = 2) -> list[str]:
5151

5252
unique_ngrams = set()
5353

54-
output_list = []
54+
output_list: list[str] = []
5555

5656
for i in range(len(string_list)):
5757
if i + n <= len(string_list):

pythainlp/parse/core.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
# SPDX-License-Identifier: Apache-2.0
44
from __future__ import annotations
55

6-
from typing import List, Optional, Union
6+
from typing import Any, List, Optional, Union
77

8-
_tagger = None
8+
_tagger: Optional[Any] = None
99
_tagger_name = ""
1010

1111

@@ -100,22 +100,22 @@ def dependency_parsing(
100100
if engine == "esupar":
101101
from pythainlp.parse.esupar_engine import Parse
102102

103-
_tagger = Parse(model=model)
103+
_tagger = Parse(model=model if model else "th")
104104
elif engine == "transformers_ud":
105-
from pythainlp.parse.transformers_ud import Parse
105+
from pythainlp.parse.transformers_ud import Parse # type: ignore[assignment] # noqa: I001
106106

107-
_tagger = Parse(model=model)
107+
_tagger = Parse(model=model if model else "KoichiYasuoka/deberta-base-thai-ud-head")
108108
elif engine == "spacy_thai":
109-
from pythainlp.parse.spacy_thai_engine import Parse
109+
from pythainlp.parse.spacy_thai_engine import Parse # type: ignore[assignment] # noqa: I001
110110

111111
_tagger = Parse()
112112
elif engine == "ud_goeswith":
113-
from pythainlp.parse.ud_goeswith import Parse
113+
from pythainlp.parse.ud_goeswith import Parse # type: ignore[assignment] # noqa: I001
114114

115-
_tagger = Parse(model=model)
115+
_tagger = Parse(model=model if model else "KoichiYasuoka/deberta-base-thai-ud-goeswith")
116116
else:
117117
raise NotImplementedError("The engine doesn't support.")
118118

119119
_tagger_name = engine
120120

121-
return _tagger(text, tag=tag)
121+
return _tagger(text, tag=tag) # type: ignore[misc]

0 commit comments

Comments
 (0)