Skip to content

Commit 2919614

Browse files
author
Han Wang
committed
warning rather than raising when the required batch size is smaller than the dataset
1 parent 30482b3 commit 2919614

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

source/train/Data.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,10 @@ def get_numb_set (self) :
157157

158158
def get_numb_batch (self, batch_size, set_idx) :
159159
data = self._load_set(self.train_dirs[set_idx])
160-
return data["coord"].shape[0] // batch_size
160+
ret = data["coord"].shape[0] // batch_size
161+
if ret == 0:
162+
ret = 1
163+
return ret
161164

162165
def get_sys_numb_batch (self, batch_size) :
163166
ret = 0

source/train/DataSystem.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import os, sys
44
import collections
5+
import warnings
56
import numpy as np
67

78
module_path = os.path.dirname(os.path.realpath(__file__)) + "/"
@@ -76,12 +77,12 @@ def __init__ (self,
7677
for ii in range(self.nsystems) :
7778
chk_ret = self.data_systems[ii].check_batch_size(self.batch_size[ii])
7879
if chk_ret is not None :
79-
raise RuntimeError ("system %s required batch size %d is larger than the size %d of the dataset %s" % \
80-
(self.system_dirs[ii], self.batch_size[ii], chk_ret[1], chk_ret[0]))
80+
warnings.warn("system %s required batch size is larger than the size of the dataset %s (%d > %d)" % \
81+
(self.system_dirs[ii], chk_ret[0], self.batch_size[ii], chk_ret[1]))
8182
chk_ret = self.data_systems[ii].check_test_size(test_size)
8283
if chk_ret is not None :
83-
print("WARNNING: system %s required test size %d is larger than the size %d of the dataset %s" % \
84-
(self.system_dirs[ii], test_size, chk_ret[1], chk_ret[0]))
84+
warnings.warn("system %s required test size is larger than the size of the dataset %s (%d > %d)" % \
85+
(self.system_dirs[ii], chk_ret[0], test_size, chk_ret[1]))
8586

8687
# print summary
8788
if run_opt is not None:

0 commit comments

Comments
 (0)