Skip to content

Commit 420cb2a

Browse files
committed
Add CM3 loss to bsmetadata/metadata_utils.py
1 parent c549da0 commit 420cb2a

File tree

4 files changed

+27
-2
lines changed

4 files changed

+27
-2
lines changed

bsmetadata/hydra_configs/v2.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ data_config:
4848
metadata_prefix_sep: ' |||'
4949
metadata_prefix_start_seq: ''
5050
max_seq_len: 1024
51+
apply_cm3_loss_to_sequences: false
5152
html_parser_config:
5253
all_tags_rules:
5354
attributes_to_keep:

bsmetadata/metadata_processors.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,12 @@ class MetadataConfig:
174174
max_seq_len: int = field(
175175
default=512, metadata={"help": "The maximum number of tokens to use for each training chunk."}
176176
)
177+
apply_cm3_loss_to_sequences: bool = field(
178+
default=False,
179+
metadata={
180+
"help": "If True, the CM3 loss will be applied to training input sequences. "
181+
},
182+
)
177183
html_parser_config: Optional[HTMLParserConfig] = HTMLParserConfig(
178184
AllTagsRules(
179185
attributes_to_keep=None,

bsmetadata/metadata_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,22 @@ def is_metadata(idx: int) -> bool:
120120
# Create chunks of `max_seq_len` tokens.
121121
prefix_len = len(metadata_prefix_encoded)
122122
max_text_len = cfg.max_seq_len - prefix_len
123+
if cfg.apply_cm3_loss_to_sequences:
124+
max_text_len -= 2
123125

124126
for text_chunk_encoded, chunk_metadata_mask in chunks(
125127
max_text_len, text_with_local_metadata_encoded.input_ids, token_level_metadata_mask
126128
):
129+
if cfg.apply_cm3_loss_to_sequences:
130+
span_start, span_end = random.randint(0, len(text_chunk_encoded)), random.randint(0, len(text_chunk_encoded))
131+
if span_end < span_start:
132+
span_start, span_end = span_end, span_start
133+
if span_end - span_start > 0:
134+
text_chunk_encoded = text_chunk_encoded[:span_start] + [tokenizer.mask_token_id] + \
135+
text_chunk_encoded[span_end:] + [tokenizer.mask_token_id] + text_chunk_encoded[span_start: span_end]
136+
chunk_metadata_mask = chunk_metadata_mask[:span_start] + [1] + \
137+
chunk_metadata_mask[span_end:] + [1] + chunk_metadata_mask[span_start: span_end]
138+
127139
total_len = prefix_len + len(text_chunk_encoded)
128140
padding_len = max_text_len - len(text_chunk_encoded)
129141

bsmetadata/train.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,10 +278,16 @@ def main(args: CFG) -> None:
278278
new_tokens = [
279279
AddedToken(token, rstrip=False, lstrip=False, single_word=False, normalized=False) for token in new_tokens
280280
]
281-
tokenizer = AutoTokenizer.from_pretrained(args.model_name, additional_special_tokens=new_tokens)
282281
else:
283-
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
282+
new_tokens = []
283+
284+
new_tokens += [AddedToken("<MASK>", rstrip=False, lstrip=False, single_word=False, normalized=False)]
285+
tokenizer = AutoTokenizer.from_pretrained(args.model_name, additional_special_tokens=new_tokens)
286+
287+
tokenizer.mask_token = "<MASK>"
288+
tokenizer.mask_token_id = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
284289
tokenizer.pad_token = tokenizer.eos_token
290+
285291
if args.data_config.experiment == "with_metadata_datasetv2_tf":
286292
from bsmetadata.experiments.with_metadata_datasetv2_tf import get_dataloader, get_dummy_dataloader
287293

0 commit comments

Comments
 (0)