1212import logging
1313import random
1414from collections import defaultdict
15+ from os .path import join as path_join
16+ from tempfile import gettempdir
1517
1618from nltk import jsontags
1719from nltk .data import find , load
2224except ImportError :
2325 pass
2426
25- TRAINED_TAGGER_PATH = "averaged_perceptron_tagger/"
26-
27- TAGGER_JSONS = {
28- "eng" : {
29- "weights" : "averaged_perceptron_tagger_eng.weights.json" ,
30- "tagdict" : "averaged_perceptron_tagger_eng.tagdict.json" ,
31- "classes" : "averaged_perceptron_tagger_eng.classes.json" ,
32- },
33- "rus" : {
34- "weights" : "averaged_perceptron_tagger_rus.weights.json" ,
35- "tagdict" : "averaged_perceptron_tagger_rus.tagdict.json" ,
36- "classes" : "averaged_perceptron_tagger_rus.classes.json" ,
37- },
38- "xxx" : {
39- "weights" : "averaged_perceptron_tagger.xxx.weights.json" ,
40- "tagdict" : "averaged_perceptron_tagger.xxx.tagdict.json" ,
41- "classes" : "averaged_perceptron_tagger.xxx.classes.json" ,
42- },
43- }
44-
4527
4628@jsontags .register_tag
4729class AveragedPerceptron :
@@ -145,15 +127,23 @@ class PerceptronTagger(TaggerI):
145127 https://explosion.ai/blog/part-of-speech-pos-tagger-in-python
146128
147129 >>> from nltk.tag.perceptron import PerceptronTagger
148-
149- Train the model
150-
151130 >>> tagger = PerceptronTagger(load=False)
152131
132+ Train and save the model:
133+
153134 >>> tagger.train([[('today','NN'),('is','VBZ'),('good','JJ'),('day','NN')],
154- ... [('yes','NNS'),('it','PRP'),('beautiful','JJ')]])
135+ ... [('yes','NNS'),('it','PRP'),('beautiful','JJ')]], save_loc=tagger.save_dir)
136+
137+ Load the saved model:
155138
156- >>> tagger.tag(['today','is','a','beautiful','day'])
139+ >>> tagger2 = PerceptronTagger(loc=tagger.save_dir)
140+ >>> print(sorted(list(tagger2.classes)))
141+ ['JJ', 'NN', 'NNS', 'PRP', 'VBZ']
142+
143+ >>> print(tagger2.classes == tagger.classes)
144+ True
145+
146+ >>> tagger2.tag(['today','is','a','beautiful','day'])
157147 [('today', 'NN'), ('is', 'PRP'), ('a', 'PRP'), ('beautiful', 'JJ'), ('day', 'NN')]
158148
159149 Use the pretrain model (the default constructor)
@@ -167,20 +157,33 @@ class PerceptronTagger(TaggerI):
167157 [('The', 'DT'), ('red', 'JJ'), ('cat', 'NN')]
168158 """
169159
170- json_tag = "nltk.tag.sequential .PerceptronTagger"
160+ json_tag = "nltk.tag.perceptron .PerceptronTagger"
171161
172162 START = ["-START-" , "-START2-" ]
173163 END = ["-END-" , "-END2-" ]
174164
175- def __init__ (self , load = True , lang = "eng" ):
165+ def __init__ (self , load = True , lang = "eng" , loc = None ):
176166 """
177167 :param load: Load the json model upon instantiation.
178168 """
179169 self .model = AveragedPerceptron ()
180170 self .tagdict = {}
181171 self .classes = set ()
172+ self .lang = lang
173+ # Save trained models in tmp directory by default:
174+ self .TRAINED_TAGGER_PATH = gettempdir ()
175+ self .TAGGER_NAME = "averaged_perceptron_tagger"
176+ self .save_dir = path_join (
177+ self .TRAINED_TAGGER_PATH , f"{ self .TAGGER_NAME } _{ self .lang } "
178+ )
182179 if load :
183- self .load_from_json (lang )
180+ self .load_from_json (lang , loc )
181+
182+ def param_files (self , lang = "eng" ):
183+ return (
184+ f"{ self .TAGGER_NAME } _{ lang } .{ attr } .json"
185+ for attr in ["weights" , "tagdict" , "classes" ]
186+ )
184187
185188 def tag (self , tokens , return_conf = False , use_tagdict = True ):
186189 """
@@ -253,42 +256,47 @@ def train(self, sentences, save_loc=None, nr_iter=5):
253256 self .model .average_weights ()
254257 # Save to json files.
255258 if save_loc is not None :
256- self .save_to_json (loc )
257-
258- def save_to_json (self , loc , lang = "xxx" ):
259- # TODO:
260- assert os .isdir (
261- TRAINED_TAGGER_PATH
262- ), f"Path set for saving needs to be a directory"
263-
264- with open (loc + TAGGER_JSONS [ lang ][ "weights" ], "w" ) as fout :
265- json . dump ( self . model . weights , fout )
266- with open ( loc + TAGGER_JSONS [ lang ][ "tagdict" ], "w" ) as fout :
267- json . dump (self .tagdict , fout )
268- with open (loc + TAGGER_JSONS [ lang ][ "classes" ] , "w" ) as fout :
269- json .dump (self . classes , fout )
270-
271- def load_from_json (self , lang = "eng" ):
259+ self .save_to_json (lang = self . lang , loc = save_loc )
260+
261+ def save_to_json (self , lang = "xxx" , loc = None ):
262+ from os import mkdir
263+ from os .path import isdir
264+
265+ if not loc :
266+ loc = self . save_dir
267+ if not isdir (loc ) :
268+ mkdir ( loc )
269+
270+ for param , json_file in zip (self .encode_json_obj (), self . param_files ( lang )):
271+ with open (path_join ( loc , json_file ) , "w" ) as fout :
272+ json .dump (param , fout )
273+
274+ def load_from_json (self , lang = "eng" , loc = None ):
272275 # Automatically find path to the tagger if location is not specified.
273- loc = find (f"taggers/averaged_perceptron_tagger_{ lang } /" )
274- with open (loc + TAGGER_JSONS [lang ]["weights" ]) as fin :
275- self .model .weights = json .load (fin )
276- with open (loc + TAGGER_JSONS [lang ]["tagdict" ]) as fin :
277- self .tagdict = json .load (fin )
278- with open (loc + TAGGER_JSONS [lang ]["classes" ]) as fin :
279- self .classes = set (json .load (fin ))
276+ if not loc :
277+ loc = find (f"taggers/averaged_perceptron_tagger_{ lang } " )
280278
281- self .model .classes = self .classes
279+ def load_param (json_file ):
280+ with open (path_join (loc , json_file )) as fin :
281+ return json .load (fin )
282+
283+ self .decode_json_params (
284+ load_param (js_file ) for js_file in self .param_files (lang )
285+ )
286+
287+ def decode_json_params (self , params ):
288+ weights , tagdict , class_list = params
289+ self .model .weights = weights
290+ self .tagdict = tagdict
291+ self .classes = self .model .classes = set (class_list )
282292
283293 def encode_json_obj (self ):
284294 return self .model .weights , self .tagdict , list (self .classes )
285295
286296 @classmethod
287297 def decode_json_obj (cls , obj ):
288298 tagger = cls (load = False )
289- tagger .model .weights , tagger .tagdict , tagger .classes = obj
290- tagger .classes = set (tagger .classes )
291- tagger .model .classes = tagger .classes
299+ tagger .decode_json_params (obj )
292300 return tagger
293301
294302 def normalize (self , word ):
@@ -362,38 +370,24 @@ def _pc(n, d):
362370 return (n / d ) * 100
363371
364372
365- def _load_data_conll_format (filename ):
366- print ("Read from file: " , filename )
367- with open (filename , "rb" ) as fin :
368- sentences = []
369- sentence = []
370- for line in fin .readlines ():
371- line = line .strip ()
372- # print line
373- if len (line ) == 0 :
374- sentences .append (sentence )
375- sentence = []
376- continue
377- tokens = line .split ("\t " )
378- word = tokens [1 ]
379- tag = tokens [4 ]
380- sentence .append ((word , tag ))
381- return sentences
382-
383-
384- def _get_pretrain_model ():
385- # Train and test on English part of ConLL data (WSJ part of Penn Treebank)
386- # Train: section 2-11
387- # Test : section 23
388- tagger = PerceptronTagger ()
389- training = _load_data_conll_format ("english_ptb_train.conll" )
390- testing = _load_data_conll_format ("english_ptb_test.conll" )
391- print ("Size of training and testing (sentence)" , len (training ), len (testing ))
373+ def _train_and_test (lang = "sv" ):
374+ """
375+ Train and test on 'lang' part of universal_treebanks corpus, which includes
376+ train and test sets in conll format for 'de', 'es', 'fi', 'fr' and 'sv'.
377+ Finds 0.94 accuracy on 'sv' (Swedish) test set.
378+ """
379+ from nltk .corpus import universal_treebanks as utb
380+
381+ tagger = PerceptronTagger (load = False , lang = lang )
382+ training = utb .tagged_sents (f"ch/{ lang } /{ lang } -universal-ch-train.conll" )
383+ testing = utb .tagged_sents (f"ch/{ lang } /{ lang } -universal-ch-test.conll" )
384+ print (
385+ f"(Lang = { lang } ) training on { len (training )} and testing on { len (testing )} sentences"
386+ )
392387 # Train and save the model
393- tagger .train (training , TRAINED_TAGGER_PATH )
388+ tagger .train (training , save_loc = tagger . save_dir )
394389 print ("Accuracy : " , tagger .accuracy (testing ))
395390
396391
397392if __name__ == "__main__" :
398- # _get_pretrain_model()
399- pass
393+ _train_and_test ()
0 commit comments