Skip to content

Commit 3efc7ef

Browse files
committed
tibert.run_test now correctly support multiple document split configurations
1 parent a1c97fe commit 3efc7ef

File tree

1 file changed

+39
-27
lines changed

1 file changed

+39
-27
lines changed

tibert/run_test.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Literal
1+
from typing import Literal, Optional
22
import os
33
import functools as ft
44
from transformers import BertTokenizerFast, CamembertTokenizerFast # type: ignore
55
from tqdm import tqdm
66
from sacred.experiment import Experiment
77
from sacred.run import Run
88
from sacred.commands import print_config
9+
from tibert import predict
910
from tibert.bertcoref import (
1011
CoreferenceDataset,
1112
CoreferenceDocument,
@@ -29,6 +30,7 @@ def config():
2930
dataset_name: str = "litbank"
3031
dataset_path: str = os.path.expanduser("~/litbank")
3132
max_span_size: int = 10
33+
limit_doc_size: Optional[int] = None
3234
hierarchical_merging: bool = False
3335
device_str: str = "auto"
3436
model_path: str
@@ -41,6 +43,7 @@ def main(
4143
dataset_name: Literal["litbank", "fr-litbank", "democrat"],
4244
dataset_path: str,
4345
max_span_size: int,
46+
limit_doc_size: Optional[int],
4447
hierarchical_merging: bool,
4548
device_str: Literal["cuda", "cpu", "auto"],
4649
model_path: str,
@@ -79,36 +82,45 @@ def main(
7982
)
8083
_, test_dataset = dataset.splitted(0.9)
8184

82-
all_annotated_docs = []
83-
for document in tqdm(test_dataset.documents):
84-
doc_dataset = CoreferenceDataset(
85-
[document],
85+
if limit_doc_size is None:
86+
all_annotated_docs = predict_coref(
87+
[doc.tokens for doc in dataset.documents],
88+
model,
8689
tokenizer,
87-
max_span_size,
90+
device_str=device_str,
91+
batch_size=batch_size,
8892
)
89-
if hierarchical_merging:
90-
annotated_doc = predict_coref(
91-
[doc.tokens for doc in doc_dataset.documents],
92-
model,
93+
assert isinstance(all_annotated_docs, list)
94+
else:
95+
all_annotated_docs = []
96+
for document in tqdm(test_dataset.documents):
97+
doc_dataset = CoreferenceDataset(
98+
split_coreference_document_tokens(document, 512),
9399
tokenizer,
94-
hierarchical_merging=True,
95-
quiet=True,
96-
device_str=device_str,
97-
batch_size=batch_size,
100+
max_span_size,
98101
)
99-
else:
100-
annotated_docs = predict_coref(
101-
[doc.tokens for doc in doc_dataset.documents],
102-
model,
103-
tokenizer,
104-
hierarchical_merging=False,
105-
quiet=True,
106-
device_str=device_str,
107-
batch_size=batch_size,
108-
)
109-
assert isinstance(annotated_docs, list)
110-
annotated_doc = CoreferenceDocument.concatenated(annotated_docs)
111-
all_annotated_docs.append(annotated_doc)
102+
if hierarchical_merging:
103+
annotated_doc = predict_coref(
104+
[doc.tokens for doc in doc_dataset.documents],
105+
model,
106+
tokenizer,
107+
hierarchical_merging=True,
108+
quiet=True,
109+
device_str=device_str,
110+
batch_size=batch_size,
111+
)
112+
else:
113+
annotated_docs = predict_coref(
114+
[doc.tokens for doc in doc_dataset.documents],
115+
model,
116+
tokenizer,
117+
quiet=True,
118+
device_str=device_str,
119+
batch_size=batch_size,
120+
)
121+
assert isinstance(annotated_docs, list)
122+
annotated_doc = CoreferenceDocument.concatenated(annotated_docs)
123+
all_annotated_docs.append(annotated_doc)
112124

113125
mention_pre, mention_rec, mention_f1 = score_mention_detection(
114126
all_annotated_docs, test_dataset.documents

0 commit comments

Comments
 (0)