Skip to content

Commit eb35073

Browse files
authored
Merge pull request #10 from deepmodeling/devel
devel update
2 parents 8b59c84 + 3d74f12 commit eb35073

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

.travis.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,12 @@ matrix:
6565
env:
6666
- CC=gcc-5
6767
- CXX=g++-5
68-
- TENSORFLOW_VERSION=2.0
68+
- TENSORFLOW_VERSION=2.1
6969
- python: 3.7
7070
env:
7171
- CC=gcc-8
7272
- CXX=g++-8
73-
- TENSORFLOW_VERSION=2.0
73+
- TENSORFLOW_VERSION=2.1
7474
before_install:
7575
- pip install --upgrade pip
7676
- pip install --upgrade setuptools

source/train/Data.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,10 @@ def get_numb_set (self) :
157157

158158
def get_numb_batch (self, batch_size, set_idx) :
159159
data = self._load_set(self.train_dirs[set_idx])
160-
return data["coord"].shape[0] // batch_size
160+
ret = data["coord"].shape[0] // batch_size
161+
if ret == 0:
162+
ret = 1
163+
return ret
161164

162165
def get_sys_numb_batch (self, batch_size) :
163166
ret = 0

source/train/DataSystem.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import os, sys
44
import collections
5+
import warnings
56
import numpy as np
67

78
module_path = os.path.dirname(os.path.realpath(__file__)) + "/"
@@ -76,12 +77,12 @@ def __init__ (self,
7677
for ii in range(self.nsystems) :
7778
chk_ret = self.data_systems[ii].check_batch_size(self.batch_size[ii])
7879
if chk_ret is not None :
79-
raise RuntimeError ("system %s required batch size %d is larger than the size %d of the dataset %s" % \
80-
(self.system_dirs[ii], self.batch_size[ii], chk_ret[1], chk_ret[0]))
80+
warnings.warn("system %s required batch size is larger than the size of the dataset %s (%d > %d)" % \
81+
(self.system_dirs[ii], chk_ret[0], self.batch_size[ii], chk_ret[1]))
8182
chk_ret = self.data_systems[ii].check_test_size(test_size)
8283
if chk_ret is not None :
83-
print("WARNNING: system %s required test size %d is larger than the size %d of the dataset %s" % \
84-
(self.system_dirs[ii], test_size, chk_ret[1], chk_ret[0]))
84+
warnings.warn("system %s required test size is larger than the size of the dataset %s (%d > %d)" % \
85+
(self.system_dirs[ii], chk_ret[0], test_size, chk_ret[1]))
8586

8687
# print summary
8788
if run_opt is not None:

0 commit comments

Comments
 (0)