Skip to content

Commit 9eac19e

Browse files
authored
[Feature] Support is_split_into_words in the TokenClassificationPipeline. (#38818)
* some fixes * some fixes * now the pipeline can take list of tokens as input and is_split_into_words argument * now the pipeline can take list of tokens as input and is_split_into_words argument * now the pipeline can take list of tokens as input and is_split_into_words argument and we can handle batches of tokenized input * now the pipeline can take list of tokens as input and is_split_into_words argument and we can handle batches of tokenized input * solving test problems * some fixes * some fixes * modify tests * aligning start and end correctly * adding tests * some formatting * some formatting * some fixes * some fixes * some fixes * resolve conflicts * removing unimportant lines * removing unimportant lines * generalize to other languages * generalize to other languages * generalize to other languages * generalize to other languages
1 parent 2ce02b9 commit 9eac19e

File tree

2 files changed

+141
-11
lines changed

2 files changed

+141
-11
lines changed

src/transformers/pipelines/token_classification.py

Lines changed: 84 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,17 @@ class TokenClassificationArgumentHandler(ArgumentHandler):
3030
"""
3131

3232
def __call__(self, inputs: Union[str, list[str]], **kwargs):
33+
is_split_into_words = kwargs.get("is_split_into_words", False)
34+
delimiter = kwargs.get("delimiter", None)
35+
3336
if inputs is not None and isinstance(inputs, (list, tuple)) and len(inputs) > 0:
3437
inputs = list(inputs)
3538
batch_size = len(inputs)
3639
elif isinstance(inputs, str):
3740
inputs = [inputs]
3841
batch_size = 1
3942
elif Dataset is not None and isinstance(inputs, Dataset) or isinstance(inputs, types.GeneratorType):
40-
return inputs, None
43+
return inputs, is_split_into_words, None, delimiter
4144
else:
4245
raise ValueError("At least one input is required.")
4346

@@ -47,7 +50,7 @@ def __call__(self, inputs: Union[str, list[str]], **kwargs):
4750
offset_mapping = [offset_mapping]
4851
if len(offset_mapping) != batch_size:
4952
raise ValueError("offset_mapping should have the same batch size as the input")
50-
return inputs, offset_mapping
53+
return inputs, is_split_into_words, offset_mapping, delimiter
5154

5255

5356
class AggregationStrategy(ExplicitEnum):
@@ -135,6 +138,7 @@ class TokenClassificationPipeline(ChunkPipeline):
135138

136139
def __init__(self, args_parser=TokenClassificationArgumentHandler(), *args, **kwargs):
137140
super().__init__(*args, **kwargs)
141+
138142
self.check_model_type(
139143
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
140144
if self.framework == "tf"
@@ -151,9 +155,16 @@ def _sanitize_parameters(
151155
ignore_subwords: Optional[bool] = None,
152156
aggregation_strategy: Optional[AggregationStrategy] = None,
153157
offset_mapping: Optional[list[tuple[int, int]]] = None,
158+
is_split_into_words: Optional[bool] = False,
154159
stride: Optional[int] = None,
160+
delimiter: Optional[str] = None,
155161
):
156162
preprocess_params = {}
163+
preprocess_params["is_split_into_words"] = is_split_into_words
164+
165+
if is_split_into_words:
166+
preprocess_params["delimiter"] = " " if delimiter is None else delimiter
167+
157168
if offset_mapping is not None:
158169
preprocess_params["offset_mapping"] = offset_mapping
159170

@@ -230,8 +241,9 @@ def __call__(
230241
Classify each token of the text(s) given as inputs.
231242
232243
Args:
233-
inputs (`str` or `list[str]`):
234-
One or several texts (or one list of texts) for token classification.
244+
inputs (`str` or `List[str]`):
245+
One or several texts (or one list of texts) for token classification. Can be pre-tokenized when
246+
`is_split_into_words=True`.
235247
236248
Return:
237249
A list or a list of list of `dict`: Each result comes as a list of dictionaries (one for each token in the
@@ -251,7 +263,11 @@ def __call__(
251263
exists if the offsets are available within the tokenizer
252264
"""
253265

254-
_inputs, offset_mapping = self._args_parser(inputs, **kwargs)
266+
_inputs, is_split_into_words, offset_mapping, delimiter = self._args_parser(inputs, **kwargs)
267+
kwargs["is_split_into_words"] = is_split_into_words
268+
kwargs["delimiter"] = delimiter
269+
if is_split_into_words and not all(isinstance(input, list) for input in inputs):
270+
return super().__call__([inputs], **kwargs)
255271
if offset_mapping:
256272
kwargs["offset_mapping"] = offset_mapping
257273

@@ -260,14 +276,43 @@ def __call__(
260276
def preprocess(self, sentence, offset_mapping=None, **preprocess_params):
261277
tokenizer_params = preprocess_params.pop("tokenizer_params", {})
262278
truncation = True if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0 else False
279+
280+
word_to_chars_map = None
281+
is_split_into_words = preprocess_params["is_split_into_words"]
282+
if is_split_into_words:
283+
delimiter = preprocess_params["delimiter"]
284+
if not isinstance(sentence, list):
285+
raise ValueError("When `is_split_into_words=True`, `sentence` must be a list of tokens.")
286+
words = sentence
287+
sentence = delimiter.join(words) # Recreate the sentence string for later display and slicing
288+
# This map will allows to convert back word => char indices
289+
word_to_chars_map = []
290+
delimiter_len = len(delimiter)
291+
char_offset = 0
292+
for word in words:
293+
word_to_chars_map.append((char_offset, char_offset + len(word)))
294+
char_offset += len(word) + delimiter_len
295+
296+
# We use `words` as the actual input for the tokenizer
297+
text_to_tokenize = words
298+
tokenizer_params["is_split_into_words"] = True
299+
else:
300+
if not isinstance(sentence, str):
301+
raise ValueError("When `is_split_into_words=False`, `sentence` must be an untokenized string.")
302+
text_to_tokenize = sentence
303+
263304
inputs = self.tokenizer(
264-
sentence,
305+
text_to_tokenize,
265306
return_tensors=self.framework,
266307
truncation=truncation,
267308
return_special_tokens_mask=True,
268309
return_offsets_mapping=self.tokenizer.is_fast,
269310
**tokenizer_params,
270311
)
312+
313+
if is_split_into_words and not self.tokenizer.is_fast:
314+
raise ValueError("is_split_into_words=True is only supported with fast tokenizers.")
315+
271316
inputs.pop("overflow_to_sample_mapping", None)
272317
num_chunks = len(inputs["input_ids"])
273318

@@ -278,8 +323,12 @@ def preprocess(self, sentence, offset_mapping=None, **preprocess_params):
278323
model_inputs = {k: v[i].unsqueeze(0) for k, v in inputs.items()}
279324
if offset_mapping is not None:
280325
model_inputs["offset_mapping"] = offset_mapping
326+
281327
model_inputs["sentence"] = sentence if i == 0 else None
282328
model_inputs["is_last"] = i == num_chunks - 1
329+
if word_to_chars_map is not None:
330+
model_inputs["word_ids"] = inputs.word_ids(i)
331+
model_inputs["word_to_chars_map"] = word_to_chars_map
283332

284333
yield model_inputs
285334

@@ -289,6 +338,9 @@ def _forward(self, model_inputs):
289338
offset_mapping = model_inputs.pop("offset_mapping", None)
290339
sentence = model_inputs.pop("sentence")
291340
is_last = model_inputs.pop("is_last")
341+
word_ids = model_inputs.pop("word_ids", None)
342+
word_to_chars_map = model_inputs.pop("word_to_chars_map", None)
343+
292344
if self.framework == "tf":
293345
logits = self.model(**model_inputs)[0]
294346
else:
@@ -301,13 +353,19 @@ def _forward(self, model_inputs):
301353
"offset_mapping": offset_mapping,
302354
"sentence": sentence,
303355
"is_last": is_last,
356+
"word_ids": word_ids,
357+
"word_to_chars_map": word_to_chars_map,
304358
**model_inputs,
305359
}
306360

307361
def postprocess(self, all_outputs, aggregation_strategy=AggregationStrategy.NONE, ignore_labels=None):
308362
if ignore_labels is None:
309363
ignore_labels = ["O"]
310364
all_entities = []
365+
366+
# Get map from the first output, it's the same for all chunks
367+
word_to_chars_map = all_outputs[0].get("word_to_chars_map")
368+
311369
for model_outputs in all_outputs:
312370
if self.framework == "pt" and model_outputs["logits"][0].dtype in (torch.bfloat16, torch.float16):
313371
logits = model_outputs["logits"][0].to(torch.float32).numpy()
@@ -320,6 +378,7 @@ def postprocess(self, all_outputs, aggregation_strategy=AggregationStrategy.NONE
320378
model_outputs["offset_mapping"][0] if model_outputs["offset_mapping"] is not None else None
321379
)
322380
special_tokens_mask = model_outputs["special_tokens_mask"][0].numpy()
381+
word_ids = model_outputs.get("word_ids")
323382

324383
maxes = np.max(logits, axis=-1, keepdims=True)
325384
shifted_exp = np.exp(logits - maxes)
@@ -330,7 +389,14 @@ def postprocess(self, all_outputs, aggregation_strategy=AggregationStrategy.NONE
330389
offset_mapping = offset_mapping.numpy() if offset_mapping is not None else None
331390

332391
pre_entities = self.gather_pre_entities(
333-
sentence, input_ids, scores, offset_mapping, special_tokens_mask, aggregation_strategy
392+
sentence,
393+
input_ids,
394+
scores,
395+
offset_mapping,
396+
special_tokens_mask,
397+
aggregation_strategy,
398+
word_ids=word_ids,
399+
word_to_chars_map=word_to_chars_map,
334400
)
335401
grouped_entities = self.aggregate(pre_entities, aggregation_strategy)
336402
# Filter anything that is in self.ignore_labels
@@ -374,6 +440,8 @@ def gather_pre_entities(
374440
offset_mapping: Optional[list[tuple[int, int]]],
375441
special_tokens_mask: np.ndarray,
376442
aggregation_strategy: AggregationStrategy,
443+
word_ids: Optional[list[Optional[int]]] = None,
444+
word_to_chars_map: Optional[list[tuple[int, int]]] = None,
377445
) -> list[dict]:
378446
"""Fuse various numpy arrays into dicts with all the information needed for aggregation"""
379447
pre_entities = []
@@ -385,6 +453,15 @@ def gather_pre_entities(
385453
word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx]))
386454
if offset_mapping is not None:
387455
start_ind, end_ind = offset_mapping[idx]
456+
457+
# If the input is pre-tokenized, we need to rescale the offsets to the absolute sentence.
458+
if word_ids is not None and word_to_chars_map is not None:
459+
word_index = word_ids[idx]
460+
if word_index is not None:
461+
start_char, _ = word_to_chars_map[word_index]
462+
start_ind += start_char
463+
end_ind += start_char
464+
388465
if not isinstance(start_ind, int):
389466
if self.framework == "pt":
390467
start_ind = start_ind.item()

