Skip to content

Commit 31292fd

Browse files
authored
refactor and optimize deepmd/hdf5 MultiSystems (#291)
Now it is avoided to open the same file multiple times.
1 parent ed0a6a6 commit 31292fd

File tree

3 files changed

+185
-33
lines changed

3 files changed

+185
-33
lines changed

dpdata/deepmd/hdf5.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
"""Utils for deepmd/hdf5 format."""
2+
from typing import Union
3+
24
import h5py
35
import numpy as np
46

@@ -7,16 +9,16 @@
79

810
__all__ = ['to_system_data', 'dump']
911

10-
def to_system_data(f: h5py.File,
12+
def to_system_data(f: Union[h5py.File, h5py.Group],
1113
folder: str,
1214
type_map: list = None,
1315
labels: bool = True) :
1416
"""Load a HDF5 file.
1517
1618
Parameters
1719
----------
18-
f : h5py.File
19-
HDF5 file object
20+
f : h5py.File or h5py.Group
21+
HDF5 file or group object
2022
folder : str
2123
path in the HDF5 file
2224
type_map : list
@@ -82,7 +84,7 @@ def to_system_data(f: h5py.File,
8284
data['cells'] = np.zeros((nframes, 3, 3))
8385
return data
8486

85-
def dump(f: h5py.File,
87+
def dump(f: Union[h5py.File, h5py.Group],
8688
folder: str,
8789
data: dict,
8890
set_size = 5000,
@@ -92,8 +94,8 @@ def dump(f: h5py.File,
9294
9395
Parameters
9496
----------
95-
f : h5py.File
96-
HDF5 file object
97+
f : h5py.File or h5py.Group
98+
HDF5 file or group object
9799
folder : str
98100
path in the HDF5 file
99101
data : dict

dpdata/plugins/deepmd.py

Lines changed: 149 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Union, List
2+
13
import dpdata
24
import dpdata.deepmd.raw
35
import dpdata.deepmd.comp
@@ -69,41 +71,164 @@ class DeePMDHDF5Format(Format):
6971
>>> import dpdata
7072
>>> dpdata.MultiSystems().from_deepmd_npy("data").to_deepmd_hdf5("data.hdf5")
7173
"""
72-
def from_system(self, file_name, type_map=None, **kwargs):
73-
s = file_name.split("#")
74-
name = s[1] if len(s) > 1 else ""
75-
with h5py.File(s[0], 'r') as f:
76-
return dpdata.deepmd.hdf5.to_system_data(f, name, type_map=type_map, labels=False)
74+
def _from_system(self, file_name: Union[str, h5py.Group, h5py.File], type_map: List[str], labels: bool):
75+
"""Convert HDF5 file to System or LabeledSystem data.
76+
77+
This method is used to switch from labeled or non-labeled options.
78+
79+
Parameters
80+
----------
81+
file_name : str or h5py.Group or h5py.File
82+
file name of the HDF5 file or HDF5 object. If it is a string,
83+
hashtag is used to split path to the HDF5 file and the HDF5 group
84+
type_map : dict[str]
85+
type map
86+
labels : bool
87+
if Labeled
88+
89+
Returns
90+
-------
91+
dict
92+
System or LabeledSystem data
93+
94+
Raises
95+
------
96+
TypeError
97+
file_name is not str or h5py.Group or h5py.File
98+
"""
99+
if isinstance(file_name, (h5py.Group, h5py.File)):
100+
return dpdata.deepmd.hdf5.to_system_data(file_name, "", type_map=type_map, labels=labels)
101+
elif isinstance(file_name, str):
102+
s = file_name.split("#")
103+
name = s[1] if len(s) > 1 else ""
104+
with h5py.File(s[0], 'r') as f:
105+
return dpdata.deepmd.hdf5.to_system_data(f, name, type_map=type_map, labels=labels)
106+
else:
107+
raise TypeError("Unsupported file_name")
108+
109+
def from_system(self,
110+
file_name: Union[str, h5py.Group, h5py.File],
111+
type_map: List[str]=None,
112+
**kwargs) -> dict:
113+
"""Convert HDF5 file to System data.
114+
115+
Parameters
116+
----------
117+
file_name : str or h5py.Group or h5py.File
118+
file name of the HDF5 file or HDF5 object. If it is a string,
119+
hashtag is used to split path to the HDF5 file and the HDF5 group
120+
type_map : dict[str]
121+
type map
122+
123+
Returns
124+
-------
125+
dict
126+
System data
127+
128+
Raises
129+
------
130+
TypeError
131+
file_name is not str or h5py.Group or h5py.File
132+
"""
133+
return self._from_system(file_name, type_map=type_map, labels=False)
134+
135+
def from_labeled_system(self,
136+
file_name: Union[str, h5py.Group, h5py.File],
137+
type_map: List[str]=None,
138+
**kwargs) -> dict:
139+
"""Convert HDF5 file to LabeledSystem data.
140+
141+
Parameters
142+
----------
143+
file_name : str or h5py.Group or h5py.File
144+
file name of the HDF5 file or HDF5 object. If it is a string,
145+
hashtag is used to split path to the HDF5 file and the HDF5 group
146+
type_map : dict[str]
147+
type map
148+
149+
Returns
150+
-------
151+
dict
152+
LabeledSystem data
153+
154+
Raises
155+
------
156+
TypeError
157+
file_name is not str or h5py.Group or h5py.File
158+
"""
159+
return self._from_system(file_name, type_map=type_map, labels=True)
77160

78-
def from_labeled_system(self, file_name, type_map=None, **kwargs):
79-
s = file_name.split("#")
80-
name = s[1] if len(s) > 1 else ""
81-
with h5py.File(s[0], 'r') as f:
82-
return dpdata.deepmd.hdf5.to_system_data(f, name, type_map=type_map, labels=True)
83-
84161
def to_system(self,
85162
data : dict,
86-
file_name : str,
163+
file_name: Union[str, h5py.Group, h5py.File],
87164
set_size : int = 5000,
88165
comp_prec : np.dtype = np.float64,
89166
**kwargs):
90-
s = file_name.split("#")
91-
name = s[1] if len(s) > 1 else ""
92-
mode = 'a' if name else 'w'
93-
with h5py.File(s[0], mode) as f:
94-
dpdata.deepmd.hdf5.dump(f, name, data, set_size = set_size, comp_prec = comp_prec)
167+
"""Convert System data to HDF5 file.
168+
169+
Parameters
170+
----------
171+
data : dict
172+
data dict
173+
file_name : str or h5py.Group or h5py.File
174+
file name of the HDF5 file or HDF5 object. If it is a string,
175+
hashtag is used to split path to the HDF5 file and the HDF5 group
176+
set_size : int, default=5000
177+
set size
178+
comp_prec : np.dtype
179+
data precision
180+
"""
181+
if isinstance(file_name, (h5py.Group, h5py.File)):
182+
dpdata.deepmd.hdf5.dump(file_name, "", data, set_size = set_size, comp_prec = comp_prec)
183+
elif isinstance(file_name, str):
184+
s = file_name.split("#")
185+
name = s[1] if len(s) > 1 else ""
186+
with h5py.File(s[0], 'w') as f:
187+
dpdata.deepmd.hdf5.dump(f, name, data, set_size = set_size, comp_prec = comp_prec)
188+
else:
189+
raise TypeError("Unsupported file_name")
95190

96191
def from_multi_systems(self,
97-
directory,
98-
**kwargs):
192+
directory: str,
193+
**kwargs) -> h5py.Group:
194+
"""Generate HDF5 groups from a HDF5 file, which will be
195+
passed to `from_system`.
196+
197+
Parameters
198+
----------
199+
directory : str
200+
HDF5 file name
201+
202+
Yields
203+
------
204+
h5py.Group
205+
a HDF5 group in the HDF5 file
206+
"""
99207
with h5py.File(directory, 'r') as f:
100-
return ["%s#%s" % (directory, ff) for ff in f.keys()]
208+
for ff in f.keys():
209+
yield f[ff]
101210

102211
def to_multi_systems(self,
103-
formulas,
104-
directory,
105-
**kwargs):
106-
return ["%s#%s" % (directory, ff) for ff in formulas]
212+
formulas: List[str],
213+
directory: str,
214+
**kwargs) -> h5py.Group:
215+
"""Generate HDF5 groups, which will be passed to `to_system`.
216+
217+
Parameters
218+
----------
219+
formulas : list[str]
220+
formulas of MultiSystems
221+
directory : str
222+
HDF5 file name
223+
224+
Yields
225+
------
226+
h5py.Group
227+
a HDF5 group with the name of formula
228+
"""
229+
with h5py.File(directory, 'w') as f:
230+
for ff in formulas:
231+
yield f.create_group(ff)
107232

