11from __future__ import annotations
2- from typing import TYPE_CHECKING , Literal , List
2+ from typing import TYPE_CHECKING , Literal , List , Union , cast
33from transformers import PreTrainedTokenizerFast
44import torch
55from tqdm import tqdm
1515
1616
1717def predict_coref (
18- documents : List [str ],
18+ documents : List [Union [ str , List [ str ]] ],
1919 model : BertForCoreferenceResolution ,
2020 tokenizer : PreTrainedTokenizerFast ,
2121 batch_size : int = 1 ,
@@ -24,14 +24,15 @@ def predict_coref(
2424) -> List [CoreferenceDocument ]:
2525 """Predict coreference chains for a list of documents.
2626
27- :param documents: A list of tokenized documents.
27+ :param documents: A list of documents, tokenized or not. If
28+ documents are not tokenized, MosesTokenizer will tokenize them
29+ automatically.
2830 :param tokenizer:
2931 :param batch_size:
30- :param quiet: If ``True``, will report progress using
31- ``tqdm``.
32+ :param quiet: If ``True``, will report progress using ``tqdm``.
3233
3334 :return: a list of ``CoreferenceDocument``, with annotated
34- coreference chains.
35+ coreference chains.
3536 """
3637 from tibert import (
3738 CoreferenceDataset ,
@@ -43,10 +44,18 @@ def predict_coref(
4344 device_str = "cuda" if torch .cuda .is_available () else "cpu"
4445 device = torch .device (device_str )
4546
46- m_tokenizer = MosesTokenizer (lang = "en" )
47- tokenized_documents = [
48- m_tokenizer .tokenize (text , escape = False ) for text in documents
49- ]
47+ if len (documents ) == 0 :
48+ return []
49+
50+ # Tokenized input sentence if needed
51+ if isinstance (documents [0 ], str ):
52+ m_tokenizer = MosesTokenizer (lang = "en" )
53+ tokenized_documents = [
54+ m_tokenizer .tokenize (text , escape = False ) for text in documents
55+ ]
56+ else :
57+ tokenized_documents = documents
58+ tokenized_documents = cast (List [List [str ]], tokenized_documents )
5059
5160 dataset = CoreferenceDataset (
5261 [CoreferenceDocument (doc , []) for doc in tokenized_documents ],
@@ -93,7 +102,10 @@ def predict_coref(
93102
94103
95104def predict_coref_simple (
96- text : str , model , tokenizer , device_str : Literal ["cpu" , "cuda" , "auto" ] = "auto"
105+ text : Union [str , List [str ]],
106+ model ,
107+ tokenizer ,
108+ device_str : Literal ["cpu" , "cuda" , "auto" ] = "auto" ,
97109) -> CoreferenceDocument :
98110 annotated_docs = predict_coref (
99111 [text ], model , tokenizer , batch_size = 1 , device_str = device_str , quiet = True
0 commit comments