Skip to content

Commit e8bea27

Browse files
committed
Add demo function
1 parent aedb5f9 commit e8bea27

File tree

1 file changed

+32
-49
lines changed

1 file changed

+32
-49
lines changed

nltk/tag/perceptron.py

Lines changed: 32 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,8 @@
2424
except ImportError:
2525
pass
2626

27-
# Save trained models in tmp directory by default:
28-
TRAINED_TAGGER_PATH = gettempdir()
2927

30-
TAGGER_NAME = "averaged_perceptron_tagger"
31-
32-
33-
def lang_jsons(lang="eng"):
34-
return {
35-
attr: f"{TAGGER_NAME}_{lang}.{attr}.json"
36-
for attr in ["weights", "tagdict", "classes"]
37-
}
38-
39-
40-
TAGGER_JSONS = {lang: lang_jsons(lang) for lang in ["eng", "rus", "xxx"]}
28+
# TAGGER_JSONS = {lang: lang_json(lang) for lang in ["eng", "rus", "xxx"]}
4129

4230

4331
@jsontags.register_tag
@@ -185,10 +173,21 @@ def __init__(self, load=True, lang="eng", loc=None):
185173
self.tagdict = {}
186174
self.classes = set()
187175
self.lang = lang
188-
self.save_dir = path_join(TRAINED_TAGGER_PATH, f"{TAGGER_NAME}_{self.lang}")
176+
# Save trained models in tmp directory by default:
177+
self.TRAINED_TAGGER_PATH = gettempdir()
178+
self.TAGGER_NAME = "averaged_perceptron_tagger"
179+
self.save_dir = path_join(
180+
self.TRAINED_TAGGER_PATH, f"{self.TAGGER_NAME}_{self.lang}"
181+
)
189182
if load:
190183
self.load_from_json(lang, loc)
191184

185+
def lang_jsons(self, lang="eng"):
186+
return {
187+
attr: f"{self.TAGGER_NAME}_{lang}.{attr}.json"
188+
for attr in ["weights", "tagdict", "classes"]
189+
}
190+
192191
def tag(self, tokens, return_conf=False, use_tagdict=True):
193192
"""
194193
Tag tokenized sentences.
@@ -272,7 +271,7 @@ def save_to_json(self, lang="xxx", loc=None):
272271
if not isdir(loc):
273272
mkdir(loc)
274273

275-
jsons = lang_jsons(lang)
274+
jsons = self.lang_jsons(lang)
276275

277276
with open(path_join(loc, jsons["weights"]), "w") as fout:
278277
json.dump(self.model.weights, fout)
@@ -285,7 +284,7 @@ def load_from_json(self, lang="eng", loc=None):
285284
# Automatically find path to the tagger if location is not specified.
286285
if not loc:
287286
loc = find(f"taggers/averaged_perceptron_tagger_{lang}")
288-
jsons = lang_jsons(lang)
287+
jsons = self.lang_jsons(lang)
289288
with open(path_join(loc, jsons["weights"])) as fin:
290289
self.model.weights = json.load(fin)
291290
with open(path_join(loc, jsons["tagdict"])) as fin:
@@ -376,40 +375,24 @@ def _pc(n, d):
376375
return (n / d) * 100
377376

378377

379-
def _load_data_conll_format(filename):
380-
print("Read from file: ", filename)
381-
with open(filename, "rb") as fin:
382-
sentences = []
383-
sentence = []
384-
for line in fin.readlines():
385-
line = line.strip()
386-
# print line
387-
if len(line) == 0:
388-
sentences.append(sentence)
389-
sentence = []
390-
continue
391-
tokens = line.split("\t")
392-
word = tokens[1]
393-
tag = tokens[4]
394-
sentence.append((word, tag))
395-
return sentences
378+
def _train_and_test(lang="sv"):
379+
"""
380+
Train and test on 'lang' part of universal_treebanks corpus, which includes
381+
train and test sets in conll format for 'de', 'es', 'fi', 'fr' and 'sv'.
382+
Finds 0.94 accuracy on 'sv' (Swedish) test set.
383+
"""
384+
from nltk.corpus import universal_treebanks as utb
396385

397-
398-
# Let's not give the impression that this is directly usable:
399-
#
400-
# def _get_pretrain_model():
401-
# # Train and test on English part of ConLL data (WSJ part of Penn Treebank)
402-
# # Train: section 2-11
403-
# # Test : section 23
404-
# tagger = PerceptronTagger()
405-
# training = _load_data_conll_format("english_ptb_train.conll")
406-
# testing = _load_data_conll_format("english_ptb_test.conll")
407-
# print("Size of training and testing (sentence)", len(training), len(testing))
408-
# # Train and save the model
409-
# tagger.train(training, save_loc=tagger.save_dir)
410-
# print("Accuracy : ", tagger.accuracy(testing))
386+
tagger = PerceptronTagger(load=False, lang=lang)
387+
training = utb.tagged_sents(f"ch/{lang}/{lang}-universal-ch-train.conll")
388+
testing = utb.tagged_sents(f"ch/{lang}/{lang}-universal-ch-test.conll")
389+
print(
390+
f"(Lang = {lang}) training on {len(training)} and testing on {len(testing)} sentences"
391+
)
392+
# Train and save the model
393+
tagger.train(training, save_loc=tagger.save_dir)
394+
print("Accuracy : ", tagger.accuracy(testing))
411395

412396

413397
if __name__ == "__main__":
414-
# _get_pretrain_model()
415-
pass
398+
_train_and_test()

0 commit comments

Comments
 (0)