Skip to content

Commit c161358

Browse files
refactor: refactor NLTKHelper
1 parent 028b043 commit c161358

File tree

1 file changed

+40
-25
lines changed

1 file changed

+40
-25
lines changed

graphgen/utils/help_nltk.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,54 @@
1+
from functools import lru_cache
12
import os
2-
from typing import Dict, List, Optional
3+
from typing import Dict, List, Final, Optional
34
import nltk
45
import jieba
56

6-
resource_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "resources")
7-
8-
97
class NLTKHelper:
10-
_stopwords: Dict[str, Optional[List[str]]] = {
11-
"english": None,
12-
"chinese": None,
8+
"""
9+
NLTK helper class
10+
"""
11+
12+
SUPPORTED_LANGUAGES: Final[Dict[str, str]] = {
13+
"en": "english",
14+
"zh": "chinese"
15+
}
16+
_NLTK_PACKAGES: Final[Dict[str, str]] = {
17+
"stopwords": "corpora",
18+
"punkt_tab": "tokenizers"
1319
}
1420

15-
def __init__(self):
21+
def __init__(self, nltk_data_path: Optional[str] = None):
22+
self._nltk_path = nltk_data_path or os.path.join(
23+
os.path.dirname(os.path.dirname(__file__)),
24+
"resources",
25+
"nltk_data"
26+
)
27+
nltk.data.path.append(self._nltk_path)
1628
jieba.initialize()
1729

30+
self._ensure_nltk_data("stopwords")
31+
self._ensure_nltk_data("punkt_tab")
32+
33+
def _ensure_nltk_data(self, package_name: str) -> None:
34+
"""
35+
ensure nltk data is downloaded
36+
"""
37+
try:
38+
nltk.data.find(f"{self._NLTK_PACKAGES[package_name]}/{package_name}")
39+
except LookupError:
40+
nltk.download(package_name, download_dir=self._nltk_path, quiet=True)
41+
42+
@lru_cache(maxsize=2)
1843
def get_stopwords(self, lang: str) -> List[str]:
19-
nltk.data.path.append(os.path.join(resource_path, "nltk_data"))
20-
if self._stopwords[lang] is None:
21-
try:
22-
nltk.data.find("corpora/stopwords")
23-
except LookupError:
24-
nltk.download("stopwords", download_dir=os.path.join(resource_path, "nltk_data"))
25-
26-
self._stopwords[lang] = nltk.corpus.stopwords.words(lang)
27-
return self._stopwords[lang]
28-
29-
@staticmethod
30-
def word_tokenize(text: str, lang: str) -> List[str]:
44+
if lang not in self.SUPPORTED_LANGUAGES:
45+
raise ValueError(f"Language {lang} is not supported.")
46+
return nltk.corpus.stopwords.words(self.SUPPORTED_LANGUAGES[lang])
47+
48+
def word_tokenize(self, text: str, lang: str) -> List[str]:
49+
if lang not in self.SUPPORTED_LANGUAGES:
50+
raise ValueError(f"Language {lang} is not supported.")
3151
if lang == "zh":
3252
return jieba.lcut(text)
33-
nltk.data.path.append(os.path.join(resource_path, "nltk_data"))
34-
try:
35-
nltk.data.find("tokenizers/punkt_tab")
36-
except LookupError:
37-
nltk.download("punkt_tab", download_dir=os.path.join(resource_path, "nltk_data"))
3853

3954
return nltk.word_tokenize(text)

0 commit comments

Comments
 (0)