Skip to content

Commit 5212b6c

Browse files
committed
Add stream_predict_coref
1 parent e85b8d5 commit 5212b6c

File tree

2 files changed

+44
-14
lines changed

2 files changed

+44
-14
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "tibert"
3-
version = "0.2.3"
3+
version = "0.2.4"
44
description = "BERT for Coreference Resolution"
55
authors = ["Arthur Amalvy <[email protected]>"]
66
license = "GPL-3.0-only"

tibert/predict.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from __future__ import annotations
2-
from typing import TYPE_CHECKING, Literal, List, Union, cast
2+
from typing import TYPE_CHECKING, Generator, Literal, List, Union, cast
33
from transformers import PreTrainedTokenizerFast
44
import torch
55
from tqdm import tqdm
@@ -14,15 +14,16 @@
1414
)
1515

1616

17-
def predict_coref(
17+
def stream_predict_coref(
1818
documents: List[Union[str, List[str]]],
1919
model: BertForCoreferenceResolution,
2020
tokenizer: PreTrainedTokenizerFast,
2121
batch_size: int = 1,
2222
quiet: bool = False,
2323
device_str: Literal["cpu", "cuda", "auto"] = "auto",
2424
lang: str = "en",
25-
) -> List[CoreferenceDocument]:
25+
) -> Generator[CoreferenceDocument, None, None]:
26+
2627
"""Predict coreference chains for a list of documents.
2728
2829
:param documents: A list of documents, tokenized or not. If
@@ -47,7 +48,7 @@ def predict_coref(
4748
device = torch.device(device_str)
4849

4950
if len(documents) == 0:
50-
return []
51+
return
5152

5253
# Tokenized input sentence if needed
5354
if isinstance(documents[0], str):
@@ -72,10 +73,10 @@ def predict_coref(
7273
model = model.eval() # type: ignore
7374
model = model.to(device)
7475

75-
preds = []
76-
7776
with torch.no_grad():
77+
7878
for i, batch in enumerate(tqdm(dataloader, disable=quiet)):
79+
7980
local_batch_size = batch["input_ids"].shape[0]
8081

8182
start_idx = batch_size * i
@@ -84,21 +85,50 @@ def predict_coref(
8485

8586
batch = batch.to(device)
8687
out: BertCoreferenceResolutionOutput = model(**batch)
88+
8789
out_docs = out.coreference_documents(
8890
[
8991
[tokenizer.decode(t) for t in input_ids] # type: ignore
9092
for input_ids in batch["input_ids"]
9193
]
9294
)
93-
out_docs = [
94-
out_doc.from_wpieced_to_tokenized(original_doc.tokens, batch, batch_i)
95-
for batch_i, (original_doc, out_doc) in enumerate(
96-
zip(batch_docs, out_docs)
95+
96+
for batch_i, (original_doc, out_doc) in enumerate(
97+
zip(batch_docs, out_docs)
98+
):
99+
doc = out_doc.from_wpieced_to_tokenized(
100+
original_doc.tokens, batch, batch_i
97101
)
98-
]
99-
preds += out_docs
102+
yield doc
103+
104+
105+
def predict_coref(
106+
documents: List[Union[str, List[str]]],
107+
model: BertForCoreferenceResolution,
108+
tokenizer: PreTrainedTokenizerFast,
109+
batch_size: int = 1,
110+
quiet: bool = False,
111+
device_str: Literal["cpu", "cuda", "auto"] = "auto",
112+
lang: str = "en",
113+
) -> List[CoreferenceDocument]:
114+
"""Predict coreference chains for a list of documents.
115+
116+
:param documents: A list of documents, tokenized or not. If
117+
documents are not tokenized, MosesTokenizer will tokenize them
118+
automatically.
119+
:param tokenizer:
120+
:param batch_size:
121+
:param quiet: If ``True``, will report progress using ``tqdm``.
122+
:param lang: lang for ``MosesTokenizer``
100123
101-
return preds
124+
:return: a list of ``CoreferenceDocument``, with annotated
125+
coreference chains.
126+
"""
127+
return list(
128+
stream_predict_coref(
129+
documents, model, tokenizer, batch_size, quiet, device_str, lang
130+
)
131+
)
102132

103133

104134
def predict_coref_simple(

0 commit comments

Comments
 (0)