@@ -30,14 +30,17 @@ class TokenClassificationArgumentHandler(ArgumentHandler):
3030 """
3131
3232 def __call__ (self , inputs : Union [str , list [str ]], ** kwargs ):
33+ is_split_into_words = kwargs .get ("is_split_into_words" , False )
34+ delimiter = kwargs .get ("delimiter" , None )
35+
3336 if inputs is not None and isinstance (inputs , (list , tuple )) and len (inputs ) > 0 :
3437 inputs = list (inputs )
3538 batch_size = len (inputs )
3639 elif isinstance (inputs , str ):
3740 inputs = [inputs ]
3841 batch_size = 1
3942 elif Dataset is not None and isinstance (inputs , Dataset ) or isinstance (inputs , types .GeneratorType ):
40- return inputs , None
43+ return inputs , is_split_into_words , None , delimiter
4144 else :
4245 raise ValueError ("At least one input is required." )
4346
@@ -47,7 +50,7 @@ def __call__(self, inputs: Union[str, list[str]], **kwargs):
4750 offset_mapping = [offset_mapping ]
4851 if len (offset_mapping ) != batch_size :
4952 raise ValueError ("offset_mapping should have the same batch size as the input" )
50- return inputs , offset_mapping
53+ return inputs , is_split_into_words , offset_mapping , delimiter
5154
5255
5356class AggregationStrategy (ExplicitEnum ):
@@ -135,6 +138,7 @@ class TokenClassificationPipeline(ChunkPipeline):
135138
136139 def __init__ (self , args_parser = TokenClassificationArgumentHandler (), * args , ** kwargs ):
137140 super ().__init__ (* args , ** kwargs )
141+
138142 self .check_model_type (
139143 TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
140144 if self .framework == "tf"
@@ -151,9 +155,16 @@ def _sanitize_parameters(
151155 ignore_subwords : Optional [bool ] = None ,
152156 aggregation_strategy : Optional [AggregationStrategy ] = None ,
153157 offset_mapping : Optional [list [tuple [int , int ]]] = None ,
158+ is_split_into_words : Optional [bool ] = False ,
154159 stride : Optional [int ] = None ,
160+ delimiter : Optional [str ] = None ,
155161 ):
156162 preprocess_params = {}
163+ preprocess_params ["is_split_into_words" ] = is_split_into_words
164+
165+ if is_split_into_words :
166+ preprocess_params ["delimiter" ] = " " if delimiter is None else delimiter
167+
157168 if offset_mapping is not None :
158169 preprocess_params ["offset_mapping" ] = offset_mapping
159170
@@ -230,8 +241,9 @@ def __call__(
230241 Classify each token of the text(s) given as inputs.
231242
232243 Args:
233- inputs (`str` or `list[str]`):
234- One or several texts (or one list of texts) for token classification.
244+ inputs (`str` or `List[str]`):
245+ One or several texts (or one list of texts) for token classification. Can be pre-tokenized when
246+ `is_split_into_words=True`.
235247
236248 Return:
237249 A list or a list of list of `dict`: Each result comes as a list of dictionaries (one for each token in the
@@ -251,7 +263,11 @@ def __call__(
251263 exists if the offsets are available within the tokenizer
252264 """
253265
254- _inputs , offset_mapping = self ._args_parser (inputs , ** kwargs )
266+ _inputs , is_split_into_words , offset_mapping , delimiter = self ._args_parser (inputs , ** kwargs )
267+ kwargs ["is_split_into_words" ] = is_split_into_words
268+ kwargs ["delimiter" ] = delimiter
269+ if is_split_into_words and not all (isinstance (input , list ) for input in inputs ):
270+ return super ().__call__ ([inputs ], ** kwargs )
255271 if offset_mapping :
256272 kwargs ["offset_mapping" ] = offset_mapping
257273
@@ -260,14 +276,43 @@ def __call__(
260276 def preprocess (self , sentence , offset_mapping = None , ** preprocess_params ):
261277 tokenizer_params = preprocess_params .pop ("tokenizer_params" , {})
262278 truncation = True if self .tokenizer .model_max_length and self .tokenizer .model_max_length > 0 else False
279+
280+ word_to_chars_map = None
281+ is_split_into_words = preprocess_params ["is_split_into_words" ]
282+ if is_split_into_words :
283+ delimiter = preprocess_params ["delimiter" ]
284+ if not isinstance (sentence , list ):
285+ raise ValueError ("When `is_split_into_words=True`, `sentence` must be a list of tokens." )
286+ words = sentence
287+ sentence = delimiter .join (words ) # Recreate the sentence string for later display and slicing
288+ # This map will allows to convert back word => char indices
289+ word_to_chars_map = []
290+ delimiter_len = len (delimiter )
291+ char_offset = 0
292+ for word in words :
293+ word_to_chars_map .append ((char_offset , char_offset + len (word )))
294+ char_offset += len (word ) + delimiter_len
295+
296+ # We use `words` as the actual input for the tokenizer
297+ text_to_tokenize = words
298+ tokenizer_params ["is_split_into_words" ] = True
299+ else :
300+ if not isinstance (sentence , str ):
301+ raise ValueError ("When `is_split_into_words=False`, `sentence` must be an untokenized string." )
302+ text_to_tokenize = sentence
303+
263304 inputs = self .tokenizer (
264- sentence ,
305+ text_to_tokenize ,
265306 return_tensors = self .framework ,
266307 truncation = truncation ,
267308 return_special_tokens_mask = True ,
268309 return_offsets_mapping = self .tokenizer .is_fast ,
269310 ** tokenizer_params ,
270311 )
312+
313+ if is_split_into_words and not self .tokenizer .is_fast :
314+ raise ValueError ("is_split_into_words=True is only supported with fast tokenizers." )
315+
271316 inputs .pop ("overflow_to_sample_mapping" , None )
272317 num_chunks = len (inputs ["input_ids" ])
273318
@@ -278,8 +323,12 @@ def preprocess(self, sentence, offset_mapping=None, **preprocess_params):
278323 model_inputs = {k : v [i ].unsqueeze (0 ) for k , v in inputs .items ()}
279324 if offset_mapping is not None :
280325 model_inputs ["offset_mapping" ] = offset_mapping
326+
281327 model_inputs ["sentence" ] = sentence if i == 0 else None
282328 model_inputs ["is_last" ] = i == num_chunks - 1
329+ if word_to_chars_map is not None :
330+ model_inputs ["word_ids" ] = inputs .word_ids (i )
331+ model_inputs ["word_to_chars_map" ] = word_to_chars_map
283332
284333 yield model_inputs
285334
@@ -289,6 +338,9 @@ def _forward(self, model_inputs):
289338 offset_mapping = model_inputs .pop ("offset_mapping" , None )
290339 sentence = model_inputs .pop ("sentence" )
291340 is_last = model_inputs .pop ("is_last" )
341+ word_ids = model_inputs .pop ("word_ids" , None )
342+ word_to_chars_map = model_inputs .pop ("word_to_chars_map" , None )
343+
292344 if self .framework == "tf" :
293345 logits = self .model (** model_inputs )[0 ]
294346 else :
@@ -301,13 +353,19 @@ def _forward(self, model_inputs):
301353 "offset_mapping" : offset_mapping ,
302354 "sentence" : sentence ,
303355 "is_last" : is_last ,
356+ "word_ids" : word_ids ,
357+ "word_to_chars_map" : word_to_chars_map ,
304358 ** model_inputs ,
305359 }
306360
307361 def postprocess (self , all_outputs , aggregation_strategy = AggregationStrategy .NONE , ignore_labels = None ):
308362 if ignore_labels is None :
309363 ignore_labels = ["O" ]
310364 all_entities = []
365+
366+ # Get map from the first output, it's the same for all chunks
367+ word_to_chars_map = all_outputs [0 ].get ("word_to_chars_map" )
368+
311369 for model_outputs in all_outputs :
312370 if self .framework == "pt" and model_outputs ["logits" ][0 ].dtype in (torch .bfloat16 , torch .float16 ):
313371 logits = model_outputs ["logits" ][0 ].to (torch .float32 ).numpy ()
@@ -320,6 +378,7 @@ def postprocess(self, all_outputs, aggregation_strategy=AggregationStrategy.NONE
320378 model_outputs ["offset_mapping" ][0 ] if model_outputs ["offset_mapping" ] is not None else None
321379 )
322380 special_tokens_mask = model_outputs ["special_tokens_mask" ][0 ].numpy ()
381+ word_ids = model_outputs .get ("word_ids" )
323382
324383 maxes = np .max (logits , axis = - 1 , keepdims = True )
325384 shifted_exp = np .exp (logits - maxes )
@@ -330,7 +389,14 @@ def postprocess(self, all_outputs, aggregation_strategy=AggregationStrategy.NONE
330389 offset_mapping = offset_mapping .numpy () if offset_mapping is not None else None
331390
332391 pre_entities = self .gather_pre_entities (
333- sentence , input_ids , scores , offset_mapping , special_tokens_mask , aggregation_strategy
392+ sentence ,
393+ input_ids ,
394+ scores ,
395+ offset_mapping ,
396+ special_tokens_mask ,
397+ aggregation_strategy ,
398+ word_ids = word_ids ,
399+ word_to_chars_map = word_to_chars_map ,
334400 )
335401 grouped_entities = self .aggregate (pre_entities , aggregation_strategy )
336402 # Filter anything that is in self.ignore_labels
@@ -374,6 +440,8 @@ def gather_pre_entities(
374440 offset_mapping : Optional [list [tuple [int , int ]]],
375441 special_tokens_mask : np .ndarray ,
376442 aggregation_strategy : AggregationStrategy ,
443+ word_ids : Optional [list [Optional [int ]]] = None ,
444+ word_to_chars_map : Optional [list [tuple [int , int ]]] = None ,
377445 ) -> list [dict ]:
378446 """Fuse various numpy arrays into dicts with all the information needed for aggregation"""
379447 pre_entities = []
@@ -385,6 +453,15 @@ def gather_pre_entities(
385453 word = self .tokenizer .convert_ids_to_tokens (int (input_ids [idx ]))
386454 if offset_mapping is not None :
387455 start_ind , end_ind = offset_mapping [idx ]
456+
457+ # If the input is pre-tokenized, we need to rescale the offsets to the absolute sentence.
458+ if word_ids is not None and word_to_chars_map is not None :
459+ word_index = word_ids [idx ]
460+ if word_index is not None :
461+ start_char , _ = word_to_chars_map [word_index ]
462+ start_ind += start_char
463+ end_ind += start_char
464+
388465 if not isinstance (start_ind , int ):
389466 if self .framework == "pt" :
390467 start_ind = start_ind .item ()
0 commit comments