|
1 | | -from hypothesis import given, settings, HealthCheck |
2 | | -from transformers import BertTokenizerFast |
3 | | -from pytest import fixture |
4 | | -from tibert.bertcoref import CoreferenceDocument, DataCollatorForSpanClassification |
| 1 | +from hypothesis import given |
| 2 | +from tibert.bertcoref import CoreferenceDocument |
5 | 3 | from tests.strategies import coref_docs |
6 | 4 |
|
7 | 5 |
|
8 | | -@fixture |
9 | | -def bert_tokenizer() -> BertTokenizerFast: |
10 | | - return BertTokenizerFast.from_pretrained("bert-base-cased") |
11 | | - |
12 | | - |
13 | | -# we suppress the `function_scoped_fixture` health check since we want |
14 | | -# to execute the `bert_tokenizer` fixture only once. |
15 | | -@settings(deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture]) |
16 | | -@given(doc=coref_docs(min_size=5, max_size=10, max_span_size=4)) |
17 | | -def test_doc_is_reconstructed( |
18 | | - doc: CoreferenceDocument, bert_tokenizer: BertTokenizerFast |
19 | | -): |
20 | | - max_span_size = min(4, len(doc)) |
21 | | - prep_doc, batch = doc.prepared_document(bert_tokenizer, max_span_size) |
22 | | - print(prep_doc) |
23 | | - collator = DataCollatorForSpanClassification(bert_tokenizer, max_span_size, "cpu") |
24 | | - batch = collator([batch]) |
25 | | - seq_size = batch["input_ids"].shape[1] |
26 | | - wp_to_token = [batch.token_to_word(0, token_index=i) for i in range(seq_size)] |
27 | | - reconstructed_doc = prep_doc.from_wpieced_to_tokenized(doc.tokens, wp_to_token) |
28 | | - |
29 | | - assert doc.tokens == reconstructed_doc.tokens |
30 | | - assert doc.coref_chains == reconstructed_doc.coref_chains |
31 | | - |
32 | | - |
33 | 6 | @given(doc=coref_docs()) |
34 | 7 | def test_mention_labels_number_is_correct(doc: CoreferenceDocument): |
35 | 8 | """ |
|
0 commit comments