File tree Expand file tree Collapse file tree 1 file changed +6
-0
lines changed Expand file tree Collapse file tree 1 file changed +6
-0
lines changed Original file line number Diff line number Diff line change @@ -97,6 +97,8 @@ def check_batch_size (self, batch_size) :
9797 tmpe = np .load (os .path .join (ii , "coord.npy" )).astype (global_ener_float_precision )
9898 else :
9999 tmpe = np .load (os .path .join (ii , "coord.npy" )).astype (global_np_float_precision )
100+ if tmpe .ndim == 1 :
101+ tmpe = tmpe .reshape ([1 ,- 1 ])
100102 if tmpe .shape [0 ] < batch_size :
101103 return ii , tmpe .shape [0 ]
102104 return None
@@ -106,6 +108,8 @@ def check_test_size (self, test_size) :
106108 tmpe = np .load (os .path .join (self .test_dir , "coord.npy" )).astype (global_ener_float_precision )
107109 else :
108110 tmpe = np .load (os .path .join (self .test_dir , "coord.npy" )).astype (global_np_float_precision )
111+ if tmpe .ndim == 1 :
112+ tmpe = tmpe .reshape ([1 ,- 1 ])
109113 if tmpe .shape [0 ] < test_size :
110114 return self .test_dir , tmpe .shape [0 ]
111115 else :
@@ -271,6 +275,8 @@ def _load_set(self, set_name) :
271275 coord = np .load (path ).astype (global_ener_float_precision )
272276 else :
273277 coord = np .load (path ).astype (global_np_float_precision )
278+ if coord .ndim == 1 :
279+ coord = coord .reshape ([1 ,- 1 ])
274280 nframes = coord .shape [0 ]
275281 assert (coord .shape [1 ] == self .data_dict ['coord' ]['ndof' ] * self .natoms )
276282 # load keys
You can’t perform that action at this time.
0 commit comments