1
1
#%%
2
2
import os
3
3
import glob
4
+ import inspect
4
5
import numpy as np
5
6
import dpdata .lammps .lmp
6
7
import dpdata .lammps .dump
23
24
from dpdata .periodic_table import Element
24
25
from dpdata .xyz .quip_gap_xyz import QuipGapxyzSystems
25
26
27
+
28
+ class Register :
29
+ def __init__ (self ):
30
+ self .funcs = {}
31
+
32
+ def register_funcs (self , fmt ):
33
+ def decorator (func ):
34
+ self .funcs [fmt ] = func
35
+ return func
36
+ return decorator
37
+
38
+
26
39
class System (MSONable ) :
27
40
'''
28
41
The data System
@@ -97,30 +110,25 @@ def __init__ (self,
97
110
return
98
111
if file_name is None :
99
112
return
100
- if fmt == 'auto' :
101
- fmt = os .path .basename (file_name ).split ('.' )[- 1 ]
102
- if fmt == 'lmp' or fmt == 'lammps/lmp' :
103
- self .from_lammps_lmp (file_name , type_map = type_map )
104
- elif fmt == 'dump' or fmt == 'lammps/dump' :
105
- self .from_lammps_dump (file_name , type_map = type_map , begin = begin , step = step )
106
- elif fmt .lower () == 'poscar' or fmt .lower () == 'contcar' or fmt .lower () == 'vasp/poscar' or fmt .lower () == 'vasp/contcar' :
107
- self .from_vasp_poscar (file_name )
108
- elif fmt == 'deepmd' or fmt == 'deepmd/raw' :
109
- self .from_deepmd_raw (file_name , type_map = type_map )
110
- elif fmt == 'deepmd/npy' :
111
- self .from_deepmd_comp (file_name , type_map = type_map )
112
- elif fmt == 'qe/cp/traj' :
113
- self .from_qe_cp_traj (file_name , begin = begin , step = step )
114
- elif fmt .lower () == 'siesta/output' :
115
- self .from_siesta_output (file_name )
116
- elif fmt .lower () == 'siesta/aimd_output' :
117
- self .from_siesta_aiMD_output (file_name )
118
- else :
119
- raise RuntimeError ('unknow data format ' + fmt )
113
+ self .from_fmt (file_name , fmt , type_map = type_map , begin = begin , step = step )
120
114
121
115
if type_map is not None :
122
116
self .apply_type_map (type_map )
123
117
118
+ register_from_funcs = Register ()
119
+
120
+ def from_fmt (self , file_name , fmt , ** kwargs ):
121
+ fmt = fmt .lower ()
122
+ if fmt == 'auto' :
123
+ fmt = os .path .basename (file_name ).split ('.' )[- 1 ].lower ()
124
+ from_funcs = self .register_from_funcs .funcs
125
+ if fmt in from_funcs :
126
+ func = from_funcs [fmt ]
127
+ args = inspect .getfullargspec (func ).args
128
+ kwargs = {kk : kwargs [kk ] for kk in kwargs if kk in args }
129
+ func (self , file_name , ** kwargs )
130
+ else :
131
+ raise RuntimeError ('unknow data format ' + fmt )
124
132
125
133
def __repr__ (self ):
126
134
return self .__str__ ()
@@ -414,7 +422,8 @@ def apply_pbc(self) :
414
422
ncoord = ncoord % 1
415
423
self .data ['coords' ] = np .matmul (ncoord , self .data ['cells' ])
416
424
417
-
425
+ @register_from_funcs .register_funcs ("lmp" )
426
+ @register_from_funcs .register_funcs ("lammps/lmp" )
418
427
def from_lammps_lmp (self , file_name , type_map = None ) :
419
428
with open (file_name ) as fp :
420
429
lines = [line .rstrip ('\n ' ) for line in fp ]
@@ -475,7 +484,8 @@ def to_lammps_lmp(self, file_name, frame_idx = 0) :
475
484
with open (file_name , 'w' ) as fp :
476
485
fp .write (w_str )
477
486
478
-
487
+ @register_from_funcs .register_funcs ('dump' )
488
+ @register_from_funcs .register_funcs ('lammps/dump' )
479
489
def from_lammps_dump (self ,
480
490
file_name ,
481
491
type_map = None ,
@@ -485,7 +495,10 @@ def from_lammps_dump (self,
485
495
self .data = dpdata .lammps .dump .system_data (lines , type_map )
486
496
self ._shift_orig_zero ()
487
497
488
-
498
+ @register_from_funcs .register_funcs ('poscar' )
499
+ @register_from_funcs .register_funcs ('contcar' )
500
+ @register_from_funcs .register_funcs ('vasp/poscar' )
501
+ @register_from_funcs .register_funcs ('vasp/contcar' )
489
502
def from_vasp_poscar (self , file_name ) :
490
503
with open (file_name ) as fp :
491
504
lines = [line .rstrip ('\n ' ) for line in fp ]
@@ -520,7 +533,7 @@ def to_vasp_poscar(self, file_name, frame_idx = 0) :
520
533
with open (file_name , 'w' ) as fp :
521
534
fp .write (w_str )
522
535
523
-
536
+ @ register_from_funcs . register_funcs ( 'qe/cp/traj' )
524
537
def from_qe_cp_traj (self ,
525
538
prefix ,
526
539
begin = 0 ,
@@ -532,10 +545,12 @@ def from_qe_cp_traj(self,
532
545
)
533
546
self .rot_lower_triangular ()
534
547
535
-
548
+ @ register_from_funcs . register_funcs ( 'deepmd/npy' )
536
549
def from_deepmd_comp (self , folder , type_map = None ) :
537
550
self .data = dpdata .deepmd .comp .to_system_data (folder , type_map = type_map , labels = False )
538
551
552
+ @register_from_funcs .register_funcs ('deepmd' )
553
+ @register_from_funcs .register_funcs ('deepmd/raw' )
539
554
def from_deepmd_raw (self , folder , type_map = None ) :
540
555
tmp_data = dpdata .deepmd .raw .to_system_data (folder , type_map = type_map , labels = False )
541
556
if tmp_data is not None :
@@ -569,6 +584,7 @@ def to_deepmd_raw(self, folder) :
569
584
"""
570
585
dpdata .deepmd .raw .dump (folder , self .data )
571
586
587
+ @register_from_funcs .register_funcs ('siesta/output' )
572
588
def from_siesta_output (self , fname ):
573
589
self .data ['atom_names' ], \
574
590
self .data ['atom_numbs' ], \
@@ -579,6 +595,7 @@ def from_siesta_output(self, fname):
579
595
= dpdata .siesta .output .obtain_frame (fname )
580
596
# self.rot_lower_triangular()
581
597
598
+ @register_from_funcs .register_funcs ('aimd/output' )
582
599
def from_siesta_aiMD_output (self , fname ):
583
600
self .data ['atom_names' ], \
584
601
self .data ['atom_numbs' ], \
@@ -836,38 +853,11 @@ def __init__ (self,
836
853
return
837
854
if file_name is None :
838
855
return
839
- if fmt == 'auto' :
840
- fmt = os .path .basename (file_name ).split ('.' )[- 1 ]
841
- if fmt == 'xml' or fmt == 'XML' or fmt == 'vasp/xml' :
842
- self .from_vasp_xml (file_name , begin = begin , step = step )
843
- elif fmt == 'outcar' or fmt == 'OUTCAR' or fmt == 'vasp/outcar' :
844
- self .from_vasp_outcar (file_name , begin = begin , step = step )
845
- elif fmt == 'deepmd' or fmt == 'deepmd/raw' :
846
- self .from_deepmd_raw (file_name , type_map = type_map )
847
- elif fmt == 'deepmd/npy' :
848
- self .from_deepmd_comp (file_name , type_map = type_map )
849
- elif fmt == 'qe/cp/traj' :
850
- self .from_qe_cp_traj (file_name , begin = begin , step = step )
851
- elif fmt == 'qe/pw/scf' :
852
- self .from_qe_pw_scf (file_name )
853
- elif fmt .lower () == 'siesta/output' :
854
- self .from_siesta_output (file_name )
855
- elif fmt .lower () == 'siesta/aimd_output' :
856
- self .from_siesta_aiMD_output (file_name )
857
- elif fmt == 'gaussian/log' :
858
- self .from_gaussian_log (file_name )
859
- elif fmt == 'gaussian/md' :
860
- self .from_gaussian_log (file_name , md = True )
861
- elif fmt == 'cp2k/output' :
862
- self .from_cp2k_output (file_name )
863
- elif fmt == 'cp2k/aimd_output' :
864
- self .from_cp2k_aimd_output (file_dir = file_name )
865
- else :
866
- raise RuntimeError ('unknow data format ' + fmt )
867
-
856
+ self .from_fmt (file_name , fmt , type_map = type_map , begin = begin , step = step )
868
857
if type_map is not None :
869
858
self .apply_type_map (type_map )
870
859
860
+ register_from_funcs = Register ()
871
861
872
862
def __repr__ (self ):
873
863
return self .__str__ ()
@@ -905,14 +895,16 @@ def has_virial(self) :
905
895
# return ('virials' in self.data) and (len(self.data['virials']) > 0)
906
896
return ('virials' in self .data )
907
897
908
-
898
+ @ register_from_funcs . register_funcs ( 'cp2k/aimd_output' )
909
899
def from_cp2k_aimd_output (self , file_dir ):
910
900
xyz_file = glob .glob ("{}/*pos*.xyz" .format (file_dir ))[0 ]
911
901
log_file = glob .glob ("{}/*.log" .format (file_dir ))[0 ]
912
902
for info_dict in Cp2kSystems (log_file , xyz_file ):
913
903
l = LabeledSystem (data = info_dict )
914
904
self .append (l )
915
905
906
+ @register_from_funcs .register_funcs ('xml' )
907
+ @register_from_funcs .register_funcs ('vasp/xml' )
916
908
def from_vasp_xml (self , file_name , begin = 0 , step = 1 ) :
917
909
self .data ['atom_names' ], \
918
910
self .data ['atom_types' ], \
@@ -937,7 +929,8 @@ def from_vasp_xml(self, file_name, begin = 0, step = 1) :
937
929
# rotate the system to lammps convention
938
930
self .rot_lower_triangular ()
939
931
940
-
932
+ @register_from_funcs .register_funcs ('outcar' )
933
+ @register_from_funcs .register_funcs ('vasp/outcar' )
941
934
def from_vasp_outcar (self , file_name , begin = 0 , step = 1 ) :
942
935
# with open(file_name) as fp:
943
936
# lines = [line.rstrip('\n') for line in fp]
@@ -979,17 +972,18 @@ def rot_frame_lower_triangular(self, f_idx = 0) :
979
972
self .affine_map_fv (trans , f_idx = f_idx )
980
973
return trans
981
974
982
-
975
+ @ register_from_funcs . register_funcs ( 'deepmd/npy' )
983
976
def from_deepmd_comp (self , folder , type_map = None ) :
984
977
self .data = dpdata .deepmd .comp .to_system_data (folder , type_map = type_map , labels = True )
985
978
986
-
979
+ @register_from_funcs .register_funcs ('deepmd' )
980
+ @register_from_funcs .register_funcs ('deepmd/raw' )
987
981
def from_deepmd_raw (self , folder , type_map = None ) :
988
982
tmp_data = dpdata .deepmd .raw .to_system_data (folder , type_map = type_map , labels = True )
989
983
if tmp_data is not None :
990
984
self .data = tmp_data
991
985
992
-
986
+ @ register_from_funcs . register_funcs ( 'qe/cp/traj' )
993
987
def from_qe_cp_traj (self , prefix , begin = 0 , step = 1 ) :
994
988
self .data , cs = dpdata .qe .traj .to_system_data (prefix + '.in' , prefix , begin = begin , step = step )
995
989
self .data ['coords' ] \
@@ -1001,6 +995,7 @@ def from_qe_cp_traj(self, prefix, begin = 0, step = 1) :
1001
995
assert (cs == es ), "the step key between files are not consistent"
1002
996
self .rot_lower_triangular ()
1003
997
998
+ @register_from_funcs .register_funcs ('qe/pw/scf' )
1004
999
def from_qe_pw_scf (self , file_name ) :
1005
1000
self .data ['atom_names' ], \
1006
1001
self .data ['atom_numbs' ], \
@@ -1013,6 +1008,7 @@ def from_qe_pw_scf(self, file_name) :
1013
1008
= dpdata .qe .scf .get_frame (file_name )
1014
1009
self .rot_lower_triangular ()
1015
1010
1011
+ @register_from_funcs .register_funcs ('siesta/output' )
1016
1012
def from_siesta_output (self , file_name ) :
1017
1013
self .data ['atom_names' ], \
1018
1014
self .data ['atom_numbs' ], \
@@ -1025,6 +1021,7 @@ def from_siesta_output(self, file_name) :
1025
1021
= dpdata .siesta .output .obtain_frame (file_name )
1026
1022
# self.rot_lower_triangular()
1027
1023
1024
+ @register_from_funcs .register_funcs ('siesta/aimd_output' )
1028
1025
def from_siesta_aiMD_output (self , file_name ):
1029
1026
self .data ['atom_names' ], \
1030
1027
self .data ['atom_numbs' ], \
@@ -1036,14 +1033,19 @@ def from_siesta_aiMD_output(self, file_name):
1036
1033
self .data ['virials' ] \
1037
1034
= dpdata .siesta .aiMD_output .get_aiMD_frame (file_name )
1038
1035
1036
+ @register_from_funcs .register_funcs ('gaussian/log' )
1039
1037
def from_gaussian_log (self , file_name , md = False ):
1040
1038
try :
1041
1039
self .data = dpdata .gaussian .log .to_system_data (file_name , md = md )
1042
1040
except AssertionError :
1043
1041
self .data ['energies' ], self .data ['forces' ]= [], []
1044
1042
self .data ['nopbc' ] = True
1043
+
1044
+ @register_from_funcs .register_funcs ('gaussian/md' )
1045
+ def from_gaussian_md (self , file_name ):
1046
+ self .from_gaussian_log (file_name , md = True )
1045
1047
1046
-
1048
+ @ register_from_funcs . register_funcs ( 'cp2k/output' )
1047
1049
def from_cp2k_output (self , file_name ) :
1048
1050
self .data ['atom_names' ], \
1049
1051
self .data ['atom_numbs' ], \
0 commit comments