1- from typing import Literal
1+ from typing import Literal , Optional
22import os
33import functools as ft
44from transformers import BertTokenizerFast , CamembertTokenizerFast # type: ignore
55from tqdm import tqdm
66from sacred .experiment import Experiment
77from sacred .run import Run
88from sacred .commands import print_config
9+ from tibert import predict
910from tibert .bertcoref import (
1011 CoreferenceDataset ,
1112 CoreferenceDocument ,
@@ -29,6 +30,7 @@ def config():
2930 dataset_name : str = "litbank"
3031 dataset_path : str = os .path .expanduser ("~/litbank" )
3132 max_span_size : int = 10
33+ limit_doc_size : Optional [int ] = None
3234 hierarchical_merging : bool = False
3335 device_str : str = "auto"
3436 model_path : str
@@ -41,6 +43,7 @@ def main(
4143 dataset_name : Literal ["litbank" , "fr-litbank" , "democrat" ],
4244 dataset_path : str ,
4345 max_span_size : int ,
46+ limit_doc_size : Optional [int ],
4447 hierarchical_merging : bool ,
4548 device_str : Literal ["cuda" , "cpu" , "auto" ],
4649 model_path : str ,
@@ -79,36 +82,45 @@ def main(
7982 )
8083 _ , test_dataset = dataset .splitted (0.9 )
8184
82- all_annotated_docs = []
83- for document in tqdm ( test_dataset . documents ):
84- doc_dataset = CoreferenceDataset (
85- [ document ] ,
85+ if limit_doc_size is None :
86+ all_annotated_docs = predict_coref (
87+ [ doc . tokens for doc in dataset . documents ],
88+ model ,
8689 tokenizer ,
87- max_span_size ,
90+ device_str = device_str ,
91+ batch_size = batch_size ,
8892 )
89- if hierarchical_merging :
90- annotated_doc = predict_coref (
91- [doc .tokens for doc in doc_dataset .documents ],
92- model ,
93+ assert isinstance (all_annotated_docs , list )
94+ else :
95+ all_annotated_docs = []
96+ for document in tqdm (test_dataset .documents ):
97+ doc_dataset = CoreferenceDataset (
98+ split_coreference_document_tokens (document , 512 ),
9399 tokenizer ,
94- hierarchical_merging = True ,
95- quiet = True ,
96- device_str = device_str ,
97- batch_size = batch_size ,
100+ max_span_size ,
98101 )
99- else :
100- annotated_docs = predict_coref (
101- [doc .tokens for doc in doc_dataset .documents ],
102- model ,
103- tokenizer ,
104- hierarchical_merging = False ,
105- quiet = True ,
106- device_str = device_str ,
107- batch_size = batch_size ,
108- )
109- assert isinstance (annotated_docs , list )
110- annotated_doc = CoreferenceDocument .concatenated (annotated_docs )
111- all_annotated_docs .append (annotated_doc )
102+ if hierarchical_merging :
103+ annotated_doc = predict_coref (
104+ [doc .tokens for doc in doc_dataset .documents ],
105+ model ,
106+ tokenizer ,
107+ hierarchical_merging = True ,
108+ quiet = True ,
109+ device_str = device_str ,
110+ batch_size = batch_size ,
111+ )
112+ else :
113+ annotated_docs = predict_coref (
114+ [doc .tokens for doc in doc_dataset .documents ],
115+ model ,
116+ tokenizer ,
117+ quiet = True ,
118+ device_str = device_str ,
119+ batch_size = batch_size ,
120+ )
121+ assert isinstance (annotated_docs , list )
122+ annotated_doc = CoreferenceDocument .concatenated (annotated_docs )
123+ all_annotated_docs .append (annotated_doc )
112124
113125 mention_pre , mention_rec , mention_f1 = score_mention_detection (
114126 all_annotated_docs , test_dataset .documents
0 commit comments