|
| 1 | +from functools import lru_cache |
1 | 2 | import os |
2 | | -from typing import Dict, List, Optional |
| 3 | +from typing import Dict, List, Final, Optional |
3 | 4 | import nltk |
4 | 5 | import jieba |
5 | 6 |
|
6 | | -resource_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "resources") |
7 | | - |
8 | | - |
9 | 7 | class NLTKHelper: |
10 | | - _stopwords: Dict[str, Optional[List[str]]] = { |
11 | | - "english": None, |
12 | | - "chinese": None, |
| 8 | + """ |
| 9 | + NLTK helper class |
| 10 | + """ |
| 11 | + |
| 12 | + SUPPORTED_LANGUAGES: Final[Dict[str, str]] = { |
| 13 | + "en": "english", |
| 14 | + "zh": "chinese" |
| 15 | + } |
| 16 | + _NLTK_PACKAGES: Final[Dict[str, str]] = { |
| 17 | + "stopwords": "corpora", |
| 18 | + "punkt_tab": "tokenizers" |
13 | 19 | } |
14 | 20 |
|
15 | | - def __init__(self): |
| 21 | + def __init__(self, nltk_data_path: Optional[str] = None): |
| 22 | + self._nltk_path = nltk_data_path or os.path.join( |
| 23 | + os.path.dirname(os.path.dirname(__file__)), |
| 24 | + "resources", |
| 25 | + "nltk_data" |
| 26 | + ) |
| 27 | + nltk.data.path.append(self._nltk_path) |
16 | 28 | jieba.initialize() |
17 | 29 |
|
| 30 | + self._ensure_nltk_data("stopwords") |
| 31 | + self._ensure_nltk_data("punkt_tab") |
| 32 | + |
| 33 | + def _ensure_nltk_data(self, package_name: str) -> None: |
| 34 | + """ |
| 35 | + ensure nltk data is downloaded |
| 36 | + """ |
| 37 | + try: |
| 38 | + nltk.data.find(f"{self._NLTK_PACKAGES[package_name]}/{package_name}") |
| 39 | + except LookupError: |
| 40 | + nltk.download(package_name, download_dir=self._nltk_path, quiet=True) |
| 41 | + |
| 42 | + @lru_cache(maxsize=2) |
18 | 43 | def get_stopwords(self, lang: str) -> List[str]: |
19 | | - nltk.data.path.append(os.path.join(resource_path, "nltk_data")) |
20 | | - if self._stopwords[lang] is None: |
21 | | - try: |
22 | | - nltk.data.find("corpora/stopwords") |
23 | | - except LookupError: |
24 | | - nltk.download("stopwords", download_dir=os.path.join(resource_path, "nltk_data")) |
25 | | - |
26 | | - self._stopwords[lang] = nltk.corpus.stopwords.words(lang) |
27 | | - return self._stopwords[lang] |
28 | | - |
29 | | - @staticmethod |
30 | | - def word_tokenize(text: str, lang: str) -> List[str]: |
| 44 | + if lang not in self.SUPPORTED_LANGUAGES: |
| 45 | + raise ValueError(f"Language {lang} is not supported.") |
| 46 | + return nltk.corpus.stopwords.words(self.SUPPORTED_LANGUAGES[lang]) |
| 47 | + |
| 48 | + def word_tokenize(self, text: str, lang: str) -> List[str]: |
| 49 | + if lang not in self.SUPPORTED_LANGUAGES: |
| 50 | + raise ValueError(f"Language {lang} is not supported.") |
31 | 51 | if lang == "zh": |
32 | 52 | return jieba.lcut(text) |
33 | | - nltk.data.path.append(os.path.join(resource_path, "nltk_data")) |
34 | | - try: |
35 | | - nltk.data.find("tokenizers/punkt_tab") |
36 | | - except LookupError: |
37 | | - nltk.download("punkt_tab", download_dir=os.path.join(resource_path, "nltk_data")) |
38 | 53 |
|
39 | 54 | return nltk.word_tokenize(text) |
0 commit comments