Skip to content

Commit 53f1567

Browse files
authored
add support for hdf5 (#1163)
* make a draft * add support for hdf5 * fix error with old python * fix rglobal * fix tests_path * fix tests * Update test_deepmd_data.py * use `visit` instead of `visititems` * cache file keys to prevent performance issues * improve performance
1 parent 5b0ff59 commit 53f1567

File tree

7 files changed

+396
-32
lines changed

7 files changed

+396
-32
lines changed

deepmd/common.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from deepmd.env import GLOBAL_TF_FLOAT_PRECISION, GLOBAL_NP_FLOAT_PRECISION
2424
from deepmd.utils.sess import run_sess
2525
from deepmd.utils.errors import GraphWithoutTensorError
26+
from deepmd.utils.path import DPPath
2627

2728
if TYPE_CHECKING:
2829
_DICT_VAL = TypeVar("_DICT_VAL")
@@ -429,9 +430,10 @@ def expand_sys_str(root_dir: Union[str, Path]) -> List[str]:
429430
List[str]
430431
list of string pointing to system directories
431432
"""
432-
matches = [str(d) for d in Path(root_dir).rglob("*") if (d / "type.raw").is_file()]
433-
if (Path(root_dir) / "type.raw").is_file():
434-
matches += [root_dir]
433+
root_dir = DPPath(root_dir)
434+
matches = [str(d) for d in root_dir.rglob("*") if (d / "type.raw").is_file()]
435+
if (root_dir / "type.raw").is_file():
436+
matches.append(str(root_dir))
435437
return matches
436438

437439

deepmd/entrypoints/train.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from deepmd.utils.data_system import DeepmdDataSystem
2121
from deepmd.utils.sess import run_sess
2222
from deepmd.utils.neighbor_stat import NeighborStat
23+
from deepmd.utils.path import DPPath
2324

2425
__all__ = ["train"]
2526

@@ -181,11 +182,12 @@ def get_data(jdata: Dict[str, Any], rcut, type_map, modifier):
181182
raise IOError(msg, help_msg)
182183
# rougly check all items in systems are valid
183184
for ii in systems:
184-
if (not os.path.isdir(ii)):
185+
ii = DPPath(ii)
186+
if (not ii.is_dir()):
185187
msg = f'dir {ii} is not a valid dir'
186188
log.fatal(msg)
187189
raise IOError(msg, help_msg)
188-
if (not os.path.isfile(os.path.join(ii, 'type.raw'))):
190+
if (not (ii / 'type.raw').is_file()):
189191
msg = f'dir {ii} is not a valid data system dir'
190192
log.fatal(msg)
191193
raise IOError(msg, help_msg)

deepmd/utils/data.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from deepmd.env import GLOBAL_NP_FLOAT_PRECISION
1111
from deepmd.env import GLOBAL_ENER_FLOAT_PRECISION
1212
from deepmd.utils import random as dp_random
13+
from deepmd.utils.path import DPPath
1314

1415
log = logging.getLogger(__name__)
1516

@@ -44,17 +45,18 @@ def __init__ (self,
4445
"""
4546
Constructor
4647
"""
47-
self.dirs = glob.glob (os.path.join(sys_path, set_prefix + ".*"))
48+
root = DPPath(sys_path)
49+
self.dirs = root.glob(set_prefix + ".*")
4850
self.dirs.sort()
4951
# load atom type
50-
self.atom_type = self._load_type(sys_path)
52+
self.atom_type = self._load_type(root)
5153
self.natoms = len(self.atom_type)
5254
# load atom type map
53-
self.type_map = self._load_type_map(sys_path)
55+
self.type_map = self._load_type_map(root)
5456
if self.type_map is not None:
5557
assert(len(self.type_map) >= max(self.atom_type)+1)
5658
# check pbc
57-
self.pbc = self._check_pbc(sys_path)
59+
self.pbc = self._check_pbc(root)
5860
# enforce type_map if necessary
5961
if type_map is not None and self.type_map is not None:
6062
atom_type_ = [type_map.index(self.type_map[ii]) for ii in self.atom_type]
@@ -167,9 +169,9 @@ def check_batch_size (self, batch_size) :
167169
"""
168170
for ii in self.train_dirs :
169171
if self.data_dict['coord']['high_prec'] :
170-
tmpe = np.load(os.path.join(ii, "coord.npy")).astype(GLOBAL_ENER_FLOAT_PRECISION)
172+
tmpe = (ii / "coord.npy").load_numpy().astype(GLOBAL_ENER_FLOAT_PRECISION)
171173
else:
172-
tmpe = np.load(os.path.join(ii, "coord.npy")).astype(GLOBAL_NP_FLOAT_PRECISION)
174+
tmpe = (ii / "coord.npy").load_numpy().astype(GLOBAL_NP_FLOAT_PRECISION)
173175
if tmpe.ndim == 1:
174176
tmpe = tmpe.reshape([1,-1])
175177
if tmpe.shape[0] < batch_size :
@@ -181,9 +183,9 @@ def check_test_size (self, test_size) :
181183
Check if the system can get a test dataset with `test_size` frames.
182184
"""
183185
if self.data_dict['coord']['high_prec'] :
184-
tmpe = np.load(os.path.join(self.test_dir, "coord.npy")).astype(GLOBAL_ENER_FLOAT_PRECISION)
186+
tmpe = (self.test_dir / "coord.npy").load_numpy().astype(GLOBAL_ENER_FLOAT_PRECISION)
185187
else:
186-
tmpe = np.load(os.path.join(self.test_dir, "coord.npy")).astype(GLOBAL_NP_FLOAT_PRECISION)
188+
tmpe = (self.test_dir / "coord.npy").load_numpy().astype(GLOBAL_NP_FLOAT_PRECISION)
187189
if tmpe.ndim == 1:
188190
tmpe = tmpe.reshape([1,-1])
189191
if tmpe.shape[0] < test_size :
@@ -377,7 +379,7 @@ def _get_subdata(self, data, idx = None) :
377379
return new_data
378380

379381
def _load_batch_set (self,
380-
set_name) :
382+
set_name: DPPath) :
381383
self.batch_set = self._load_set(set_name)
382384
self.batch_set, _ = self._shuffle_data(self.batch_set)
383385
self.reset_get_batch()
@@ -386,7 +388,7 @@ def reset_get_batch(self):
386388
self.iterator = 0
387389

388390
def _load_test_set (self,
389-
set_name,
391+
set_name: DPPath,
390392
shuffle_test) :
391393
self.test_set = self._load_set(set_name)
392394
if shuffle_test :
@@ -409,13 +411,15 @@ def _shuffle_data (self,
409411
ret[kk] = data[kk]
410412
return ret, idx
411413

412-
def _load_set(self, set_name) :
414+
def _load_set(self, set_name: DPPath) :
413415
# get nframes
414-
path = os.path.join(set_name, "coord.npy")
416+
if not isinstance(set_name, DPPath):
417+
set_name = DPPath(set_name)
418+
path = set_name / "coord.npy"
415419
if self.data_dict['coord']['high_prec'] :
416-
coord = np.load(path).astype(GLOBAL_ENER_FLOAT_PRECISION)
420+
coord = path.load_numpy().astype(GLOBAL_ENER_FLOAT_PRECISION)
417421
else:
418-
coord = np.load(path).astype(GLOBAL_NP_FLOAT_PRECISION)
422+
coord = path.load_numpy().astype(GLOBAL_NP_FLOAT_PRECISION)
419423
if coord.ndim == 1:
420424
coord = coord.reshape([1,-1])
421425
nframes = coord.shape[0]
@@ -459,12 +463,12 @@ def _load_data(self, set_name, key, nframes, ndof_, atomic = False, must = True,
459463
ndof = ndof_ * natoms
460464
else:
461465
ndof = ndof_
462-
path = os.path.join(set_name, key+".npy")
463-
if os.path.isfile (path) :
466+
path = set_name / (key+".npy")
467+
if path.is_file() :
464468
if high_prec :
465-
data = np.load(path).astype(GLOBAL_ENER_FLOAT_PRECISION)
469+
data = path.load_numpy().astype(GLOBAL_ENER_FLOAT_PRECISION)
466470
else:
467-
data = np.load(path).astype(GLOBAL_NP_FLOAT_PRECISION)
471+
data = path.load_numpy().astype(GLOBAL_NP_FLOAT_PRECISION)
468472
try: # YWolfeee: deal with data shape error
469473
if atomic :
470474
data = data.reshape([nframes, natoms, -1])
@@ -491,8 +495,8 @@ def _load_data(self, set_name, key, nframes, ndof_, atomic = False, must = True,
491495
return np.float32(0.0), data
492496

493497

494-
def _load_type (self, sys_path) :
495-
atom_type = np.loadtxt (os.path.join(sys_path, "type.raw"), dtype=np.int32, ndmin=1)
498+
def _load_type (self, sys_path: DPPath) :
499+
atom_type = (sys_path / "type.raw").load_txt(dtype=np.int32, ndmin=1)
496500
return atom_type
497501

498502
def _make_idx_map(self, atom_type):
@@ -501,17 +505,16 @@ def _make_idx_map(self, atom_type):
501505
idx_map = np.lexsort ((idx, atom_type))
502506
return idx_map
503507

504-
def _load_type_map(self, sys_path) :
505-
fname = os.path.join(sys_path, 'type_map.raw')
506-
if os.path.isfile(fname) :
507-
with open(os.path.join(sys_path, 'type_map.raw')) as fp:
508-
return fp.read().split()
508+
def _load_type_map(self, sys_path: DPPath) :
509+
fname = sys_path / 'type_map.raw'
510+
if fname.is_file() :
511+
return fname.load_txt(dtype=str).tolist()
509512
else :
510513
return None
511514

512-
def _check_pbc(self, sys_path):
515+
def _check_pbc(self, sys_path: DPPath):
513516
pbc = True
514-
if os.path.isfile(os.path.join(sys_path, 'nopbc')) :
517+
if (sys_path / 'nopbc').is_file() :
515518
pbc = False
516519
return pbc
517520

0 commit comments

Comments
 (0)