Skip to content

Commit efdbb94

Browse files
authored
Merge pull request #77 from njzjz/func_register
add experimental function register
2 parents 21b8b28 + 843e5ee commit efdbb94

File tree

3 files changed

+96
-66
lines changed

3 files changed

+96
-66
lines changed

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,26 +131,26 @@ Available properties are (nframe: number of frames in the system, natoms: total
131131
## Dump data
132132
The data stored in `System` or `LabeledSystem` can be dumped in 'lammps/lmp' or 'vasp/poscar' format, for example:
133133
```python
134-
d_outcar.to_lammps_lmp('conf.lmp', frame_idx=0)
134+
d_outcar.to('lammps/lmp', 'conf.lmp', frame_idx=0)
135135
```
136136
The first frames of `d_outcar` will be dumped to 'conf.lmp'
137137
```python
138-
d_outcar.to_vasp_poscar('POSCAR', frame_idx=-1)
138+
d_outcar.to('vasp/poscar', 'POSCAR', frame_idx=-1)
139139
```
140140
The last frames of `d_outcar` will be dumped to 'POSCAR'.
141141

142142

143143
The data stored in `LabeledSystem` can be dumped to deepmd-kit raw format, for example
144144
```python
145-
d_outcar.to_deepmd_raw('dpmd_raw')
145+
d_outcar.to('deepmd/raw', 'dpmd_raw')
146146
```
147147
Or a simpler command:
148148
```python
149-
dpdata.LabeledSystem('OUTCAR').to_deepmd_raw('dpmd_raw')
149+
dpdata.LabeledSystem('OUTCAR').to('deepmd/raw', 'dpmd_raw')
150150
```
151151
Frame selection can be implemented by
152152
```python
153-
dpdata.LabeledSystem('OUTCAR').sub_system([0,-1]).to_deepmd_raw('dpmd_raw')
153+
dpdata.LabeledSystem('OUTCAR').sub_system([0,-1]).to('deepmd/raw', 'dpmd_raw')
154154
```
155155
by which only the first and last frames are dumped to `dpmd_raw`.
156156

dpdata/system.py

Lines changed: 82 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#%%
22
import os
33
import glob
4+
import inspect
45
import numpy as np
56
import dpdata.lammps.lmp
67
import dpdata.lammps.dump
@@ -23,6 +24,18 @@
2324
from dpdata.periodic_table import Element
2425
from dpdata.xyz.quip_gap_xyz import QuipGapxyzSystems
2526

27+
28+
class Register:
29+
def __init__(self):
30+
self.funcs = {}
31+
32+
def register_funcs(self, fmt):
33+
def decorator(func):
34+
self.funcs[fmt] = func
35+
return func
36+
return decorator
37+
38+
2639
class System (MSONable) :
2740
'''
2841
The data System
@@ -97,30 +110,37 @@ def __init__ (self,
97110
return
98111
if file_name is None :
99112
return
100-
if fmt == 'auto':
101-
fmt = os.path.basename(file_name).split('.')[-1]
102-
if fmt == 'lmp' or fmt == 'lammps/lmp' :
103-
self.from_lammps_lmp(file_name, type_map = type_map)
104-
elif fmt == 'dump' or fmt == 'lammps/dump' :
105-
self.from_lammps_dump(file_name, type_map = type_map, begin = begin, step = step)
106-
elif fmt.lower() == 'poscar' or fmt.lower() == 'contcar' or fmt.lower() == 'vasp/poscar' or fmt.lower() == 'vasp/contcar':
107-
self.from_vasp_poscar(file_name)
108-
elif fmt == 'deepmd' or fmt == 'deepmd/raw':
109-
self.from_deepmd_raw(file_name, type_map = type_map)
110-
elif fmt == 'deepmd/npy':
111-
self.from_deepmd_comp(file_name, type_map = type_map)
112-
elif fmt == 'qe/cp/traj':
113-
self.from_qe_cp_traj(file_name, begin = begin, step = step)
114-
elif fmt.lower() == 'siesta/output':
115-
self.from_siesta_output(file_name)
116-
elif fmt.lower() == 'siesta/aimd_output':
117-
self.from_siesta_aiMD_output(file_name)
118-
else :
119-
raise RuntimeError('unknow data format ' + fmt)
113+
self.from_fmt(file_name, fmt, type_map=type_map, begin= begin, step=step)
120114

121115
if type_map is not None:
122116
self.apply_type_map(type_map)
123117

118+
register_from_funcs = Register()
119+
register_to_funcs = Register()
120+
121+
def from_fmt(self, file_name, fmt='auto', **kwargs):
122+
fmt = fmt.lower()
123+
if fmt == 'auto':
124+
fmt = os.path.basename(file_name).split('.')[-1].lower()
125+
from_funcs = self.register_from_funcs.funcs
126+
if fmt in from_funcs:
127+
func = from_funcs[fmt]
128+
args = inspect.getfullargspec(func).args
129+
kwargs = {kk: kwargs[kk] for kk in kwargs if kk in args}
130+
func(self, file_name, **kwargs)
131+
else :
132+
raise RuntimeError('unknow data format ' + fmt)
133+
134+
def to(self, fmt, *args, **kwargs):
135+
fmt = fmt.lower()
136+
to_funcs = self.register_to_funcs.funcs
137+
if fmt in to_funcs:
138+
func = to_funcs[fmt]
139+
func_args = inspect.getfullargspec(func).args
140+
kwargs = {kk: kwargs[kk] for kk in kwargs if kk in func_args}
141+
func(self, *args, **kwargs)
142+
else :
143+
raise RuntimeError('unknow data format %s. Accepted format:' % (fmt, " ".join(to_funcs)))
124144

125145
def __repr__(self):
126146
return self.__str__()
@@ -207,7 +227,7 @@ def map_atom_types(self,type_map=None):
207227

208228
return new_atom_types
209229

210-
230+
@register_to_funcs.register_funcs("list")
211231
def to_list(self):
212232
"""
213233
convert system to list, usefull for data collection
@@ -414,13 +434,15 @@ def apply_pbc(self) :
414434
ncoord = ncoord % 1
415435
self.data['coords'] = np.matmul(ncoord, self.data['cells'])
416436

417-
437+
@register_from_funcs.register_funcs("lmp")
438+
@register_from_funcs.register_funcs("lammps/lmp")
418439
def from_lammps_lmp (self, file_name, type_map = None) :
419440
with open(file_name) as fp:
420441
lines = [line.rstrip('\n') for line in fp]
421442
self.data = dpdata.lammps.lmp.to_system_data(lines, type_map)
422443
self._shift_orig_zero()
423444

445+
@register_to_funcs.register_funcs("pymatgen/structure")
424446
def to_pymatgen_structure(self):
425447
'''
426448
convert System to Pymatgen Structure obj
@@ -440,6 +462,7 @@ def to_pymatgen_structure(self):
440462
structures.append(structure)
441463
return structures
442464

465+
@register_to_funcs.register_funcs("ase/structure")
443466
def to_ase_structure(self):
444467
'''
445468
convert System to ASE Atom obj
@@ -459,6 +482,7 @@ def to_ase_structure(self):
459482
structures.append(structure)
460483
return structures
461484

485+
@register_to_funcs.register_funcs("lammps/lmp")
462486
def to_lammps_lmp(self, file_name, frame_idx = 0) :
463487
"""
464488
Dump the system in lammps data format
@@ -475,7 +499,8 @@ def to_lammps_lmp(self, file_name, frame_idx = 0) :
475499
with open(file_name, 'w') as fp:
476500
fp.write(w_str)
477501

478-
502+
@register_from_funcs.register_funcs('dump')
503+
@register_from_funcs.register_funcs('lammps/dump')
479504
def from_lammps_dump (self,
480505
file_name,
481506
type_map = None,
@@ -485,13 +510,17 @@ def from_lammps_dump (self,
485510
self.data = dpdata.lammps.dump.system_data(lines, type_map)
486511
self._shift_orig_zero()
487512

488-
513+
@register_from_funcs.register_funcs('poscar')
514+
@register_from_funcs.register_funcs('contcar')
515+
@register_from_funcs.register_funcs('vasp/poscar')
516+
@register_from_funcs.register_funcs('vasp/contcar')
489517
def from_vasp_poscar(self, file_name) :
490518
with open(file_name) as fp:
491519
lines = [line.rstrip('\n') for line in fp]
492520
self.data = dpdata.vasp.poscar.to_system_data(lines)
493521
self.rot_lower_triangular()
494522

523+
@register_to_funcs.register_funcs("vasp/string")
495524
def to_vasp_string(self, frame_idx=0):
496525
"""
497526
Dump the system in vasp POSCAR format string
@@ -505,6 +534,7 @@ def to_vasp_string(self, frame_idx=0):
505534
w_str = dpdata.vasp.poscar.from_system_data(self.data, frame_idx)
506535
return w_str
507536

537+
@register_to_funcs.register_funcs("vasp/poscar")
508538
def to_vasp_poscar(self, file_name, frame_idx = 0) :
509539
"""
510540
Dump the system in vasp POSCAR format
@@ -520,7 +550,7 @@ def to_vasp_poscar(self, file_name, frame_idx = 0) :
520550
with open(file_name, 'w') as fp:
521551
fp.write(w_str)
522552

523-
553+
@register_from_funcs.register_funcs('qe/cp/traj')
524554
def from_qe_cp_traj(self,
525555
prefix,
526556
begin = 0,
@@ -532,15 +562,18 @@ def from_qe_cp_traj(self,
532562
)
533563
self.rot_lower_triangular()
534564

535-
565+
@register_from_funcs.register_funcs('deepmd/npy')
536566
def from_deepmd_comp(self, folder, type_map = None) :
537567
self.data = dpdata.deepmd.comp.to_system_data(folder, type_map = type_map, labels = False)
538568

569+
@register_from_funcs.register_funcs('deepmd')
570+
@register_from_funcs.register_funcs('deepmd/raw')
539571
def from_deepmd_raw(self, folder, type_map = None) :
540572
tmp_data = dpdata.deepmd.raw.to_system_data(folder, type_map = type_map, labels = False)
541573
if tmp_data is not None :
542574
self.data = tmp_data
543575

576+
@register_to_funcs.register_funcs("deepmd/npy")
544577
def to_deepmd_npy(self, folder, set_size = 5000, prec=np.float32) :
545578
"""
546579
Dump the system in deepmd compressed format (numpy binary) to `folder`.
@@ -563,12 +596,14 @@ def to_deepmd_npy(self, folder, set_size = 5000, prec=np.float32) :
563596
set_size = set_size,
564597
comp_prec = prec)
565598

599+
@register_to_funcs.register_funcs("deepmd/raw")
566600
def to_deepmd_raw(self, folder) :
567601
"""
568602
Dump the system in deepmd raw format to `folder`
569603
"""
570604
dpdata.deepmd.raw.dump(folder, self.data)
571605

606+
@register_from_funcs.register_funcs('siesta/output')
572607
def from_siesta_output(self, fname):
573608
self.data['atom_names'], \
574609
self.data['atom_numbs'], \
@@ -579,6 +614,7 @@ def from_siesta_output(self, fname):
579614
= dpdata.siesta.output.obtain_frame(fname)
580615
# self.rot_lower_triangular()
581616

617+
@register_from_funcs.register_funcs('siesta/aimd_output')
582618
def from_siesta_aiMD_output(self, fname):
583619
self.data['atom_names'], \
584620
self.data['atom_numbs'], \
@@ -836,38 +872,11 @@ def __init__ (self,
836872
return
837873
if file_name is None :
838874
return
839-
if fmt == 'auto':
840-
fmt = os.path.basename(file_name).split('.')[-1]
841-
if fmt == 'xml' or fmt == 'XML' or fmt == 'vasp/xml' :
842-
self.from_vasp_xml(file_name, begin = begin, step = step)
843-
elif fmt == 'outcar' or fmt == 'OUTCAR' or fmt == 'vasp/outcar' :
844-
self.from_vasp_outcar(file_name, begin = begin, step = step)
845-
elif fmt == 'deepmd' or fmt == 'deepmd/raw':
846-
self.from_deepmd_raw(file_name, type_map = type_map)
847-
elif fmt == 'deepmd/npy':
848-
self.from_deepmd_comp(file_name, type_map = type_map)
849-
elif fmt == 'qe/cp/traj':
850-
self.from_qe_cp_traj(file_name, begin = begin, step = step)
851-
elif fmt == 'qe/pw/scf':
852-
self.from_qe_pw_scf(file_name)
853-
elif fmt.lower() == 'siesta/output':
854-
self.from_siesta_output(file_name)
855-
elif fmt.lower() == 'siesta/aimd_output':
856-
self.from_siesta_aiMD_output(file_name)
857-
elif fmt == 'gaussian/log':
858-
self.from_gaussian_log(file_name)
859-
elif fmt == 'gaussian/md':
860-
self.from_gaussian_log(file_name, md=True)
861-
elif fmt == 'cp2k/output':
862-
self.from_cp2k_output(file_name)
863-
elif fmt == 'cp2k/aimd_output':
864-
self.from_cp2k_aimd_output(file_dir=file_name)
865-
else :
866-
raise RuntimeError('unknow data format ' + fmt)
867-
875+
self.from_fmt(file_name, fmt, type_map=type_map, begin= begin, step=step)
868876
if type_map is not None:
869877
self.apply_type_map(type_map)
870878

879+
register_from_funcs = Register()
871880

872881
def __repr__(self):
873882
return self.__str__()
@@ -905,14 +914,16 @@ def has_virial(self) :
905914
# return ('virials' in self.data) and (len(self.data['virials']) > 0)
906915
return ('virials' in self.data)
907916

908-
917+
@register_from_funcs.register_funcs('cp2k/aimd_output')
909918
def from_cp2k_aimd_output(self, file_dir):
910919
xyz_file=glob.glob("{}/*pos*.xyz".format(file_dir))[0]
911920
log_file=glob.glob("{}/*.log".format(file_dir))[0]
912921
for info_dict in Cp2kSystems(log_file, xyz_file):
913922
l = LabeledSystem(data=info_dict)
914923
self.append(l)
915924

925+
@register_from_funcs.register_funcs('xml')
926+
@register_from_funcs.register_funcs('vasp/xml')
916927
def from_vasp_xml(self, file_name, begin = 0, step = 1) :
917928
self.data['atom_names'], \
918929
self.data['atom_types'], \
@@ -937,7 +948,8 @@ def from_vasp_xml(self, file_name, begin = 0, step = 1) :
937948
# rotate the system to lammps convention
938949
self.rot_lower_triangular()
939950

940-
951+
@register_from_funcs.register_funcs('outcar')
952+
@register_from_funcs.register_funcs('vasp/outcar')
941953
def from_vasp_outcar(self, file_name, begin = 0, step = 1) :
942954
# with open(file_name) as fp:
943955
# lines = [line.rstrip('\n') for line in fp]
@@ -979,17 +991,18 @@ def rot_frame_lower_triangular(self, f_idx = 0) :
979991
self.affine_map_fv(trans, f_idx = f_idx)
980992
return trans
981993

982-
994+
@register_from_funcs.register_funcs('deepmd/npy')
983995
def from_deepmd_comp(self, folder, type_map = None) :
984996
self.data = dpdata.deepmd.comp.to_system_data(folder, type_map = type_map, labels = True)
985997

986-
998+
@register_from_funcs.register_funcs('deepmd')
999+
@register_from_funcs.register_funcs('deepmd/raw')
9871000
def from_deepmd_raw(self, folder, type_map = None) :
9881001
tmp_data = dpdata.deepmd.raw.to_system_data(folder, type_map = type_map, labels = True)
9891002
if tmp_data is not None :
9901003
self.data = tmp_data
9911004

992-
1005+
@register_from_funcs.register_funcs('qe/cp/traj')
9931006
def from_qe_cp_traj(self, prefix, begin = 0, step = 1) :
9941007
self.data, cs = dpdata.qe.traj.to_system_data(prefix + '.in', prefix, begin = begin, step = step)
9951008
self.data['coords'] \
@@ -1001,6 +1014,7 @@ def from_qe_cp_traj(self, prefix, begin = 0, step = 1) :
10011014
assert(cs == es), "the step key between files are not consistent"
10021015
self.rot_lower_triangular()
10031016

1017+
@register_from_funcs.register_funcs('qe/pw/scf')
10041018
def from_qe_pw_scf(self, file_name) :
10051019
self.data['atom_names'], \
10061020
self.data['atom_numbs'], \
@@ -1013,6 +1027,7 @@ def from_qe_pw_scf(self, file_name) :
10131027
= dpdata.qe.scf.get_frame(file_name)
10141028
self.rot_lower_triangular()
10151029

1030+
@register_from_funcs.register_funcs('siesta/output')
10161031
def from_siesta_output(self, file_name) :
10171032
self.data['atom_names'], \
10181033
self.data['atom_numbs'], \
@@ -1025,6 +1040,7 @@ def from_siesta_output(self, file_name) :
10251040
= dpdata.siesta.output.obtain_frame(file_name)
10261041
# self.rot_lower_triangular()
10271042

1043+
@register_from_funcs.register_funcs('siesta/aimd_output')
10281044
def from_siesta_aiMD_output(self, file_name):
10291045
self.data['atom_names'], \
10301046
self.data['atom_numbs'], \
@@ -1036,14 +1052,19 @@ def from_siesta_aiMD_output(self, file_name):
10361052
self.data['virials'] \
10371053
= dpdata.siesta.aiMD_output.get_aiMD_frame(file_name)
10381054

1055+
@register_from_funcs.register_funcs('gaussian/log')
10391056
def from_gaussian_log(self, file_name, md=False):
10401057
try:
10411058
self.data = dpdata.gaussian.log.to_system_data(file_name, md=md)
10421059
except AssertionError:
10431060
self.data['energies'], self.data['forces']= [], []
10441061
self.data['nopbc'] = True
1062+
1063+
@register_from_funcs.register_funcs('gaussian/md')
1064+
def from_gaussian_md(self, file_name):
1065+
self.from_gaussian_log(file_name, md=True)
10451066

1046-
1067+
@register_from_funcs.register_funcs('cp2k/output')
10471068
def from_cp2k_output(self, file_name) :
10481069
self.data['atom_names'], \
10491070
self.data['atom_numbs'], \

tests/test_lammps_lmp_dump.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@ def setUp(self):
1414
self.system.from_lammps_lmp('tmp.lmp',
1515
type_map = ['O', 'H'])
1616

17+
class TestToFunc(unittest.TestCase, TestPOSCARoh):
18+
19+
def setUp(self):
20+
tmp_system = dpdata.System(os.path.join('poscars', 'conf.lmp'),
21+
type_map = ['O', 'H'])
22+
tmp_system.to('lammps/lmp', 'tmp.lmp')
23+
self.system = dpdata.System()
24+
self.system.from_fmt('tmp.lmp', fmt='lammps/lmp',
25+
type_map = ['O', 'H'])
1726

1827
if __name__ == '__main__':
1928
unittest.main()

0 commit comments

Comments
 (0)