1010from deepmd .env import GLOBAL_NP_FLOAT_PRECISION
1111from deepmd .env import GLOBAL_ENER_FLOAT_PRECISION
1212from deepmd .utils import random as dp_random
13+ from deepmd .utils .path import DPPath
1314
1415log = logging .getLogger (__name__ )
1516
@@ -44,17 +45,18 @@ def __init__ (self,
4445 """
4546 Constructor
4647 """
47- self .dirs = glob .glob (os .path .join (sys_path , set_prefix + ".*" ))
48+ root = DPPath (sys_path )
49+ self .dirs = root .glob (set_prefix + ".*" )
4850 self .dirs .sort ()
4951 # load atom type
50- self .atom_type = self ._load_type (sys_path )
52+ self .atom_type = self ._load_type (root )
5153 self .natoms = len (self .atom_type )
5254 # load atom type map
53- self .type_map = self ._load_type_map (sys_path )
55+ self .type_map = self ._load_type_map (root )
5456 if self .type_map is not None :
5557 assert (len (self .type_map ) >= max (self .atom_type )+ 1 )
5658 # check pbc
57- self .pbc = self ._check_pbc (sys_path )
59+ self .pbc = self ._check_pbc (root )
5860 # enforce type_map if necessary
5961 if type_map is not None and self .type_map is not None :
6062 atom_type_ = [type_map .index (self .type_map [ii ]) for ii in self .atom_type ]
@@ -167,9 +169,9 @@ def check_batch_size (self, batch_size) :
167169 """
168170 for ii in self .train_dirs :
169171 if self .data_dict ['coord' ]['high_prec' ] :
170- tmpe = np . load ( os . path . join ( ii , "coord.npy" )).astype (GLOBAL_ENER_FLOAT_PRECISION )
172+ tmpe = ( ii / "coord.npy" ). load_numpy ( ).astype (GLOBAL_ENER_FLOAT_PRECISION )
171173 else :
172- tmpe = np . load ( os . path . join ( ii , "coord.npy" )).astype (GLOBAL_NP_FLOAT_PRECISION )
174+ tmpe = ( ii / "coord.npy" ). load_numpy ( ).astype (GLOBAL_NP_FLOAT_PRECISION )
173175 if tmpe .ndim == 1 :
174176 tmpe = tmpe .reshape ([1 ,- 1 ])
175177 if tmpe .shape [0 ] < batch_size :
@@ -181,9 +183,9 @@ def check_test_size (self, test_size) :
181183 Check if the system can get a test dataset with `test_size` frames.
182184 """
183185 if self .data_dict ['coord' ]['high_prec' ] :
184- tmpe = np . load ( os . path . join ( self .test_dir , "coord.npy" )).astype (GLOBAL_ENER_FLOAT_PRECISION )
186+ tmpe = ( self .test_dir / "coord.npy" ). load_numpy ( ).astype (GLOBAL_ENER_FLOAT_PRECISION )
185187 else :
186- tmpe = np . load ( os . path . join ( self .test_dir , "coord.npy" )).astype (GLOBAL_NP_FLOAT_PRECISION )
188+ tmpe = ( self .test_dir / "coord.npy" ). load_numpy ( ).astype (GLOBAL_NP_FLOAT_PRECISION )
187189 if tmpe .ndim == 1 :
188190 tmpe = tmpe .reshape ([1 ,- 1 ])
189191 if tmpe .shape [0 ] < test_size :
@@ -377,7 +379,7 @@ def _get_subdata(self, data, idx = None) :
377379 return new_data
378380
379381 def _load_batch_set (self ,
380- set_name ) :
382+ set_name : DPPath ) :
381383 self .batch_set = self ._load_set (set_name )
382384 self .batch_set , _ = self ._shuffle_data (self .batch_set )
383385 self .reset_get_batch ()
@@ -386,7 +388,7 @@ def reset_get_batch(self):
386388 self .iterator = 0
387389
388390 def _load_test_set (self ,
389- set_name ,
391+ set_name : DPPath ,
390392 shuffle_test ) :
391393 self .test_set = self ._load_set (set_name )
392394 if shuffle_test :
@@ -409,13 +411,15 @@ def _shuffle_data (self,
409411 ret [kk ] = data [kk ]
410412 return ret , idx
411413
412- def _load_set (self , set_name ) :
414+ def _load_set (self , set_name : DPPath ) :
413415 # get nframes
414- path = os .path .join (set_name , "coord.npy" )
416+ if not isinstance (set_name , DPPath ):
417+ set_name = DPPath (set_name )
418+ path = set_name / "coord.npy"
415419 if self .data_dict ['coord' ]['high_prec' ] :
416- coord = np . load ( path ).astype (GLOBAL_ENER_FLOAT_PRECISION )
420+ coord = path . load_numpy ( ).astype (GLOBAL_ENER_FLOAT_PRECISION )
417421 else :
418- coord = np . load ( path ).astype (GLOBAL_NP_FLOAT_PRECISION )
422+ coord = path . load_numpy ( ).astype (GLOBAL_NP_FLOAT_PRECISION )
419423 if coord .ndim == 1 :
420424 coord = coord .reshape ([1 ,- 1 ])
421425 nframes = coord .shape [0 ]
@@ -459,12 +463,12 @@ def _load_data(self, set_name, key, nframes, ndof_, atomic = False, must = True,
459463 ndof = ndof_ * natoms
460464 else :
461465 ndof = ndof_
462- path = os . path . join ( set_name , key + ".npy" )
463- if os . path .isfile ( path ) :
466+ path = set_name / ( key + ".npy" )
467+ if path .is_file ( ) :
464468 if high_prec :
465- data = np . load ( path ).astype (GLOBAL_ENER_FLOAT_PRECISION )
469+ data = path . load_numpy ( ).astype (GLOBAL_ENER_FLOAT_PRECISION )
466470 else :
467- data = np . load ( path ).astype (GLOBAL_NP_FLOAT_PRECISION )
471+ data = path . load_numpy ( ).astype (GLOBAL_NP_FLOAT_PRECISION )
468472 try : # YWolfeee: deal with data shape error
469473 if atomic :
470474 data = data .reshape ([nframes , natoms , - 1 ])
@@ -491,8 +495,8 @@ def _load_data(self, set_name, key, nframes, ndof_, atomic = False, must = True,
491495 return np .float32 (0.0 ), data
492496
493497
494- def _load_type (self , sys_path ) :
495- atom_type = np . loadtxt ( os . path . join ( sys_path , "type.raw" ), dtype = np .int32 , ndmin = 1 )
498+ def _load_type (self , sys_path : DPPath ) :
499+ atom_type = ( sys_path / "type.raw" ). load_txt ( dtype = np .int32 , ndmin = 1 )
496500 return atom_type
497501
498502 def _make_idx_map (self , atom_type ):
@@ -501,17 +505,16 @@ def _make_idx_map(self, atom_type):
501505 idx_map = np .lexsort ((idx , atom_type ))
502506 return idx_map
503507
504- def _load_type_map (self , sys_path ) :
505- fname = os .path .join (sys_path , 'type_map.raw' )
506- if os .path .isfile (fname ) :
507- with open (os .path .join (sys_path , 'type_map.raw' )) as fp :
508- return fp .read ().split ()
508+ def _load_type_map (self , sys_path : DPPath ) :
509+ fname = sys_path / 'type_map.raw'
510+ if fname .is_file () :
511+ return fname .load_txt (dtype = str ).tolist ()
509512 else :
510513 return None
511514
512- def _check_pbc (self , sys_path ):
515+ def _check_pbc (self , sys_path : DPPath ):
513516 pbc = True
514- if os . path . isfile ( os . path . join ( sys_path , 'nopbc' )) :
517+ if ( sys_path / 'nopbc' ). is_file ( ) :
515518 pbc = False
516519 return pbc
517520
0 commit comments