Skip to content

Commit 360e7cf

Browse files
authored
Merge pull request #239 from amcadmus/devel
fix bug of loading coord when data has only one frame
2 parents 6b3d014 + f209b4f commit 360e7cf

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

source/train/Data.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ def check_batch_size (self, batch_size) :
9797
tmpe = np.load(os.path.join(ii, "coord.npy")).astype(global_ener_float_precision)
9898
else:
9999
tmpe = np.load(os.path.join(ii, "coord.npy")).astype(global_np_float_precision)
100+
if tmpe.ndim == 1:
101+
tmpe = tmpe.reshape([1,-1])
100102
if tmpe.shape[0] < batch_size :
101103
return ii, tmpe.shape[0]
102104
return None
@@ -106,6 +108,8 @@ def check_test_size (self, test_size) :
106108
tmpe = np.load(os.path.join(self.test_dir, "coord.npy")).astype(global_ener_float_precision)
107109
else:
108110
tmpe = np.load(os.path.join(self.test_dir, "coord.npy")).astype(global_np_float_precision)
111+
if tmpe.ndim == 1:
112+
tmpe = tmpe.reshape([1,-1])
109113
if tmpe.shape[0] < test_size :
110114
return self.test_dir, tmpe.shape[0]
111115
else :
@@ -271,6 +275,8 @@ def _load_set(self, set_name) :
271275
coord = np.load(path).astype(global_ener_float_precision)
272276
else:
273277
coord = np.load(path).astype(global_np_float_precision)
278+
if coord.ndim == 1:
279+
coord = coord.reshape([1,-1])
274280
nframes = coord.shape[0]
275281
assert(coord.shape[1] == self.data_dict['coord']['ndof'] * self.natoms)
276282
# load keys

0 commit comments

Comments
 (0)