tests/pipelines/test_pipelines_token_classification.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,54 @@ def test_chunking(self):
308308
],
309309
)
310310

311+
@require_torch
312+
@slow
313+
def test_is_split_into_words(self):
314+
"""
315+
Tests the pipeline with pre-tokenized inputs (is_split_into_words=True)
316+
and validates that the character offsets are correct.
317+
"""
318+
token_classifier = pipeline(task="ner", model="dslim/bert-base-NER", aggregation_strategy="simple")
319+
320+
# Input is a list of words
321+
words = ["Hello", "Sarah", "lives", "in", "New", "York"]
322+
323+
# The reconstructed sentence will be "Hello Sarah lives in New York"
324+
# - "Sarah": starts at index 6, ends at 11
325+
# - "New York": starts at index 21, ends at 29
326+
327+
output = token_classifier(words, is_split_into_words=True)
328+
329+
self.assertEqual(
330+
nested_simplify(output),
331+
[
332+
{"entity_group": "PER", "score": ANY(float), "word": "Sarah", "start": 6, "end": 11},
333+
{"entity_group": "LOC", "score": ANY(float), "word": "New York", "start": 21, "end": 29},
334+
],
335+
)
336+
337+
# Also test batching with pre-tokenized inputs
338+
words2 = ["My", "name", "is", "Wolfgang", "and", "I", "live", "in", "Berlin"]
339+
batch_output = token_classifier([words, words2], is_split_into_words=True)
340+
341+
# Expected for second sentence ("My name is Wolfgang and I live in Berlin")
342+
# - "Wolfgang": starts at 12, ends at 20
343+
# - "Berlin": starts at 36, ends at 42
344+
345+
self.assertEqual(
346+
nested_simplify(batch_output),
347+
[
348+
[
349+
{"entity_group": "PER", "score": ANY(float), "word": "Sarah", "start": 6, "end": 11},
350+
{"entity_group": "LOC", "score": ANY(float), "word": "New York", "start": 21, "end": 29},
351+
],
352+
[
353+
{"entity_group": "PER", "score": ANY(float), "word": "Wolfgang", "start": 12, "end": 20},
354+
{"entity_group": "LOC", "score": ANY(float), "word": "Berlin", "start": 36, "end": 42},
355+
],
356+
],
357+
)
358+
311359
@require_torch
312360
def test_chunking_fast(self):
313361
# Note: We cannot run the test on "conflicts" on the chunking.
@@ -953,19 +1001,24 @@ def setUp(self):
9531001
def test_simple(self):
9541002
string = "This is a simple input"
9551003

