Skip to content

Commit 0267e5b

Browse files
authored
Merge pull request #701 from wangzhen38/del_batch_error
update error info when batchSize > datasetSize
2 parents 64fc8ed + a231e14 commit 0267e5b

File tree

4 files changed

+18
-0
lines changed

4 files changed

+18
-0
lines changed

tools/infer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ def main(args):
116116
infer_run_cost = 0.0
117117
reader_start = time.time()
118118

119+
#we will drop the last incomplete batch when dataset size is not divisible by the batch size
120+
assert any(test_dataloader(
121+
)), "test_dataloader is null, please ensure batch size < dataset size!"
122+
119123
for batch_id, batch in enumerate(test_dataloader()):
120124
infer_reader_cost += time.time() - reader_start
121125
infer_start = time.time()

tools/static_infer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ def main(args):
128128

129129
if use_auc:
130130
reset_auc(use_fleet, auc_num)
131+
132+
#we will drop the last incomplete batch when dataset size is not divisible by the batch size
133+
assert any(test_dataloader(
134+
)), "test_dataloader's size is null, please ensure batch size < dataset size!"
135+
131136
for batch_id, batch_data in enumerate(test_dataloader()):
132137
infer_reader_cost += time.time() - reader_start
133138
infer_start = time.time()

tools/static_trainer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,11 @@ def dataloader_train(epoch_id, train_dataloader, input_data_names, fetch_vars,
239239
train_run_cost = 0.0
240240
total_samples = 0
241241
reader_start = time.time()
242+
243+
#we will drop the last incomplete batch when dataset size is not divisible by the batch size
244+
assert any(train_dataloader(
245+
)), "train_dataloader's size is null, please ensure batch size < dataset size!"
246+
242247
for batch_id, batch_data in enumerate(train_dataloader()):
243248
train_reader_cost += time.time() - reader_start
244249
train_start = time.time()

tools/trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ def main(args):
127127
total_samples = 0
128128
reader_start = time.time()
129129

130+
#we will drop the last incomplete batch when dataset size is not divisible by the batch size
131+
assert any(train_dataloader(
132+
)), "train_dataloader is null, please ensure batch size < dataset size!"
133+
130134
for batch_id, batch in enumerate(train_dataloader()):
131135
train_reader_cost += time.time() - reader_start
132136
optimizer.clear_grad()

0 commit comments

Comments
 (0)