Skip to content

Commit 524177f

Browse files
committed
Fix saving PerceptronTagger
1 parent 1642942 commit 524177f

File tree

1 file changed

+58
-76
lines changed

1 file changed

+58
-76
lines changed

nltk/tag/perceptron.py

Lines changed: 58 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
import logging
1313
import random
1414
from collections import defaultdict
15+
from os.path import join as path_join
16+
from tempfile import gettempdir
1517

16-
from nltk import jsontags
1718
from nltk.data import find, load
1819
from nltk.tag.api import TaggerI
1920

@@ -22,37 +23,29 @@
2223
except ImportError:
2324
pass
2425

25-
TRAINED_TAGGER_PATH = "averaged_perceptron_tagger/"
26-
27-
TAGGER_JSONS = {
28-
"eng": {
29-
"weights": "averaged_perceptron_tagger_eng.weights.json",
30-
"tagdict": "averaged_perceptron_tagger_eng.tagdict.json",
31-
"classes": "averaged_perceptron_tagger_eng.classes.json",
32-
},
33-
"rus": {
34-
"weights": "averaged_perceptron_tagger_rus.weights.json",
35-
"tagdict": "averaged_perceptron_tagger_rus.tagdict.json",
36-
"classes": "averaged_perceptron_tagger_rus.classes.json",
37-
},
38-
"xxx": {
39-
"weights": "averaged_perceptron_tagger.xxx.weights.json",
40-
"tagdict": "averaged_perceptron_tagger.xxx.tagdict.json",
41-
"classes": "averaged_perceptron_tagger.xxx.classes.json",
42-
},
43-
}
44-
45-
46-
@jsontags.register_tag
26+
# Save trained models in tmp directory by default:
27+
TRAINED_TAGGER_PATH = gettempdir()
28+
29+
TAGGER_NAME = "averaged_perceptron_tagger"
30+
31+
32+
def lang_jsons(lang="eng"):
33+
return {
34+
attr: f"{TAGGER_NAME}_{lang}.{attr}.json"
35+
for attr in ["weights", "tagdict", "classes"]
36+
}
37+
38+
39+
TAGGER_JSONS = {lang: lang_jsons(lang) for lang in ["eng", "rus", "xxx"]}
40+
41+
4742
class AveragedPerceptron:
4843
"""An averaged perceptron, as implemented by Matthew Honnibal.
4944
5045
See more implementation details here:
5146
https://explosion.ai/blog/part-of-speech-pos-tagger-in-python
5247
"""
5348

54-
json_tag = "nltk.tag.perceptron.AveragedPerceptron"
55-
5649
def __init__(self, weights=None):
5750
# Each feature gets its own weight vector, so weights is a dict-of-dicts
5851
self.weights = weights if weights else {}
@@ -129,15 +122,7 @@ def load(self, path):
129122
with open(path) as fin:
130123
self.weights = json.load(fin)
131124

132-
def encode_json_obj(self):
133-
return self.weights
134-
135-
@classmethod
136-
def decode_json_obj(cls, obj):
137-
return cls(obj)
138-
139125

140-
@jsontags.register_tag
141126
class PerceptronTagger(TaggerI):
142127
"""
143128
Greedy Averaged Perceptron tagger, as implemented by Matthew Honnibal.
@@ -151,7 +136,7 @@ class PerceptronTagger(TaggerI):
151136
>>> tagger = PerceptronTagger(load=False)
152137
153138
>>> tagger.train([[('today','NN'),('is','VBZ'),('good','JJ'),('day','NN')],
154-
... [('yes','NNS'),('it','PRP'),('beautiful','JJ')]])
139+
... [('yes','NNS'),('it','PRP'),('beautiful','JJ')]], save_loc=tagger.save_dir)
155140
156141
>>> tagger.tag(['today','is','a','beautiful','day'])
157142
[('today', 'NN'), ('is', 'PRP'), ('a', 'PRP'), ('beautiful', 'JJ'), ('day', 'NN')]
@@ -167,8 +152,6 @@ class PerceptronTagger(TaggerI):
167152
[('The', 'DT'), ('red', 'JJ'), ('cat', 'NN')]
168153
"""
169154

170-
json_tag = "nltk.tag.sequential.PerceptronTagger"
171-
172155
START = ["-START-", "-START2-"]
173156
END = ["-END-", "-END2-"]
174157

@@ -179,6 +162,8 @@ def __init__(self, load=True, lang="eng"):
179162
self.model = AveragedPerceptron()
180163
self.tagdict = {}
181164
self.classes = set()
165+
self.lang = lang
166+
self.save_dir = path_join(TRAINED_TAGGER_PATH, f"{TAGGER_NAME}_{self.lang}")
182167
if load:
183168
self.load_from_json(lang)
184169

