Skip to content

Commit e13888d

Browse files
committed
Use encode/decode_json_obj functions
1 parent e8bea27 commit e13888d

File tree

1 file changed

+22
-27
lines changed

1 file changed

+22
-27
lines changed

nltk/tag/perceptron.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,6 @@
2525
pass
2626

2727

28-
# TAGGER_JSONS = {lang: lang_json(lang) for lang in ["eng", "rus", "xxx"]}
29-
30-
3128
@jsontags.register_tag
3229
class AveragedPerceptron:
3330
"""An averaged perceptron, as implemented by Matthew Honnibal.
@@ -182,11 +179,11 @@ def __init__(self, load=True, lang="eng", loc=None):
182179
if load:
183180
self.load_from_json(lang, loc)
184181

185-
def lang_jsons(self, lang="eng"):
186-
return {
187-
attr: f"{self.TAGGER_NAME}_{lang}.{attr}.json"
182+
def param_files(self, lang="eng"):
183+
return (
184+
f"{self.TAGGER_NAME}_{lang}.{attr}.json"
188185
for attr in ["weights", "tagdict", "classes"]
189-
}
186+
)
190187

191188
def tag(self, tokens, return_conf=False, use_tagdict=True):
192189
"""
@@ -267,41 +264,39 @@ def save_to_json(self, lang="xxx", loc=None):
267264

268265
if not loc:
269266
loc = self.save_dir
270-
271267
if not isdir(loc):
272268
mkdir(loc)
273269

274-
jsons = self.lang_jsons(lang)
275-
276-
with open(path_join(loc, jsons["weights"]), "w") as fout:
277-
json.dump(self.model.weights, fout)
278-
with open(path_join(loc, jsons["tagdict"]), "w") as fout:
279-
json.dump(self.tagdict, fout)
280-
with open(path_join(loc, jsons["classes"]), "w") as fout:
281-
json.dump(list(self.classes), fout)
270+
for param, json_file in zip(self.encode_json_obj(), self.param_files(lang)):
271+
with open(path_join(loc, json_file), "w") as fout:
272+
json.dump(param, fout)
282273

283274
def load_from_json(self, lang="eng", loc=None):
284275
# Automatically find path to the tagger if location is not specified.
285276
if not loc:
286277
loc = find(f"taggers/averaged_perceptron_tagger_{lang}")
287-
jsons = self.lang_jsons(lang)
288-
with open(path_join(loc, jsons["weights"])) as fin:
289-
self.model.weights = json.load(fin)
290-
with open(path_join(loc, jsons["tagdict"])) as fin:
291-
self.tagdict = json.load(fin)
292-
with open(path_join(loc, jsons["classes"])) as fin:
293-
self.classes = set(json.load(fin))
294-
self.model.classes = self.classes
278+
279+
def load_param(json_file):
280+
with open(path_join(loc, json_file)) as fin:
281+
return json.load(fin)
282+
283+
self.decode_json_params(
284+
load_param(js_file) for js_file in self.param_files(lang)
285+
)
286+
287+
def decode_json_params(self, params):
288+
weights, tagdict, class_list = params
289+
self.model.weights = weights
290+
self.tagdict = tagdict
291+
self.classes = self.model.classes = set(class_list)
295292

296293
def encode_json_obj(self):
297294
return self.model.weights, self.tagdict, list(self.classes)
298295

299296
@classmethod
300297
def decode_json_obj(cls, obj):
301298
tagger = cls(load=False)
302-
tagger.model.weights, tagger.tagdict, tagger.classes = obj
303-
tagger.classes = set(tagger.classes)
304-
tagger.model.classes = tagger.classes
299+
tagger.decode_json_params(obj)
305300
return tagger
306301

307302
def normalize(self, word):

0 commit comments

Comments
 (0)