108233

109234
@Driver.register("dp")

tests/test_deepmd_hdf5.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
import numpy as np
33
import unittest
44
from context import dpdata
5-
from comp_sys import CompLabeledSys, CompSys, IsPBC
5+
from comp_sys import CompLabeledSys, CompSys, IsNoPBC, IsPBC, MultiSystems
66

7-
class TestDeepmdLoadDumpComp(unittest.TestCase, CompLabeledSys, IsPBC):
7+
class TestDeepmdLoadDumpHDF5(unittest.TestCase, CompLabeledSys, IsPBC):
88
def setUp (self) :
99
self.system_1 = dpdata.LabeledSystem('poscars/OUTCAR.h2o.md',
1010
fmt = 'vasp/outcar')
@@ -25,7 +25,7 @@ def tearDown(self) :
2525
os.remove('tmp.deepmd.hdf5')
2626

2727

28-
class TestDeepmdCompNoLabels(unittest.TestCase, CompSys, IsPBC) :
28+
class TestDeepmdHDF5NoLabels(unittest.TestCase, CompSys, IsPBC) :
2929
def setUp (self) :
3030
self.system_1 = dpdata.System('poscars/POSCAR.h2o.md',
3131
fmt = 'vasp/poscar')
@@ -43,3 +43,28 @@ def setUp (self) :
4343
def tearDown(self) :
4444
if os.path.exists('tmp.deepmd.hdf5'):
4545
os.remove('tmp.deepmd.hdf5')
46+
47+
48+
class TestHDF5Multi(unittest.TestCase, CompLabeledSys, MultiSystems, IsNoPBC):
49+
def setUp (self):
50+
self.places = 6
51+
self.e_places = 6
52+
self.f_places = 6
53+
self.v_places = 6
54+
55+
system_1 = dpdata.LabeledSystem('gaussian/methane.gaussianlog', fmt='gaussian/log')
56+
system_2 = dpdata.LabeledSystem('gaussian/methane_reordered.gaussianlog', fmt='gaussian/log')
57+
system_3 = dpdata.LabeledSystem('gaussian/methane_sub.gaussianlog', fmt='gaussian/log')
58+
systems = dpdata.MultiSystems(system_1, system_2, system_3)
59+
systems.to_deepmd_hdf5("tmp.deepmd.hdf5")
60+
61+
self.systems = dpdata.MultiSystems().from_deepmd_hdf5("tmp.deepmd.hdf5")
62+
self.system_names = ['C1H4', 'C1H3']
63+
self.system_sizes = {'C1H4':2, 'C1H3':1}
64+
self.atom_names = ['C', 'H']
65+
self.system_1 = self.systems['C1H3']
66+
self.system_2 = system_3
67+
68+
def tearDown(self) :
69+
if os.path.exists('tmp.deepmd.hdf5'):
70+
os.remove('tmp.deepmd.hdf5')

0 commit comments

Comments
 (0)