Skip to content

Commit e85b8d5

Browse files
committed
Add support for Democrat
1 parent 9033244 commit e85b8d5

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

tibert/bertcoref.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ def from_conll2012_file(
400400
max_span_size: int,
401401
tokens_split_idx: int,
402402
corefs_split_idx: int,
403+
separator: str = "\t",
403404
) -> CoreferenceDataset:
404405
"""
405406
:param tokens_split_idx: index of the tokens column in the
@@ -435,7 +436,7 @@ def from_conll2012_file(
435436
open_mentions = {}
436437
continue
437438

438-
splitted = line.split("\t")
439+
splitted = line.split(separator)
439440

440441
# - tokens
441442
document_tokens.append(splitted[tokens_split_idx])
@@ -453,7 +454,7 @@ def from_conll2012_file(
453454
# - A ending parenthesis indicate the end of a mention
454455
# - The middle number indicates the ID of the coreference chain
455456
# the mention belongs to
456-
if splitted[4] == "-":
457+
if splitted[corefs_split_idx] == "-":
457458
continue
458459

459460
coref_datas_list = splitted[corefs_split_idx].split("|")
@@ -635,6 +636,21 @@ def load_litbank_dataset(
635636
)
636637

637638

639+
def load_democrat_dataset(
640+
root_path: str, tokenizer: PreTrainedTokenizerFast, max_span_size: int
641+
) -> CoreferenceDataset:
642+
"Load the Democrat dataset from the boberle/coreference_databases repository."
643+
root_path = os.path.expanduser(root_path.rstrip("/"))
644+
return CoreferenceDataset.from_conll2012_file(
645+
f"{root_path}/democrat_dem1921/dem1921_base.conll",
646+
tokenizer,
647+
max_span_size,
648+
3,
649+
11,
650+
separator=" ",
651+
)
652+
653+
638654
def load_fr_litbank_dataset(
639655
root_path: str, tokenizer: PreTrainedTokenizerFast, max_span_size: int
640656
):

tibert/run_train.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
train_coref_model,
1414
load_train_checkpoint,
1515
)
16-
from tibert.bertcoref import CoreferenceDataset
16+
from tibert.bertcoref import CoreferenceDataset, load_democrat_dataset
1717

1818
ex = Experiment()
1919

@@ -22,7 +22,7 @@
2222
def config():
2323
batch_size: int = 1
2424
epochs_nb: int = 30
25-
# either "litbank" or "fr-litbank"
25+
# either "litbank", "fr-litbank" or "democrat"
2626
dataset_name: str = "litbank"
2727
dataset_path: str = os.path.expanduser("~/litbank")
2828
mentions_per_tokens: float = 0.4
@@ -45,7 +45,7 @@ def main(
4545
_run: Run,
4646
batch_size: int,
4747
epochs_nb: int,
48-
dataset_name: Literal["litbank", "fr-litbank"],
48+
dataset_name: Literal["litbank", "fr-litbank", "democrat"],
4949
dataset_path: str,
5050
mentions_per_tokens: float,
5151
antecedents_nb: int,
@@ -74,6 +74,11 @@ def main(
7474
"tokenizer_class": CamembertTokenizerFast,
7575
"loading_function": load_fr_litbank_dataset,
7676
},
77+
"democrat": {
78+
"model_class": CamembertForCoreferenceResolution,
79+
"tokenizer_class": CamembertTokenizerFast,
80+
"loading_function": load_democrat_dataset,
81+
},
7782
}
7883

7984
if not dataset_name in dataset_configs:

0 commit comments

Comments
 (0)