File tree Expand file tree Collapse file tree 2 files changed +8
-8
lines changed
examples/text_summarization/bart Expand file tree Collapse file tree 2 files changed +8
-8
lines changed Original file line number Diff line number Diff line change @@ -172,11 +172,11 @@ def generate(args):
172
172
ignore_pad_token_for_loss = args .ignore_pad_token_for_loss ,
173
173
is_train = False )
174
174
batchify_fn = lambda samples , fn = Tuple (
175
- Stack (), # input_ids
176
- Stack (), # attention mask
175
+ Stack (dtype = "int64" ), # input_ids
176
+ Stack (dtype = "int64" ), # attention mask
177
177
Stack (dtype = "int32" ), # mem_seq_lens
178
- Stack (), # decoder_input_ids
179
- Stack (), # labels
178
+ Stack (dtype = "int64" ), # decoder_input_ids
179
+ Stack (dtype = "int64" ), # labels
180
180
): fn (samples )
181
181
182
182
dataset = dataset .map (trans_func , lazy = True )
Original file line number Diff line number Diff line change @@ -220,10 +220,10 @@ def do_train(args):
220
220
train_batch_sampler = DistributedBatchSampler (
221
221
train_set , batch_size = args .train_batch_size , shuffle = True )
222
222
batchify_fn = lambda samples , fn = Tuple (
223
- Stack (), # input_ids
224
- Stack (), # attention mask
225
- Stack (), # decoder_input_ids
226
- Stack (), # labels
223
+ Stack (dtype = "int64" ), # input_ids
224
+ Stack (dtype = "int64" ), # attention mask
225
+ Stack (dtype = "int64" ), # decoder_input_ids
226
+ Stack (dtype = "int64" ), # labels
227
227
): fn (samples )
228
228
train_data_loader = DataLoader (
229
229
dataset = train_set ,
You can’t perform that action at this time.
0 commit comments