77import string
88from collections import Counter
99
10- from .utils import load_file , write_file , _parse_into_words
10+ from .utils import load_file , write_file , _parse_into_words , ENSURE_UNICODE
1111
1212
1313class SpellChecker (object ):
@@ -62,10 +62,12 @@ def __init__(
6262
6363 def __contains__ (self , key ):
6464 """ setup easier known checks """
65+ key = ENSURE_UNICODE (key )
6566 return key in self ._word_frequency
6667
6768 def __getitem__ (self , key ):
6869 """ setup easier frequency checks """
70+ key = ENSURE_UNICODE (key )
6971 return self ._word_frequency [key ]
7072
7173 @property
@@ -105,6 +107,7 @@ def split_words(self, text):
105107 text (str): The text to split into individual words
106108 Returns:
107109 list(str): A listing of all words in the provided text """
110+ text = ENSURE_UNICODE (text )
108111 return self ._tokenizer (text )
109112
110113 def export (self , filepath , encoding = "utf-8" , gzipped = True ):
@@ -131,6 +134,7 @@ def word_probability(self, word, total_words=None):
131134 float: The probability that the word is the correct word """
132135 if total_words is None :
133136 total_words = self ._word_frequency .total_words
137+ word = ENSURE_UNICODE (word )
134138 return self ._word_frequency .dictionary [word ] / total_words
135139
136140 def correction (self , word ):
@@ -140,6 +144,7 @@ def correction(self, word):
140144 word (str): The word to correct
141145 Returns:
142146 str: The most likely candidate """
147+ word = ENSURE_UNICODE (word )
143148 candidates = list (self .candidates (word ))
144149 return max (sorted (candidates ), key = self .word_probability )
145150
@@ -151,6 +156,7 @@ def candidates(self, word):
151156 word (str): The word for which to calculate candidate spellings
152157 Returns:
153158 set: The set of words that are possible candidates """
159+ word = ENSURE_UNICODE (word )
154160 if self .known ([word ]): # short-cut if word is correct already
155161 return {word }
156162 # get edit distance 1...
@@ -174,6 +180,7 @@ def known(self, words):
174180 Returns:
175181 set: The set of those words from the input that are in the \
176182 corpus """
183+ words = [ENSURE_UNICODE (w ) for w in words ]
177184 tmp = [w if self ._case_sensitive else w .lower () for w in words ]
178185 return set (
179186 w
@@ -191,6 +198,7 @@ def unknown(self, words):
191198 Returns:
192199 set: The set of those words from the input that are not in \
193200 the corpus """
201+ words = [ENSURE_UNICODE (w ) for w in words ]
194202 tmp = [
195203 w if self ._case_sensitive else w .lower ()
196204 for w in words
@@ -207,7 +215,7 @@ def edit_distance_1(self, word):
207215 Returns:
208216 set: The set of strings that are edit distance one from the \
209217 provided word """
210- word = word .lower ()
218+ word = ENSURE_UNICODE ( word ) .lower ()
211219 if self ._check_if_should_check (word ) is False :
212220 return {word }
213221 letters = self ._word_frequency .letters
@@ -227,7 +235,7 @@ def edit_distance_2(self, word):
227235 Returns:
228236 set: The set of strings that are edit distance two from the \
229237 provided word """
230- word = word .lower ()
238+ word = ENSURE_UNICODE ( word ) .lower ()
231239 return [
232240 e2 for e1 in self .edit_distance_1 (word ) for e2 in self .edit_distance_1 (e1 )
233241 ]
@@ -241,8 +249,13 @@ def __edit_distance_alt(self, words):
241249 Returns:
242250 set: The set of strings that are edit distance two from the \
243251 provided words """
244- words = [word .lower () for word in words ]
245- return [e2 for e1 in words for e2 in self .edit_distance_1 (e1 )]
252+ words = [ENSURE_UNICODE (w ) for w in words ]
253+ tmp = [
254+ w if self ._case_sensitive else w .lower ()
255+ for w in words
256+ if self ._check_if_should_check (w )
257+ ]
258+ return [e2 for e1 in tmp for e2 in self .edit_distance_1 (e1 )]
246259
247260 @staticmethod
248261 def _check_if_should_check (word ):
@@ -283,11 +296,13 @@ def __init__(self, tokenizer=None, case_sensitive=False):
283296
284297 def __contains__ (self , key ):
285298 """ turn on contains """
299+ key = ENSURE_UNICODE (key )
286300 key = key if self ._case_sensitive else key .lower ()
287301 return key in self ._dictionary
288302
289303 def __getitem__ (self , key ):
290304 """ turn on getitem """
305+ key = ENSURE_UNICODE (key )
291306 key = key if self ._case_sensitive else key .lower ()
292307 return self ._dictionary [key ]
293308
@@ -298,6 +313,7 @@ def pop(self, key, default=None):
298313 Args:
299314 key (str): The key to remove
300315 default (obj): The value to return if key is not present """
316+ key = ENSURE_UNICODE (key )
301317 key = key if self ._case_sensitive else key .lower ()
302318 return self ._dictionary .pop (key , default )
303319
@@ -344,6 +360,7 @@ def tokenize(self, text):
344360 str: The next `word` in the tokenized string
345361 Note:
346362 This is the same as the `spellchecker.split_words()` """
363+ text = ENSURE_UNICODE (text )
347364 for word in self ._tokenizer (text ):
348365 yield word if self ._case_sensitive else word .lower ()
349366
@@ -408,6 +425,7 @@ def load_text(self, text, tokenizer=None):
408425 text (str): The text to be loaded
409426 tokenizer (function): The function to use to tokenize a string
410427 """
428+ text = ENSURE_UNICODE (text )
411429 if tokenizer :
412430 words = [x if self ._case_sensitive else x .lower () for x in tokenizer (text )]
413431 else :
@@ -421,6 +439,7 @@ def load_words(self, words):
421439
422440 Args:
423441 words (list): The list of words to be loaded """
442+ words = [ENSURE_UNICODE (w ) for w in words ]
424443 self ._dictionary .update (
425444 [word if self ._case_sensitive else word .lower () for word in words ]
426445 )
@@ -431,13 +450,15 @@ def add(self, word):
431450
432451 Args:
433452 word (str): The word to add """
453+ word = ENSURE_UNICODE (word )
434454 self .load_words ([word ])
435455
436456 def remove_words (self , words ):
437457 """ Remove a list of words from the word frequency list
438458
439459 Args:
440460 words (list): The list of words to remove """
461+ words = [ENSURE_UNICODE (w ) for w in words ]
441462 for word in words :
442463 self ._dictionary .pop (word if self ._case_sensitive else word .lower ())
443464 self ._update_dictionary ()
@@ -447,6 +468,7 @@ def remove(self, word):
447468
448469 Args:
449470 word (str): The word to remove """
471+ word = ENSURE_UNICODE (word )
450472 self ._dictionary .pop (word if self ._case_sensitive else word .lower ())
451473 self ._update_dictionary ()
452474
0 commit comments