Skip to content

Commit ebaf5f9

Browse files
authored
Merge pull request nltk#3383 from ekaf/hotfix-3379
Fix saving PerceptronTagger
2 parents 4910b8e + e13888d commit ebaf5f9

File tree

1 file changed

+79
-85
lines changed

1 file changed

+79
-85
lines changed

nltk/tag/perceptron.py

Lines changed: 79 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import logging
1313
import random
1414
from collections import defaultdict
15+
from os.path import join as path_join
16+
from tempfile import gettempdir
1517

1618
from nltk import jsontags
1719
from nltk.data import find, load
@@ -22,26 +24,6 @@
2224
except ImportError:
2325
pass
2426

25-
TRAINED_TAGGER_PATH = "averaged_perceptron_tagger/"
26-
27-
TAGGER_JSONS = {
28-
"eng": {
29-
"weights": "averaged_perceptron_tagger_eng.weights.json",
30-
"tagdict": "averaged_perceptron_tagger_eng.tagdict.json",
31-
"classes": "averaged_perceptron_tagger_eng.classes.json",
32-
},
33-
"rus": {
34-
"weights": "averaged_perceptron_tagger_rus.weights.json",
35-
"tagdict": "averaged_perceptron_tagger_rus.tagdict.json",
36-
"classes": "averaged_perceptron_tagger_rus.classes.json",
37-
},
38-
"xxx": {
39-
"weights": "averaged_perceptron_tagger.xxx.weights.json",
40-
"tagdict": "averaged_perceptron_tagger.xxx.tagdict.json",
41-
"classes": "averaged_perceptron_tagger.xxx.classes.json",
42-
},
43-
}
44-
4527

4628
@jsontags.register_tag
4729
class AveragedPerceptron:
@@ -145,15 +127,23 @@ class PerceptronTagger(TaggerI):
145127
https://explosion.ai/blog/part-of-speech-pos-tagger-in-python
146128
147129
>>> from nltk.tag.perceptron import PerceptronTagger
148-
149-
Train the model
150-
151130
>>> tagger = PerceptronTagger(load=False)
152131
132+
Train and save the model:
133+
153134
>>> tagger.train([[('today','NN'),('is','VBZ'),('good','JJ'),('day','NN')],
154-
... [('yes','NNS'),('it','PRP'),('beautiful','JJ')]])
135+
... [('yes','NNS'),('it','PRP'),('beautiful','JJ')]], save_loc=tagger.save_dir)
136+
137+
Load the saved model:
155138
156-
>>> tagger.tag(['today','is','a','beautiful','day'])
139+
>>> tagger2 = PerceptronTagger(loc=tagger.save_dir)
140+
>>> print(sorted(list(tagger2.classes)))
141+
['JJ', 'NN', 'NNS', 'PRP', 'VBZ']
142+
143+
>>> print(tagger2.classes == tagger.classes)
144+
True
145+
146+
>>> tagger2.tag(['today','is','a','beautiful','day'])
157147
[('today', 'NN'), ('is', 'PRP'), ('a', 'PRP'), ('beautiful', 'JJ'), ('day', 'NN')]
158148
159149
Use the pretrain model (the default constructor)
@@ -167,20 +157,33 @@ class PerceptronTagger(TaggerI):
167157
[('The', 'DT'), ('red', 'JJ'), ('cat', 'NN')]
168158
"""
169159

170-
json_tag = "nltk.tag.sequential.PerceptronTagger"
160+
json_tag = "nltk.tag.perceptron.PerceptronTagger"
171161

172162
START = ["-START-", "-START2-"]
173163
END = ["-END-", "-END2-"]
174164

175-
def __init__(self, load=True, lang="eng"):
165+
def __init__(self, load=True, lang="eng", loc=None):
176166
"""
177167
:param load: Load the json model upon instantiation.
178168
"""
179169
self.model = AveragedPerceptron()
180170
self.tagdict = {}
181171
self.classes = set()
172+
self.lang = lang
173+
# Save trained models in tmp directory by default:
174+
self.TRAINED_TAGGER_PATH = gettempdir()
175+
self.TAGGER_NAME = "averaged_perceptron_tagger"
176+
self.save_dir = path_join(
177+
self.TRAINED_TAGGER_PATH, f"{self.TAGGER_NAME}_{self.lang}"
178+
)
182179
if load:
183-
self.load_from_json(lang)
180+
self.load_from_json(lang, loc)
181+
182+
def param_files(self, lang="eng"):
183+
return (
184+
f"{self.TAGGER_NAME}_{lang}.{attr}.json"
185+
for attr in ["weights", "tagdict", "classes"]
186+
)
184187

185188
def tag(self, tokens, return_conf=False, use_tagdict=True):
186189
"""
@@ -253,42 +256,47 @@ def train(self, sentences, save_loc=None, nr_iter=5):
253256
self.model.average_weights()
254257
# Save to json files.
255258
if save_loc is not None:
256-
self.save_to_json(loc)
257-
258-
def save_to_json(self, loc, lang="xxx"):
259-
# TODO:
260-
assert os.isdir(
261-
TRAINED_TAGGER_PATH
262-
), f"Path set for saving needs to be a directory"
263-
264-
with open(loc + TAGGER_JSONS[lang]["weights"], "w") as fout:
265-
json.dump(self.model.weights, fout)
266-
with open(loc + TAGGER_JSONS[lang]["tagdict"], "w") as fout:
267-
json.dump(self.tagdict, fout)
268-
with open(loc + TAGGER_JSONS[lang]["classes"], "w") as fout:
269-
json.dump(self.classes, fout)
270-
271-
def load_from_json(self, lang="eng"):
259+
self.save_to_json(lang=self.lang, loc=save_loc)
260+
261+
def save_to_json(self, lang="xxx", loc=None):
262+
from os import mkdir
263+
from os.path import isdir
264+
265+
if not loc:
266+
loc = self.save_dir
267+
if not isdir(loc):
268+
mkdir(loc)
269+
270+
for param, json_file in zip(self.encode_json_obj(), self.param_files(lang)):
271+
with open(path_join(loc, json_file), "w") as fout:
272+
json.dump(param, fout)
273+
274+
def load_from_json(self, lang="eng", loc=None):
272275
# Automatically find path to the tagger if location is not specified.
273-
loc = find(f"taggers/averaged_perceptron_tagger_{lang}/")
274-
with open(loc + TAGGER_JSONS[lang]["weights"]) as fin:
275-
self.model.weights = json.load(fin)
276-
with open(loc + TAGGER_JSONS[lang]["tagdict"]) as fin:
277-
self.tagdict = json.load(fin)
278-
with open(loc + TAGGER_JSONS[lang]["classes"]) as fin:
279-
self.classes = set(json.load(fin))
276+
if not loc:
277+
loc = find(f"taggers/averaged_perceptron_tagger_{lang}")
280278