@@ -253,43 +238,38 @@ def train(self, sentences, save_loc=None, nr_iter=5):
253238
self.model.average_weights()
254239
# Save to json files.
255240
if save_loc is not None:
256-
self.save_to_json(loc)
241+
self.save_to_json(lang=self.lang, loc=save_loc)
242+
243+
def save_to_json(self, lang="xxx", loc=None):
244+
from os import mkdir
245+
from os.path import isdir
246+
247+
if not loc:
248+
loc = self.save_dir
257249

258-
def save_to_json(self, loc, lang="xxx"):
259-
# TODO:
260-
assert os.isdir(
261-
TRAINED_TAGGER_PATH
262-
), f"Path set for saving needs to be a directory"
250+
if not isdir(loc):
251+
mkdir(loc)
263252

264-
with open(loc + TAGGER_JSONS[lang]["weights"], "w") as fout:
253+
jsons = lang_jsons(lang)
254+
255+
with open(path_join(loc, jsons["weights"]), "w") as fout:
265256
json.dump(self.model.weights, fout)
266-
with open(loc + TAGGER_JSONS[lang]["tagdict"], "w") as fout:
257+
with open(path_join(loc, jsons["tagdict"]), "w") as fout:
267258
json.dump(self.tagdict, fout)
268-
with open(loc + TAGGER_JSONS[lang]["classes"], "w") as fout:
269-
json.dump(self.classes, fout)
259+
with open(path_join(loc, jsons["classes"]), "w") as fout:
260+
json.dump(list(self.model.classes), fout)
270261

271-
def load_from_json(self, lang="eng"):
262+
def load_from_json(self, lang="eng", loc=None):
272263
# Automatically find path to the tagger if location is not specified.
273-
loc = find(f"taggers/averaged_perceptron_tagger_{lang}/")
274-
with open(loc + TAGGER_JSONS[lang]["weights"]) as fin:
264+
if not loc:
265+
loc = find(f"taggers/averaged_perceptron_tagger_{lang}/")
266+
jsons = lang_jsons(lang)
267+
with open(loc + jsons["weights"]) as fin:
275268
self.model.weights = json.load(fin)
276-
with open(loc + TAGGER_JSONS[lang]["tagdict"]) as fin:
269+
with open(loc + jsons["tagdict"]) as fin:
277270
self.tagdict = json.load(fin)
278-
with open(loc + TAGGER_JSONS[lang]["classes"]) as fin:
279-
self.classes = set(json.load(fin))
280-
281-
self.model.classes = self.classes
282-
283-
def encode_json_obj(self):
284-
return self.model.weights, self.tagdict, list(self.classes)
285-
286-
@classmethod
287-
def decode_json_obj(cls, obj):
288-
tagger = cls(load=False)
289-
tagger.model.weights, tagger.tagdict, tagger.classes = obj
290-
tagger.classes = set(tagger.classes)
291-
tagger.model.classes = tagger.classes
292-
return tagger
271+
with open(loc + jsons["classes"]) as fin:
272+
self.model.classes = set(json.load(fin))
293273

294274
def normalize(self, word):
295275
"""
@@ -381,17 +361,19 @@ def _load_data_conll_format(filename):
381361
return sentences
382362

383363

384-
def _get_pretrain_model():
385-
# Train and test on English part of ConLL data (WSJ part of Penn Treebank)
386-
# Train: section 2-11
387-
# Test : section 23
388-
tagger = PerceptronTagger()
389-
training = _load_data_conll_format("english_ptb_train.conll")
390-
testing = _load_data_conll_format("english_ptb_test.conll")
391-
print("Size of training and testing (sentence)", len(training), len(testing))
392-
# Train and save the model
393-
tagger.train(training, TRAINED_TAGGER_PATH)
394-
print("Accuracy : ", tagger.accuracy(testing))
364+
# Let's not give the impression that this is directly usable:
365+
#
366+
# def _get_pretrain_model():
367+
# # Train and test on English part of ConLL data (WSJ part of Penn Treebank)
368+
# # Train: section 2-11
369+
# # Test : section 23
370+
# tagger = PerceptronTagger()
371+
# training = _load_data_conll_format("english_ptb_train.conll")
372+
# testing = _load_data_conll_format("english_ptb_test.conll")
373+
# print("Size of training and testing (sentence)", len(training), len(testing))
374+
# # Train and save the model
375+
# tagger.train(training, save_loc=tagger.save_dir)
376+
# print("Accuracy : ", tagger.accuracy(testing))
395377

396378

397379
if __name__ == "__main__":

0 commit comments

Comments
 (0)