956-
inputs, offset_mapping = self.args_parser(string)
1004+
inputs, is_split_into_words, offset_mapping, delimiter = self.args_parser(string)
9571005
self.assertEqual(inputs, [string])
1006+
self.assertFalse(is_split_into_words)
9581007
self.assertEqual(offset_mapping, None)
9591008

960-
inputs, offset_mapping = self.args_parser([string, string])
1009+
inputs, is_split_into_words, offset_mapping, delimiter = self.args_parser([string, string])
9611010
self.assertEqual(inputs, [string, string])
1011+
self.assertFalse(is_split_into_words)
9621012
self.assertEqual(offset_mapping, None)
9631013

964-
inputs, offset_mapping = self.args_parser(string, offset_mapping=[(0, 1), (1, 2)])
1014+
inputs, is_split_into_words, offset_mapping, delimiter = self.args_parser(
1015+
string, offset_mapping=[(0, 1), (1, 2)]
1016+
)
9651017
self.assertEqual(inputs, [string])
1018+
self.assertFalse(is_split_into_words)
9661019
self.assertEqual(offset_mapping, [[(0, 1), (1, 2)]])
9671020

968-
inputs, offset_mapping = self.args_parser(
1021+
inputs, is_split_into_words, offset_mapping, delimiter = self.args_parser(
9691022
[string, string], offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]]
9701023
)
9711024
self.assertEqual(inputs, [string, string])

0 commit comments

Comments
 (0)