Skip to content

Commit 47805ac

Browse files
committed
add styling changes to cm3 loss
1 parent 420cb2a commit 47805ac

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

bsmetadata/metadata_processors.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,7 @@ class MetadataConfig:
176176
)
177177
apply_cm3_loss_to_sequences: bool = field(
178178
default=False,
179-
metadata={
180-
"help": "If True, the CM3 loss will be applied to training input sequences. "
181-
},
179+
metadata={"help": "If True, the CM3 loss will be applied to training input sequences. "},
182180
)
183181
html_parser_config: Optional[HTMLParserConfig] = HTMLParserConfig(
184182
AllTagsRules(

bsmetadata/metadata_utils.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,23 @@ def is_metadata(idx: int) -> bool:
127127
max_text_len, text_with_local_metadata_encoded.input_ids, token_level_metadata_mask
128128
):
129129
if cfg.apply_cm3_loss_to_sequences:
130-
span_start, span_end = random.randint(0, len(text_chunk_encoded)), random.randint(0, len(text_chunk_encoded))
131-
if span_end < span_start:
132-
span_start, span_end = span_end, span_start
130+
span_ids = sorted([random.randint(0, len(text_chunk_encoded)) for x in range(2)])
131+
span_start, span_end = span_ids[0], span_ids[1]
133132
if span_end - span_start > 0:
134-
text_chunk_encoded = text_chunk_encoded[:span_start] + [tokenizer.mask_token_id] + \
135-
text_chunk_encoded[span_end:] + [tokenizer.mask_token_id] + text_chunk_encoded[span_start: span_end]
136-
chunk_metadata_mask = chunk_metadata_mask[:span_start] + [1] + \
137-
chunk_metadata_mask[span_end:] + [1] + chunk_metadata_mask[span_start: span_end]
133+
text_chunk_encoded = (
134+
text_chunk_encoded[:span_start]
135+
+ [tokenizer.mask_token_id]
136+
+ text_chunk_encoded[span_end:]
137+
+ [tokenizer.mask_token_id]
138+
+ text_chunk_encoded[span_start:span_end]
139+
)
140+
chunk_metadata_mask = (
141+
chunk_metadata_mask[:span_start]
142+
+ [1]
143+
+ chunk_metadata_mask[span_end:]
144+
+ [1]
145+
+ chunk_metadata_mask[span_start:span_end]
146+
)
138147

139148
total_len = prefix_len + len(text_chunk_encoded)
140149
padding_len = max_text_len - len(text_chunk_encoded)

0 commit comments

Comments
 (0)