@@ -127,14 +127,23 @@ def is_metadata(idx: int) -> bool:
127127 max_text_len , text_with_local_metadata_encoded .input_ids , token_level_metadata_mask
128128 ):
129129 if cfg .apply_cm3_loss_to_sequences :
130- span_start , span_end = random .randint (0 , len (text_chunk_encoded )), random .randint (0 , len (text_chunk_encoded ))
131- if span_end < span_start :
132- span_start , span_end = span_end , span_start
130+ span_ids = sorted ([random .randint (0 , len (text_chunk_encoded )) for x in range (2 )])
131+ span_start , span_end = span_ids [0 ], span_ids [1 ]
133132 if span_end - span_start > 0 :
134- text_chunk_encoded = text_chunk_encoded [:span_start ] + [tokenizer .mask_token_id ] + \
135- text_chunk_encoded [span_end :] + [tokenizer .mask_token_id ] + text_chunk_encoded [span_start : span_end ]
136- chunk_metadata_mask = chunk_metadata_mask [:span_start ] + [1 ] + \
137- chunk_metadata_mask [span_end :] + [1 ] + chunk_metadata_mask [span_start : span_end ]
133+ text_chunk_encoded = (
134+ text_chunk_encoded [:span_start ]
135+ + [tokenizer .mask_token_id ]
136+ + text_chunk_encoded [span_end :]
137+ + [tokenizer .mask_token_id ]
138+ + text_chunk_encoded [span_start :span_end ]
139+ )
140+ chunk_metadata_mask = (
141+ chunk_metadata_mask [:span_start ]
142+ + [1 ]
143+ + chunk_metadata_mask [span_end :]
144+ + [1 ]
145+ + chunk_metadata_mask [span_start :span_end ]
146+ )
138147
139148 total_len = prefix_len + len (text_chunk_encoded )
140149 padding_len = max_text_len - len (text_chunk_encoded )
0 commit comments