11from __future__ import annotations
2- from typing import TYPE_CHECKING , Literal , List , Union , cast
2+ from typing import TYPE_CHECKING , Generator , Literal , List , Union , cast
33from transformers import PreTrainedTokenizerFast
44import torch
55from tqdm import tqdm
1414 )
1515
1616
17- def predict_coref (
17+ def stream_predict_coref (
1818 documents : List [Union [str , List [str ]]],
1919 model : BertForCoreferenceResolution ,
2020 tokenizer : PreTrainedTokenizerFast ,
2121 batch_size : int = 1 ,
2222 quiet : bool = False ,
2323 device_str : Literal ["cpu" , "cuda" , "auto" ] = "auto" ,
2424 lang : str = "en" ,
25- ) -> List [CoreferenceDocument ]:
25+ ) -> Generator [CoreferenceDocument , None , None ]:
26+
2627 """Predict coreference chains for a list of documents.
2728
2829 :param documents: A list of documents, tokenized or not. If
@@ -47,7 +48,7 @@ def predict_coref(
4748 device = torch .device (device_str )
4849
4950 if len (documents ) == 0 :
50- return []
51+ return
5152
5253 # Tokenized input sentence if needed
5354 if isinstance (documents [0 ], str ):
@@ -72,10 +73,10 @@ def predict_coref(
7273 model = model .eval () # type: ignore
7374 model = model .to (device )
7475
75- preds = []
76-
7776 with torch .no_grad ():
77+
7878 for i , batch in enumerate (tqdm (dataloader , disable = quiet )):
79+
7980 local_batch_size = batch ["input_ids" ].shape [0 ]
8081
8182 start_idx = batch_size * i
@@ -84,21 +85,50 @@ def predict_coref(
8485
8586 batch = batch .to (device )
8687 out : BertCoreferenceResolutionOutput = model (** batch )
88+
8789 out_docs = out .coreference_documents (
8890 [
8991 [tokenizer .decode (t ) for t in input_ids ] # type: ignore
9092 for input_ids in batch ["input_ids" ]
9193 ]
9294 )
93- out_docs = [
94- out_doc .from_wpieced_to_tokenized (original_doc .tokens , batch , batch_i )
95- for batch_i , (original_doc , out_doc ) in enumerate (
96- zip (batch_docs , out_docs )
95+
96+ for batch_i , (original_doc , out_doc ) in enumerate (
97+ zip (batch_docs , out_docs )
98+ ):
99+ doc = out_doc .from_wpieced_to_tokenized (
100+ original_doc .tokens , batch , batch_i
97101 )
98- ]
99- preds += out_docs
102+ yield doc
103+
104+
105+ def predict_coref (
106+ documents : List [Union [str , List [str ]]],
107+ model : BertForCoreferenceResolution ,
108+ tokenizer : PreTrainedTokenizerFast ,
109+ batch_size : int = 1 ,
110+ quiet : bool = False ,
111+ device_str : Literal ["cpu" , "cuda" , "auto" ] = "auto" ,
112+ lang : str = "en" ,
113+ ) -> List [CoreferenceDocument ]:
114+ """Predict coreference chains for a list of documents.
115+
116+ :param documents: A list of documents, tokenized or not. If
117+ documents are not tokenized, MosesTokenizer will tokenize them
118+ automatically.
119+ :param tokenizer:
120+ :param batch_size:
121+ :param quiet: If ``True``, will report progress using ``tqdm``.
122+ :param lang: lang for ``MosesTokenizer``
100123
101- return preds
124+ :return: a list of ``CoreferenceDocument``, with annotated
125+ coreference chains.
126+ """
127+ return list (
128+ stream_predict_coref (
129+ documents , model , tokenizer , batch_size , quiet , device_str , lang
130+ )
131+ )
102132
103133
104134def predict_coref_simple (
0 commit comments