281-
self.model.classes = self.classes
279+
def load_param(json_file):
280+
with open(path_join(loc, json_file)) as fin:
281+
return json.load(fin)
282+
283+
self.decode_json_params(
284+
load_param(js_file) for js_file in self.param_files(lang)
285+
)
286+
287+
def decode_json_params(self, params):
288+
weights, tagdict, class_list = params
289+
self.model.weights = weights
290+
self.tagdict = tagdict
291+
self.classes = self.model.classes = set(class_list)
282292

283293
def encode_json_obj(self):
284294
return self.model.weights, self.tagdict, list(self.classes)
285295

286296
@classmethod
287297
def decode_json_obj(cls, obj):
288298
tagger = cls(load=False)
289-
tagger.model.weights, tagger.tagdict, tagger.classes = obj
290-
tagger.classes = set(tagger.classes)
291-
tagger.model.classes = tagger.classes
299+
tagger.decode_json_params(obj)
292300
return tagger
293301

294302
def normalize(self, word):
@@ -362,38 +370,24 @@ def _pc(n, d):
362370
return (n / d) * 100
363371

364372

365-
def _load_data_conll_format(filename):
366-
print("Read from file: ", filename)
367-
with open(filename, "rb") as fin:
368-
sentences = []
369-
sentence = []
370-
for line in fin.readlines():
371-
line = line.strip()
372-
# print line
373-
if len(line) == 0:
374-
sentences.append(sentence)
375-
sentence = []
376-
continue
377-
tokens = line.split("\t")
378-
word = tokens[1]
379-
tag = tokens[4]
380-
sentence.append((word, tag))
381-
return sentences
382-
383-
384-
def _get_pretrain_model():
385-
# Train and test on English part of ConLL data (WSJ part of Penn Treebank)
386-
# Train: section 2-11
387-
# Test : section 23
388-
tagger = PerceptronTagger()
389-
training = _load_data_conll_format("english_ptb_train.conll")
390-
testing = _load_data_conll_format("english_ptb_test.conll")
391-
print("Size of training and testing (sentence)", len(training), len(testing))
373+
def _train_and_test(lang="sv"):
374+
"""
375+
Train and test on 'lang' part of universal_treebanks corpus, which includes
376+
train and test sets in conll format for 'de', 'es', 'fi', 'fr' and 'sv'.
377+
Finds 0.94 accuracy on 'sv' (Swedish) test set.
378+
"""
379+
from nltk.corpus import universal_treebanks as utb
380+
381+
tagger = PerceptronTagger(load=False, lang=lang)
382+
training = utb.tagged_sents(f"ch/{lang}/{lang}-universal-ch-train.conll")
383+
testing = utb.tagged_sents(f"ch/{lang}/{lang}-universal-ch-test.conll")
384+
print(
385+
f"(Lang = {lang}) training on {len(training)} and testing on {len(testing)} sentences"
386+
)
392387
# Train and save the model
393-
tagger.train(training, TRAINED_TAGGER_PATH)
388+
tagger.train(training, save_loc=tagger.save_dir)
394389
print("Accuracy : ", tagger.accuracy(testing))
395390

396391

397392
if __name__ == "__main__":
398-
# _get_pretrain_model()
399-
pass
393+
_train_and_test()

0 commit comments

Comments
 (0)