Skip to content

Commit 60608d3

Browse files
author
gongenlei
authored
fix: fix windows dtype to int64 (#1588)
1 parent 6213573 commit 60608d3

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

examples/text_summarization/bart/generate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,11 +172,11 @@ def generate(args):
172172
ignore_pad_token_for_loss=args.ignore_pad_token_for_loss,
173173
is_train=False)
174174
batchify_fn = lambda samples, fn=Tuple(
175-
Stack(), # input_ids
176-
Stack(), # attention mask
175+
Stack(dtype="int64"), # input_ids
176+
Stack(dtype="int64"), # attention mask
177177
Stack(dtype="int32"), # mem_seq_lens
178-
Stack(), # decoder_input_ids
179-
Stack(), # labels
178+
Stack(dtype="int64"), # decoder_input_ids
179+
Stack(dtype="int64"), # labels
180180
): fn(samples)
181181

182182
dataset = dataset.map(trans_func, lazy=True)

examples/text_summarization/bart/run_summarization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,10 +220,10 @@ def do_train(args):
220220
train_batch_sampler = DistributedBatchSampler(
221221
train_set, batch_size=args.train_batch_size, shuffle=True)
222222
batchify_fn = lambda samples, fn=Tuple(
223-
Stack(), # input_ids
224-
Stack(), # attention mask
225-
Stack(), # decoder_input_ids
226-
Stack(), # labels
223+
Stack(dtype="int64"), # input_ids
224+
Stack(dtype="int64"), # attention mask
225+
Stack(dtype="int64"), # decoder_input_ids
226+
Stack(dtype="int64"), # labels
227227
): fn(samples)
228228
train_data_loader = DataLoader(
229229
dataset=train_set,

0 commit comments

Comments
 (0)