Skip to content

Commit 1766182

Browse files
committed
add function register
1 parent 21b8b28 commit 1766182

File tree

1 file changed

+62
-60
lines changed

1 file changed

+62
-60
lines changed

dpdata/system.py

Lines changed: 62 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#%%
22
import os
33
import glob
4+
import inspect
45
import numpy as np
56
import dpdata.lammps.lmp
67
import dpdata.lammps.dump
@@ -23,6 +24,18 @@
2324
from dpdata.periodic_table import Element
2425
from dpdata.xyz.quip_gap_xyz import QuipGapxyzSystems
2526

27+
28+
class Register:
29+
def __init__(self):
30+
self.funcs = {}
31+
32+
def register_funcs(self, fmt):
33+
def decorator(func):
34+
self.funcs[fmt] = func
35+
return func
36+
return decorator
37+
38+
2639
class System (MSONable) :
2740
'''
2841
The data System
@@ -97,30 +110,25 @@ def __init__ (self,
97110
return
98111
if file_name is None :
99112
return
100-
if fmt == 'auto':
101-
fmt = os.path.basename(file_name).split('.')[-1]
102-
if fmt == 'lmp' or fmt == 'lammps/lmp' :
103-
self.from_lammps_lmp(file_name, type_map = type_map)
104-
elif fmt == 'dump' or fmt == 'lammps/dump' :
105-
self.from_lammps_dump(file_name, type_map = type_map, begin = begin, step = step)
106-
elif fmt.lower() == 'poscar' or fmt.lower() == 'contcar' or fmt.lower() == 'vasp/poscar' or fmt.lower() == 'vasp/contcar':
107-
self.from_vasp_poscar(file_name)
108-
elif fmt == 'deepmd' or fmt == 'deepmd/raw':
109-
self.from_deepmd_raw(file_name, type_map = type_map)
110-
elif fmt == 'deepmd/npy':
111-
self.from_deepmd_comp(file_name, type_map = type_map)
112-
elif fmt == 'qe/cp/traj':
113-
self.from_qe_cp_traj(file_name, begin = begin, step = step)
114-
elif fmt.lower() == 'siesta/output':
115-
self.from_siesta_output(file_name)
116-
elif fmt.lower() == 'siesta/aimd_output':
117-
self.from_siesta_aiMD_output(file_name)
118-
else :
119-
raise RuntimeError('unknow data format ' + fmt)
113+
self.from_fmt(file_name, fmt, type_map=type_map, begin= begin, step=step)
120114

121115
if type_map is not None:
122116
self.apply_type_map(type_map)
123117

118+
register_from_funcs = Register()
119+
120+
def from_fmt(self, file_name, fmt, **kwargs):
121+
fmt = fmt.lower()
122+
if fmt == 'auto':
123+
fmt = os.path.basename(file_name).split('.')[-1].lower()
124+
from_funcs = self.register_from_funcs.funcs
125+
if fmt in from_funcs:
126+
func = from_funcs[fmt]
127+
args = inspect.getfullargspec(func).args
128+
kwargs = {kk: kwargs[kk] for kk in kwargs if kk in args}
129+
func(self, file_name, **kwargs)
130+
else :
131+
raise RuntimeError('unknow data format ' + fmt)
124132

125133
def __repr__(self):
126134
return self.__str__()
@@ -414,7 +422,8 @@ def apply_pbc(self) :
414422
ncoord = ncoord % 1
415423
self.data['coords'] = np.matmul(ncoord, self.data['cells'])
416424

417-
425+
@register_from_funcs.register_funcs("lmp")
426+
@register_from_funcs.register_funcs("lammps/lmp")
418427
def from_lammps_lmp (self, file_name, type_map = None) :
419428
with open(file_name) as fp:
420429
lines = [line.rstrip('\n') for line in fp]
@@ -475,7 +484,8 @@ def to_lammps_lmp(self, file_name, frame_idx = 0) :
475484
with open(file_name, 'w') as fp:
476485
fp.write(w_str)
477486

478-
487+
@register_from_funcs.register_funcs('dump')
488+
@register_from_funcs.register_funcs('lammps/dump')
479489
def from_lammps_dump (self,
480490
file_name,
481491
type_map = None,
@@ -485,7 +495,10 @@ def from_lammps_dump (self,
485495
self.data = dpdata.lammps.dump.system_data(lines, type_map)
486496
self._shift_orig_zero()
487497

488-
498+
@register_from_funcs.register_funcs('poscar')
499+
@register_from_funcs.register_funcs('contcar')
500+
@register_from_funcs.register_funcs('vasp/poscar')
501+
@register_from_funcs.register_funcs('vasp/contcar')
489502
def from_vasp_poscar(self, file_name) :
490503
with open(file_name) as fp:
491504
lines = [line.rstrip('\n') for line in fp]
@@ -520,7 +533,7 @@ def to_vasp_poscar(self, file_name, frame_idx = 0) :
520533
with open(file_name, 'w') as fp:
521534
fp.write(w_str)
522535

523-
536+
@register_from_funcs.register_funcs('qe/cp/traj')
524537
def from_qe_cp_traj(self,
525538
prefix,
526539
begin = 0,
@@ -532,10 +545,12 @@ def from_qe_cp_traj(self,
532545
)
533546
self.rot_lower_triangular()
534547

535-
548+
@register_from_funcs.register_funcs('deepmd/npy')
536549
def from_deepmd_comp(self, folder, type_map = None) :
537550
self.data = dpdata.deepmd.comp.to_system_data(folder, type_map = type_map, labels = False)
538551

552+
@register_from_funcs.register_funcs('deepmd')
553+
@register_from_funcs.register_funcs('deepmd/raw')
539554
def from_deepmd_raw(self, folder, type_map = None) :
540555
tmp_data = dpdata.deepmd.raw.to_system_data(folder, type_map = type_map, labels = False)
541556
if tmp_data is not None :
@@ -569,6 +584,7 @@ def to_deepmd_raw(self, folder) :
569584
"""
570585
dpdata.deepmd.raw.dump(folder, self.data)
571586

587+
@register_from_funcs.register_funcs('siesta/output')
572588
def from_siesta_output(self, fname):
573589
self.data['atom_names'], \
574590
self.data['atom_numbs'], \
@@ -579,6 +595,7 @@ def from_siesta_output(self, fname):
579595
= dpdata.siesta.output.obtain_frame(fname)
580596
# self.rot_lower_triangular()
581597

598+
@register_from_funcs.register_funcs('aimd/output')
582599
def from_siesta_aiMD_output(self, fname):
583600
self.data['atom_names'], \
584601
self.data['atom_numbs'], \
@@ -836,38 +853,11 @@ def __init__ (self,
836853
return
837854
if file_name is None :
838855
return
839-
if fmt == 'auto':
840-
fmt = os.path.basename(file_name).split('.')[-1]
841-
if fmt == 'xml' or fmt == 'XML' or fmt == 'vasp/xml' :
842-
self.from_vasp_xml(file_name, begin = begin, step = step)
843-
elif fmt == 'outcar' or fmt == 'OUTCAR' or fmt == 'vasp/outcar' :
844-
self.from_vasp_outcar(file_name, begin = begin, step = step)
845-
elif fmt == 'deepmd' or fmt == 'deepmd/raw':
846-
self.from_deepmd_raw(file_name, type_map = type_map)
847-
elif fmt == 'deepmd/npy':
848-
self.from_deepmd_comp(file_name, type_map = type_map)
849-
elif fmt == 'qe/cp/traj':
850-
self.from_qe_cp_traj(file_name, begin = begin, step = step)
851-
elif fmt == 'qe/pw/scf':
852-
self.from_qe_pw_scf(file_name)
853-
elif fmt.lower() == 'siesta/output':
854-
self.from_siesta_output(file_name)
855-
elif fmt.lower() == 'siesta/aimd_output':
856-
self.from_siesta_aiMD_output(file_name)
857-
elif fmt == 'gaussian/log':
858-
self.from_gaussian_log(file_name)
859-
elif fmt == 'gaussian/md':
860-
self.from_gaussian_log(file_name, md=True)
861-
elif fmt == 'cp2k/output':
862-
self.from_cp2k_output(file_name)
863-
elif fmt == 'cp2k/aimd_output':
864-
self.from_cp2k_aimd_output(file_dir=file_name)
865-
else :
866-
raise RuntimeError('unknow data format ' + fmt)
867-
856+
self.from_fmt(file_name, fmt, type_map=type_map, begin= begin, step=step)
868857
if type_map is not None:
869858
self.apply_type_map(type_map)
870859

860+
register_from_funcs = Register()
871861

872862
def __repr__(self):
873863
return self.__str__()
@@ -905,14 +895,16 @@ def has_virial(self) :
905895
# return ('virials' in self.data) and (len(self.data['virials']) > 0)
906896
return ('virials' in self.data)
907897

908-
898+
@register_from_funcs.register_funcs('cp2k/aimd_output')
909899
def from_cp2k_aimd_output(self, file_dir):
910900
xyz_file=glob.glob("{}/*pos*.xyz".format(file_dir))[0]
911901
log_file=glob.glob("{}/*.log".format(file_dir))[0]
912902
for info_dict in Cp2kSystems(log_file, xyz_file):
913903
l = LabeledSystem(data=info_dict)
914904
self.append(l)
915905

906+
@register_from_funcs.register_funcs('xml')
907+
@register_from_funcs.register_funcs('vasp/xml')
916908
def from_vasp_xml(self, file_name, begin = 0, step = 1) :
917909
self.data['atom_names'], \
918910
self.data['atom_types'], \
@@ -937,7 +929,8 @@ def from_vasp_xml(self, file_name, begin = 0, step = 1) :
937929
# rotate the system to lammps convention
938930
self.rot_lower_triangular()
939931

940-
932+
@register_from_funcs.register_funcs('outcar')
933+
@register_from_funcs.register_funcs('vasp/outcar')
941934
def from_vasp_outcar(self, file_name, begin = 0, step = 1) :
942935
# with open(file_name) as fp:
943936
# lines = [line.rstrip('\n') for line in fp]
@@ -979,17 +972,18 @@ def rot_frame_lower_triangular(self, f_idx = 0) :
979972
self.affine_map_fv(trans, f_idx = f_idx)
980973
return trans
981974

982-
975+
@register_from_funcs.register_funcs('deepmd/npy')
983976
def from_deepmd_comp(self, folder, type_map = None) :
984977
self.data = dpdata.deepmd.comp.to_system_data(folder, type_map = type_map, labels = True)
985978

986-
979+
@register_from_funcs.register_funcs('deepmd')
980+
@register_from_funcs.register_funcs('deepmd/raw')
987981
def from_deepmd_raw(self, folder, type_map = None) :
988982
tmp_data = dpdata.deepmd.raw.to_system_data(folder, type_map = type_map, labels = True)
989983
if tmp_data is not None :
990984
self.data = tmp_data
991985

992-
986+
@register_from_funcs.register_funcs('qe/cp/traj')
993987
def from_qe_cp_traj(self, prefix, begin = 0, step = 1) :
994988
self.data, cs = dpdata.qe.traj.to_system_data(prefix + '.in', prefix, begin = begin, step = step)
995989
self.data['coords'] \
@@ -1001,6 +995,7 @@ def from_qe_cp_traj(self, prefix, begin = 0, step = 1) :
1001995
assert(cs == es), "the step key between files are not consistent"
1002996
self.rot_lower_triangular()
1003997

998+
@register_from_funcs.register_funcs('qe/pw/scf')
1004999
def from_qe_pw_scf(self, file_name) :
10051000
self.data['atom_names'], \
10061001
self.data['atom_numbs'], \
@@ -1013,6 +1008,7 @@ def from_qe_pw_scf(self, file_name) :
10131008
= dpdata.qe.scf.get_frame(file_name)
10141009
self.rot_lower_triangular()
10151010

1011+
@register_from_funcs.register_funcs('siesta/output')
10161012
def from_siesta_output(self, file_name) :
10171013
self.data['atom_names'], \
10181014
self.data['atom_numbs'], \
@@ -1025,6 +1021,7 @@ def from_siesta_output(self, file_name) :
10251021
= dpdata.siesta.output.obtain_frame(file_name)
10261022
# self.rot_lower_triangular()
10271023

1024+
@register_from_funcs.register_funcs('siesta/aimd_output')
10281025
def from_siesta_aiMD_output(self, file_name):
10291026
self.data['atom_names'], \
10301027
self.data['atom_numbs'], \
@@ -1036,14 +1033,19 @@ def from_siesta_aiMD_output(self, file_name):
10361033
self.data['virials'] \
10371034
= dpdata.siesta.aiMD_output.get_aiMD_frame(file_name)
10381035

1036+
@register_from_funcs.register_funcs('gaussian/log')
10391037
def from_gaussian_log(self, file_name, md=False):
10401038
try:
10411039
self.data = dpdata.gaussian.log.to_system_data(file_name, md=md)
10421040
except AssertionError:
10431041
self.data['energies'], self.data['forces']= [], []
10441042
self.data['nopbc'] = True
1043+
1044+
@register_from_funcs.register_funcs('gaussian/md')
1045+
def from_gaussian_md(self, file_name):
1046+
self.from_gaussian_log(file_name, md=True)
10451047

1046-
1048+
@register_from_funcs.register_funcs('cp2k/output')
10471049
def from_cp2k_output(self, file_name) :
10481050
self.data['atom_names'], \
10491051
self.data['atom_numbs'], \

0 commit comments

Comments
 (0)