|
15 | 15 | BertForCoreferenceResolution, |
16 | 16 | CamembertForCoreferenceResolution, |
17 | 17 | ) |
18 | | -from tibert.score import score_coref_predictions |
| 18 | +from tibert.score import score_coref_predictions, score_mention_detection |
19 | 19 | from tibert.predict import predict_coref |
20 | 20 | from tibert.utils import split_coreference_document_tokens |
21 | 21 |
|
@@ -82,7 +82,7 @@ def main( |
82 | 82 | all_annotated_docs = [] |
83 | 83 | for document in tqdm(test_dataset.documents): |
84 | 84 | doc_dataset = CoreferenceDataset( |
85 | | - split_coreference_document_tokens(document, 512), |
| 85 | + [document], |
86 | 86 | tokenizer, |
87 | 87 | max_span_size, |
88 | 88 | ) |
@@ -110,6 +110,17 @@ def main( |
110 | 110 | annotated_doc = CoreferenceDocument.concatenated(annotated_docs) |
111 | 111 | all_annotated_docs.append(annotated_doc) |
112 | 112 |
|
| 113 | + mention_pre, mention_rec, mention_f1 = score_mention_detection( |
| 114 | + all_annotated_docs, test_dataset.documents |
| 115 | + ) |
| 116 | + for metric_key, score in [ |
| 117 | + ("precision", mention_pre), |
| 118 | + ("recall", mention_rec), |
| 119 | + ("f1", mention_f1), |
| 120 | + ]: |
| 121 | + print(f"mention.{metric_key}={score}") |
| 122 | + _run.log_scalar(f"mention.{metric_key}", score) |
| 123 | + |
113 | 124 | scores = score_coref_predictions(all_annotated_docs, test_dataset.documents) |
114 | 125 | for key, score_dict in scores.items(): |
115 | 126 | for metric_key, score in score_dict.items(): |
|
0 commit comments