Skip to content

Commit 780cb69

Browse files
committed
Add support for pre-tokenized documents in predictions functions
1 parent 04e5734 commit 780cb69

File tree

1 file changed

+23
-11
lines changed

1 file changed

+23
-11
lines changed

tibert/predict.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from __future__ import annotations
2-
from typing import TYPE_CHECKING, Literal, List
2+
from typing import TYPE_CHECKING, Literal, List, Union, cast
33
from transformers import PreTrainedTokenizerFast
44
import torch
55
from tqdm import tqdm
@@ -15,7 +15,7 @@
1515

1616

1717
def predict_coref(
18-
documents: List[str],
18+
documents: List[Union[str, List[str]]],
1919
model: BertForCoreferenceResolution,
2020
tokenizer: PreTrainedTokenizerFast,
2121
batch_size: int = 1,
@@ -24,14 +24,15 @@ def predict_coref(
2424
) -> List[CoreferenceDocument]:
2525
"""Predict coreference chains for a list of documents.
2626
27-
:param documents: A list of tokenized documents.
27+
:param documents: A list of documents, tokenized or not. If
28+
documents are not tokenized, MosesTokenizer will tokenize them
29+
automatically.
2830
:param tokenizer:
2931
:param batch_size:
30-
:param quiet: If ``True``, will report progress using
31-
``tqdm``.
32+
:param quiet: If ``True``, will report progress using ``tqdm``.
3233
3334
:return: a list of ``CoreferenceDocument``, with annotated
34-
coreference chains.
35+
coreference chains.
3536
"""
3637
from tibert import (
3738
CoreferenceDataset,
@@ -43,10 +44,18 @@ def predict_coref(
4344
device_str = "cuda" if torch.cuda.is_available() else "cpu"
4445
device = torch.device(device_str)
4546

46-
m_tokenizer = MosesTokenizer(lang="en")
47-
tokenized_documents = [
48-
m_tokenizer.tokenize(text, escape=False) for text in documents
49-
]
47+
if len(documents) == 0:
48+
return []
49+
50+
# Tokenized input sentence if needed
51+
if isinstance(documents[0], str):
52+
m_tokenizer = MosesTokenizer(lang="en")
53+
tokenized_documents = [
54+
m_tokenizer.tokenize(text, escape=False) for text in documents
55+
]
56+
else:
57+
tokenized_documents = documents
58+
tokenized_documents = cast(List[List[str]], tokenized_documents)
5059

5160
dataset = CoreferenceDataset(
5261
[CoreferenceDocument(doc, []) for doc in tokenized_documents],
@@ -93,7 +102,10 @@ def predict_coref(
93102

94103

95104
def predict_coref_simple(
96-
text: str, model, tokenizer, device_str: Literal["cpu", "cuda", "auto"] = "auto"
105+
text: Union[str, List[str]],
106+
model,
107+
tokenizer,
108+
device_str: Literal["cpu", "cuda", "auto"] = "auto",
97109
) -> CoreferenceDocument:
98110
annotated_docs = predict_coref(
99111
[text], model, tokenizer, batch_size=1, device_str=device_str, quiet=True

0 commit comments

Comments
 (0)