@@ -157,13 +157,8 @@ def build_training_sample(sample, target_seq_length,
157157 return train_sample
158158
159159
160- def pad_and_convert_to_numpy (tokens , masked_positions ,
161- masked_labels , pad_id ,
162- max_seq_length , max_seq_length_dec ,
163- masked_spans = None , bos_id = None ,
164- eos_id = None , sentinel_tokens = None ):
165- """Pad sequences and convert them to numpy."""
166-
160+ def merge_subsequent_masks (tokens , masked_spans = None , bos_id = None ,
161+ eos_id = None , sentinel_tokens = None ):
167162 sentinel_tokens = collections .deque (sentinel_tokens )
168163 t5_input = []
169164 (t5_decoder_in , t5_decoder_out ) = ([bos_id ], [])
@@ -189,6 +184,18 @@ def pad_and_convert_to_numpy(tokens, masked_positions,
189184
190185 # Add the remaining tokens to the t5 input
191186 t5_input .extend (tokens [start_index :])
187+ return t5_input , t5_decoder_in , t5_decoder_out
188+
189+
190+ def pad_and_convert_to_numpy (tokens , masked_positions ,
191+ masked_labels , pad_id ,
192+ max_seq_length , max_seq_length_dec ,
193+ masked_spans = None , bos_id = None ,
194+ eos_id = None , sentinel_tokens = None ):
195+ """Pad sequences and convert them to numpy."""
196+
197+ t5_input , t5_decoder_in , t5_decoder_out = merge_subsequent_masks (
198+ tokens , masked_spans , bos_id , eos_id , sentinel_tokens )
192199
193200 # assert (len(t5_input) - len(masked_spans)) + \
194201 # (len(t5_decoder_in) - (len(masked_spans) + 1)) == len(tokens)
0 commit comments