1
1
#%%
2
2
import os
3
3
import glob
4
+ import inspect
4
5
import numpy as np
5
6
import dpdata .lammps .lmp
6
7
import dpdata .lammps .dump
23
24
from dpdata .periodic_table import Element
24
25
from dpdata .xyz .quip_gap_xyz import QuipGapxyzSystems
25
26
27
+
28
+ class Register :
29
+ def __init__ (self ):
30
+ self .funcs = {}
31
+
32
+ def register_funcs (self , fmt ):
33
+ def decorator (func ):
34
+ self .funcs [fmt ] = func
35
+ return func
36
+ return decorator
37
+
38
+
26
39
class System (MSONable ) :
27
40
'''
28
41
The data System
@@ -97,30 +110,37 @@ def __init__ (self,
97
110
return
98
111
if file_name is None :
99
112
return
100
- if fmt == 'auto' :
101
- fmt = os .path .basename (file_name ).split ('.' )[- 1 ]
102
- if fmt == 'lmp' or fmt == 'lammps/lmp' :
103
- self .from_lammps_lmp (file_name , type_map = type_map )
104
- elif fmt == 'dump' or fmt == 'lammps/dump' :
105
- self .from_lammps_dump (file_name , type_map = type_map , begin = begin , step = step )
106
- elif fmt .lower () == 'poscar' or fmt .lower () == 'contcar' or fmt .lower () == 'vasp/poscar' or fmt .lower () == 'vasp/contcar' :
107
- self .from_vasp_poscar (file_name )
108
- elif fmt == 'deepmd' or fmt == 'deepmd/raw' :
109
- self .from_deepmd_raw (file_name , type_map = type_map )
110
- elif fmt == 'deepmd/npy' :
111
- self .from_deepmd_comp (file_name , type_map = type_map )
112
- elif fmt == 'qe/cp/traj' :
113
- self .from_qe_cp_traj (file_name , begin = begin , step = step )
114
- elif fmt .lower () == 'siesta/output' :
115
- self .from_siesta_output (file_name )
116
- elif fmt .lower () == 'siesta/aimd_output' :
117
- self .from_siesta_aiMD_output (file_name )
118
- else :
119
- raise RuntimeError ('unknow data format ' + fmt )
113
+ self .from_fmt (file_name , fmt , type_map = type_map , begin = begin , step = step )
120
114
121
115
if type_map is not None :
122
116
self .apply_type_map (type_map )
123
117
118
+ register_from_funcs = Register ()
119
+ register_to_funcs = Register ()
120
+
121
+ def from_fmt (self , file_name , fmt = 'auto' , ** kwargs ):
122
+ fmt = fmt .lower ()
123
+ if fmt == 'auto' :
124
+ fmt = os .path .basename (file_name ).split ('.' )[- 1 ].lower ()
125
+ from_funcs = self .register_from_funcs .funcs
126
+ if fmt in from_funcs :
127
+ func = from_funcs [fmt ]
128
+ args = inspect .getfullargspec (func ).args
129
+ kwargs = {kk : kwargs [kk ] for kk in kwargs if kk in args }
130
+ func (self , file_name , ** kwargs )
131
+ else :
132
+ raise RuntimeError ('unknow data format ' + fmt )
133
+
134
+ def to (self , fmt , * args , ** kwargs ):
135
+ fmt = fmt .lower ()
136
+ to_funcs = self .register_to_funcs .funcs
137
+ if fmt in to_funcs :
138
+ func = to_funcs [fmt ]
139
+ func_args = inspect .getfullargspec (func ).args
140
+ kwargs = {kk : kwargs [kk ] for kk in kwargs if kk in func_args }
141
+ func (self , * args , ** kwargs )
142
+ else :
143
+ raise RuntimeError ('unknow data format %s. Accepted format:' % (fmt , " " .join (to_funcs )))
124
144
125
145
def __repr__ (self ):
126
146
return self .__str__ ()
@@ -207,7 +227,7 @@ def map_atom_types(self,type_map=None):
207
227
208
228
return new_atom_types
209
229
210
-
230
+ @ register_to_funcs . register_funcs ( "list" )
211
231
def to_list (self ):
212
232
"""
213
233
convert system to list, usefull for data collection
@@ -414,13 +434,15 @@ def apply_pbc(self) :
414
434
ncoord = ncoord % 1
415
435
self .data ['coords' ] = np .matmul (ncoord , self .data ['cells' ])
416
436
417
-
437
+ @register_from_funcs .register_funcs ("lmp" )
438
+ @register_from_funcs .register_funcs ("lammps/lmp" )
418
439
def from_lammps_lmp (self , file_name , type_map = None ) :
419
440
with open (file_name ) as fp :
420
441
lines = [line .rstrip ('\n ' ) for line in fp ]
421
442
self .data = dpdata .lammps .lmp .to_system_data (lines , type_map )
422
443
self ._shift_orig_zero ()
423
444
445
+ @register_to_funcs .register_funcs ("pymatgen/structure" )
424
446
def to_pymatgen_structure (self ):
425
447
'''
426
448
convert System to Pymatgen Structure obj
@@ -440,6 +462,7 @@ def to_pymatgen_structure(self):
440
462
structures .append (structure )
441
463
return structures
442
464
465
+ @register_to_funcs .register_funcs ("ase/structure" )
443
466
def to_ase_structure (self ):
444
467
'''
445
468
convert System to ASE Atom obj
@@ -459,6 +482,7 @@ def to_ase_structure(self):
459
482
structures .append (structure )
460
483
return structures
461
484
485
+ @register_to_funcs .register_funcs ("lammps/lmp" )
462
486
def to_lammps_lmp (self , file_name , frame_idx = 0 ) :
463
487
"""
464
488
Dump the system in lammps data format
@@ -475,7 +499,8 @@ def to_lammps_lmp(self, file_name, frame_idx = 0) :
475
499
with open (file_name , 'w' ) as fp :
476
500
fp .write (w_str )
477
501
478
-
502
+ @register_from_funcs .register_funcs ('dump' )
503
+ @register_from_funcs .register_funcs ('lammps/dump' )
479
504
def from_lammps_dump (self ,
480
505
file_name ,
481
506
type_map = None ,
@@ -485,13 +510,17 @@ def from_lammps_dump (self,
485
510
self .data = dpdata .lammps .dump .system_data (lines , type_map )
486
511
self ._shift_orig_zero ()
487
512
488
-
513
+ @register_from_funcs .register_funcs ('poscar' )
514
+ @register_from_funcs .register_funcs ('contcar' )
515
+ @register_from_funcs .register_funcs ('vasp/poscar' )
516
+ @register_from_funcs .register_funcs ('vasp/contcar' )
489
517
def from_vasp_poscar (self , file_name ) :
490
518
with open (file_name ) as fp :
491
519
lines = [line .rstrip ('\n ' ) for line in fp ]
492
520
self .data = dpdata .vasp .poscar .to_system_data (lines )
493
521
self .rot_lower_triangular ()
494
522
523
+ @register_to_funcs .register_funcs ("vasp/string" )
495
524
def to_vasp_string (self , frame_idx = 0 ):
496
525
"""
497
526
Dump the system in vasp POSCAR format string
@@ -505,6 +534,7 @@ def to_vasp_string(self, frame_idx=0):
505
534
w_str = dpdata .vasp .poscar .from_system_data (self .data , frame_idx )
506
535
return w_str
507
536
537
+ @register_to_funcs .register_funcs ("vasp/poscar" )
508
538
def to_vasp_poscar (self , file_name , frame_idx = 0 ) :
509
539
"""
510
540
Dump the system in vasp POSCAR format
@@ -520,7 +550,7 @@ def to_vasp_poscar(self, file_name, frame_idx = 0) :
520
550
with open (file_name , 'w' ) as fp :
521
551
fp .write (w_str )
522
552
523
-
553
+ @ register_from_funcs . register_funcs ( 'qe/cp/traj' )
524
554
def from_qe_cp_traj (self ,
525
555
prefix ,
526
556
begin = 0 ,
@@ -532,15 +562,18 @@ def from_qe_cp_traj(self,
532
562
)
533
563
self .rot_lower_triangular ()
534
564
535
-
565
+ @ register_from_funcs . register_funcs ( 'deepmd/npy' )
536
566
def from_deepmd_comp (self , folder , type_map = None ) :
537
567
self .data = dpdata .deepmd .comp .to_system_data (folder , type_map = type_map , labels = False )
538
568
569
+ @register_from_funcs .register_funcs ('deepmd' )
570
+ @register_from_funcs .register_funcs ('deepmd/raw' )
539
571
def from_deepmd_raw (self , folder , type_map = None ) :
540
572
tmp_data = dpdata .deepmd .raw .to_system_data (folder , type_map = type_map , labels = False )
541
573
if tmp_data is not None :
542
574
self .data = tmp_data
543
575
576
+ @register_to_funcs .register_funcs ("deepmd/npy" )
544
577
def to_deepmd_npy (self , folder , set_size = 5000 , prec = np .float32 ) :
545
578
"""
546
579
Dump the system in deepmd compressed format (numpy binary) to `folder`.
@@ -563,12 +596,14 @@ def to_deepmd_npy(self, folder, set_size = 5000, prec=np.float32) :
563
596
set_size = set_size ,
564
597
comp_prec = prec )
565
598
599
+ @register_to_funcs .register_funcs ("deepmd/raw" )
566
600
def to_deepmd_raw (self , folder ) :
567
601
"""
568
602
Dump the system in deepmd raw format to `folder`
569
603
"""
570
604
dpdata .deepmd .raw .dump (folder , self .data )
571
605
606
+ @register_from_funcs .register_funcs ('siesta/output' )
572
607
def from_siesta_output (self , fname ):
573
608
self .data ['atom_names' ], \
574
609
self .data ['atom_numbs' ], \
@@ -579,6 +614,7 @@ def from_siesta_output(self, fname):
579
614
= dpdata .siesta .output .obtain_frame (fname )
580
615
# self.rot_lower_triangular()
581
616
617
+ @register_from_funcs .register_funcs ('siesta/aimd_output' )
582
618
def from_siesta_aiMD_output (self , fname ):
583
619
self .data ['atom_names' ], \
584
620
self .data ['atom_numbs' ], \
@@ -836,38 +872,11 @@ def __init__ (self,
836
872
return
837
873
if file_name is None :
838
874
return
839
- if fmt == 'auto' :
840
- fmt = os .path .basename (file_name ).split ('.' )[- 1 ]
841
- if fmt == 'xml' or fmt == 'XML' or fmt == 'vasp/xml' :
842
- self .from_vasp_xml (file_name , begin = begin , step = step )
843
- elif fmt == 'outcar' or fmt == 'OUTCAR' or fmt == 'vasp/outcar' :
844
- self .from_vasp_outcar (file_name , begin = begin , step = step )
845
- elif fmt == 'deepmd' or fmt == 'deepmd/raw' :
846
- self .from_deepmd_raw (file_name , type_map = type_map )
847
- elif fmt == 'deepmd/npy' :
848
- self .from_deepmd_comp (file_name , type_map = type_map )
849
- elif fmt == 'qe/cp/traj' :
850
- self .from_qe_cp_traj (file_name , begin = begin , step = step )
851
- elif fmt == 'qe/pw/scf' :
852
- self .from_qe_pw_scf (file_name )
853
- elif fmt .lower () == 'siesta/output' :
854
- self .from_siesta_output (file_name )
855
- elif fmt .lower () == 'siesta/aimd_output' :
856
- self .from_siesta_aiMD_output (file_name )
857
- elif fmt == 'gaussian/log' :
858
- self .from_gaussian_log (file_name )
859
- elif fmt == 'gaussian/md' :
860
- self .from_gaussian_log (file_name , md = True )
861
- elif fmt == 'cp2k/output' :
862
- self .from_cp2k_output (file_name )
863
- elif fmt == 'cp2k/aimd_output' :
864
- self .from_cp2k_aimd_output (file_dir = file_name )
865
- else :
866
- raise RuntimeError ('unknow data format ' + fmt )
867
-
875
+ self .from_fmt (file_name , fmt , type_map = type_map , begin = begin , step = step )
868
876
if type_map is not None :
869
877
self .apply_type_map (type_map )
870
878
879
+ register_from_funcs = Register ()
871
880
872
881
def __repr__ (self ):
873
882
return self .__str__ ()
@@ -905,14 +914,16 @@ def has_virial(self) :
905
914
# return ('virials' in self.data) and (len(self.data['virials']) > 0)
906
915
return ('virials' in self .data )
907
916
908
-
917
+ @ register_from_funcs . register_funcs ( 'cp2k/aimd_output' )
909
918
def from_cp2k_aimd_output (self , file_dir ):
910
919
xyz_file = glob .glob ("{}/*pos*.xyz" .format (file_dir ))[0 ]
911
920
log_file = glob .glob ("{}/*.log" .format (file_dir ))[0 ]
912
921
for info_dict in Cp2kSystems (log_file , xyz_file ):
913
922
l = LabeledSystem (data = info_dict )
914
923
self .append (l )
915
924
925
+ @register_from_funcs .register_funcs ('xml' )
926
+ @register_from_funcs .register_funcs ('vasp/xml' )
916
927
def from_vasp_xml (self , file_name , begin = 0 , step = 1 ) :
917
928
self .data ['atom_names' ], \
918
929
self .data ['atom_types' ], \
@@ -937,7 +948,8 @@ def from_vasp_xml(self, file_name, begin = 0, step = 1) :
937
948
# rotate the system to lammps convention
938
949
self .rot_lower_triangular ()
939
950
940
-
951
+ @register_from_funcs .register_funcs ('outcar' )
952
+ @register_from_funcs .register_funcs ('vasp/outcar' )
941
953
def from_vasp_outcar (self , file_name , begin = 0 , step = 1 ) :
942
954
# with open(file_name) as fp:
943
955
# lines = [line.rstrip('\n') for line in fp]
@@ -979,17 +991,18 @@ def rot_frame_lower_triangular(self, f_idx = 0) :
979
991
self .affine_map_fv (trans , f_idx = f_idx )
980
992
return trans
981
993
982
-
994
+ @ register_from_funcs . register_funcs ( 'deepmd/npy' )
983
995
def from_deepmd_comp (self , folder , type_map = None ) :
984
996
self .data = dpdata .deepmd .comp .to_system_data (folder , type_map = type_map , labels = True )
985
997
986
-
998
+ @register_from_funcs .register_funcs ('deepmd' )
999
+ @register_from_funcs .register_funcs ('deepmd/raw' )
987
1000
def from_deepmd_raw (self , folder , type_map = None ) :
988
1001
tmp_data = dpdata .deepmd .raw .to_system_data (folder , type_map = type_map , labels = True )
989
1002
if tmp_data is not None :
990
1003
self .data = tmp_data
991
1004
992
-
1005
+ @ register_from_funcs . register_funcs ( 'qe/cp/traj' )
993
1006
def from_qe_cp_traj (self , prefix , begin = 0 , step = 1 ) :
994
1007
self .data , cs = dpdata .qe .traj .to_system_data (prefix + '.in' , prefix , begin = begin , step = step )
995
1008
self .data ['coords' ] \
@@ -1001,6 +1014,7 @@ def from_qe_cp_traj(self, prefix, begin = 0, step = 1) :
1001
1014
assert (cs == es ), "the step key between files are not consistent"
1002
1015
self .rot_lower_triangular ()
1003
1016
1017
+ @register_from_funcs .register_funcs ('qe/pw/scf' )
1004
1018
def from_qe_pw_scf (self , file_name ) :
1005
1019
self .data ['atom_names' ], \
1006
1020
self .data ['atom_numbs' ], \
@@ -1013,6 +1027,7 @@ def from_qe_pw_scf(self, file_name) :
1013
1027
= dpdata .qe .scf .get_frame (file_name )
1014
1028
self .rot_lower_triangular ()
1015
1029
1030
+ @register_from_funcs .register_funcs ('siesta/output' )
1016
1031
def from_siesta_output (self , file_name ) :
1017
1032
self .data ['atom_names' ], \
1018
1033
self .data ['atom_numbs' ], \
@@ -1025,6 +1040,7 @@ def from_siesta_output(self, file_name) :
1025
1040
= dpdata .siesta .output .obtain_frame (file_name )
1026
1041
# self.rot_lower_triangular()
1027
1042
1043
+ @register_from_funcs .register_funcs ('siesta/aimd_output' )
1028
1044
def from_siesta_aiMD_output (self , file_name ):
1029
1045
self .data ['atom_names' ], \
1030
1046
self .data ['atom_numbs' ], \
@@ -1036,14 +1052,19 @@ def from_siesta_aiMD_output(self, file_name):
1036
1052
self .data ['virials' ] \
1037
1053
= dpdata .siesta .aiMD_output .get_aiMD_frame (file_name )
1038
1054
1055
+ @register_from_funcs .register_funcs ('gaussian/log' )
1039
1056
def from_gaussian_log (self , file_name , md = False ):
1040
1057
try :
1041
1058
self .data = dpdata .gaussian .log .to_system_data (file_name , md = md )
1042
1059
except AssertionError :
1043
1060
self .data ['energies' ], self .data ['forces' ]= [], []
1044
1061
self .data ['nopbc' ] = True
1062
+
1063
+ @register_from_funcs .register_funcs ('gaussian/md' )
1064
+ def from_gaussian_md (self , file_name ):
1065
+ self .from_gaussian_log (file_name , md = True )
1045
1066
1046
-
1067
+ @ register_from_funcs . register_funcs ( 'cp2k/output' )
1047
1068
def from_cp2k_output (self , file_name ) :
1048
1069
self .data ['atom_names' ], \
1049
1070
self .data ['atom_numbs' ], \
0 commit comments