|
25 | 25 | pass |
26 | 26 |
|
27 | 27 |
|
28 | | -# TAGGER_JSONS = {lang: lang_json(lang) for lang in ["eng", "rus", "xxx"]} |
29 | | - |
30 | | - |
31 | 28 | @jsontags.register_tag |
32 | 29 | class AveragedPerceptron: |
33 | 30 | """An averaged perceptron, as implemented by Matthew Honnibal. |
@@ -182,11 +179,11 @@ def __init__(self, load=True, lang="eng", loc=None): |
182 | 179 | if load: |
183 | 180 | self.load_from_json(lang, loc) |
184 | 181 |
|
185 | | - def lang_jsons(self, lang="eng"): |
186 | | - return { |
187 | | - attr: f"{self.TAGGER_NAME}_{lang}.{attr}.json" |
| 182 | + def param_files(self, lang="eng"): |
| 183 | + return ( |
| 184 | + f"{self.TAGGER_NAME}_{lang}.{attr}.json" |
188 | 185 | for attr in ["weights", "tagdict", "classes"] |
189 | | - } |
| 186 | + ) |
190 | 187 |
|
191 | 188 | def tag(self, tokens, return_conf=False, use_tagdict=True): |
192 | 189 | """ |
@@ -267,41 +264,39 @@ def save_to_json(self, lang="xxx", loc=None): |
267 | 264 |
|
268 | 265 | if not loc: |
269 | 266 | loc = self.save_dir |
270 | | - |
271 | 267 | if not isdir(loc): |
272 | 268 | mkdir(loc) |
273 | 269 |
|
274 | | - jsons = self.lang_jsons(lang) |
275 | | - |
276 | | - with open(path_join(loc, jsons["weights"]), "w") as fout: |
277 | | - json.dump(self.model.weights, fout) |
278 | | - with open(path_join(loc, jsons["tagdict"]), "w") as fout: |
279 | | - json.dump(self.tagdict, fout) |
280 | | - with open(path_join(loc, jsons["classes"]), "w") as fout: |
281 | | - json.dump(list(self.classes), fout) |
| 270 | + for param, json_file in zip(self.encode_json_obj(), self.param_files(lang)): |
| 271 | + with open(path_join(loc, json_file), "w") as fout: |
| 272 | + json.dump(param, fout) |
282 | 273 |
|
283 | 274 | def load_from_json(self, lang="eng", loc=None): |
284 | 275 | # Automatically find path to the tagger if location is not specified. |
285 | 276 | if not loc: |
286 | 277 | loc = find(f"taggers/averaged_perceptron_tagger_{lang}") |
287 | | - jsons = self.lang_jsons(lang) |
288 | | - with open(path_join(loc, jsons["weights"])) as fin: |
289 | | - self.model.weights = json.load(fin) |
290 | | - with open(path_join(loc, jsons["tagdict"])) as fin: |
291 | | - self.tagdict = json.load(fin) |
292 | | - with open(path_join(loc, jsons["classes"])) as fin: |
293 | | - self.classes = set(json.load(fin)) |
294 | | - self.model.classes = self.classes |
| 278 | + |
| 279 | + def load_param(json_file): |
| 280 | + with open(path_join(loc, json_file)) as fin: |
| 281 | + return json.load(fin) |
| 282 | + |
| 283 | + self.decode_json_params( |
| 284 | + load_param(js_file) for js_file in self.param_files(lang) |
| 285 | + ) |
| 286 | + |
| 287 | + def decode_json_params(self, params): |
| 288 | + weights, tagdict, class_list = params |
| 289 | + self.model.weights = weights |
| 290 | + self.tagdict = tagdict |
| 291 | + self.classes = self.model.classes = set(class_list) |
295 | 292 |
|
296 | 293 | def encode_json_obj(self): |
297 | 294 | return self.model.weights, self.tagdict, list(self.classes) |
298 | 295 |
|
299 | 296 | @classmethod |
300 | 297 | def decode_json_obj(cls, obj): |
301 | 298 | tagger = cls(load=False) |
302 | | - tagger.model.weights, tagger.tagdict, tagger.classes = obj |
303 | | - tagger.classes = set(tagger.classes) |
304 | | - tagger.model.classes = tagger.classes |
| 299 | + tagger.decode_json_params(obj) |
305 | 300 | return tagger |
306 | 301 |
|
307 | 302 | def normalize(self, word): |
|
0 commit comments