@@ -290,16 +290,25 @@ def _gen_data(self):
290290class TokenClsInputExample (InputExample ):
291291 """A single training/test example for simple sequence token classification."""
292292
293- def __init__ (self , guid : str , text : str , tokens : List [str ], label : List [str ] = None ):
293+ def __init__ (
294+ self ,
295+ guid : str ,
296+ text : str ,
297+ tokens : List [str ],
298+ shapes : List [int ] = None ,
299+ label : List [str ] = None ,
300+ ):
294301 """Constructs a SequenceClassInputExample.
295302 Args:
296303 guid: Unique id for the example.
297304 text: string. The untokenized text of the sequence.
298305 tokens (List[str]): The list of tokens.
306+ shapes (List[str]): List of tokens shapes.
299307 label (List[str], optional): The tags of the tokens.
300308 """
301309 super (TokenClsInputExample , self ).__init__ (guid , text , label )
302310 self .tokens = tokens
311+ self .shapes = shapes
303312
304313
305314class TokenClsProcessor (DataProcessor ):
@@ -309,12 +318,13 @@ class TokenClsProcessor(DataProcessor):
309318 Label dictionary is given in labels.txt file.
310319 """
311320
312- def __init__ (self , data_dir , tag_col : int = - 1 ):
321+ def __init__ (self , data_dir , tag_col : int = - 1 , ignore_token = None ):
313322 if not os .path .exists (data_dir ):
314323 raise FileNotFoundError
315324 self .data_dir = data_dir
316325 self .tag_col = tag_col
317326 self .labels = None
327+ self .ignore_token = ignore_token
318328
319329 def _read_examples (self , data_dir , file_name , set_name ):
320330 if not os .path .exists (data_dir + os .sep + file_name ):
@@ -325,7 +335,11 @@ def _read_examples(self, data_dir, file_name, set_name):
325335 )
326336 return None
327337 return self ._create_examples (
328- read_column_tagged_file (os .path .join (data_dir , file_name ), tag_col = self .tag_col ),
338+ read_column_tagged_file (
339+ os .path .join (data_dir , file_name ),
340+ tag_col = self .tag_col ,
341+ ignore_token = self .ignore_token ,
342+ ),
329343 set_name ,
330344 )
331345
@@ -359,19 +373,31 @@ def get_labels_filename():
359373 return "labels.txt"
360374
361375 @staticmethod
362- def _create_examples (lines , set_type ):
376+ def _get_shape (string ):
377+ if all (c .isupper () for c in string ):
378+ return 1 # "AA"
379+ if string [0 ].isupper ():
380+ return 2 # "Aa"
381+ if any (c for c in string if c .isupper ()):
382+ return 3 # "aAa"
383+ return 4 # "a"
384+
385+ @classmethod
386+ def _create_examples (cls , lines , set_type ):
363387 """See base class."""
364388 examples = []
365389 for i , (sentence , labels ) in enumerate (lines ):
366390 guid = "%s-%s" % (set_type , i )
367391 text = " " .join (sentence )
392+ shapes = [cls ._get_shape (w ) for w in sentence ]
368393 examples .append (
369- TokenClsInputExample (guid = guid , text = text , tokens = sentence , label = labels )
394+ TokenClsInputExample (
395+ guid = guid , text = text , tokens = sentence , label = labels , shapes = shapes
396+ )
370397 )
371398 return examples
372399
373- def get_vocabulary (self ):
374- examples = self .get_train_examples () + self .get_dev_examples () + self .get_test_examples ()
400+ def get_vocabulary (self , examples : TokenClsInputExample = None ):
375401 vocab = Vocabulary (start = 1 )
376402 for e in examples :
377403 for t in e .tokens :
0 commit comments