1
1
import numpy as np
2
2
import re
3
3
4
- def system_info (lines , type_idx_zero = False ) :
4
+ def system_info (lines , type_idx_zero = False ):
5
5
atom_names = []
6
6
atom_numbs = None
7
7
nelm = None
@@ -30,7 +30,7 @@ def system_info (lines, type_idx_zero = False) :
30
30
assert (atom_numbs is not None ), "cannot find ion type info in OUTCAR"
31
31
atom_names = atom_names [:len (atom_numbs )]
32
32
atom_types = []
33
- for idx ,ii in enumerate (atom_numbs ) :
33
+ for idx ,ii in enumerate (atom_numbs ):
34
34
for jj in range (ii ) :
35
35
if type_idx_zero :
36
36
atom_types .append (idx )
@@ -39,18 +39,20 @@ def system_info (lines, type_idx_zero = False) :
39
39
return atom_names , atom_numbs , np .array (atom_types , dtype = int ), nelm
40
40
41
41
42
- def get_outcar_block (fp ) :
42
+ def get_outcar_block (fp , ml = False ) :
43
43
blk = []
44
+ energy_token = ['free energy TOTEN' , 'free energy ML TOTEN' ]
45
+ ml_index = int (ml )
44
46
for ii in fp :
45
47
if not ii :
46
48
return blk
47
49
blk .append (ii .rstrip ('\n ' ))
48
- if 'free energy TOTEN' in ii :
50
+ if energy_token [ ml_index ] in ii :
49
51
return blk
50
52
return blk
51
53
52
54
# we assume that the force is printed ...
53
- def get_frames (fname , begin = 0 , step = 1 ) :
55
+ def get_frames (fname , begin = 0 , step = 1 , ml = False ) :
54
56
fp = open (fname )
55
57
blk = get_outcar_block (fp )
56
58
@@ -66,7 +68,7 @@ def get_frames (fname, begin = 0, step = 1) :
66
68
cc = 0
67
69
while len (blk ) > 0 :
68
70
if cc >= begin and (cc - begin ) % step == 0 :
69
- coord , cell , energy , force , virial , is_converge = analyze_block (blk , ntot , nelm )
71
+ coord , cell , energy , force , virial , is_converge = analyze_block (blk , ntot , nelm , ml )
70
72
if is_converge :
71
73
if len (coord ) == 0 :
72
74
break
@@ -76,9 +78,10 @@ def get_frames (fname, begin = 0, step = 1) :
76
78
all_forces .append (force )
77
79
if virial is not None :
78
80
all_virials .append (virial )
79
- blk = get_outcar_block (fp )
81
+
82
+ blk = get_outcar_block (fp , ml )
80
83
cc += 1
81
-
84
+
82
85
if len (all_virials ) == 0 :
83
86
all_virials = None
84
87
else :
@@ -87,36 +90,39 @@ def get_frames (fname, begin = 0, step = 1) :
87
90
return atom_names , atom_numbs , atom_types , np .array (all_cells ), np .array (all_coords ), np .array (all_energies ), np .array (all_forces ), all_virials
88
91
89
92
90
- def analyze_block (lines , ntot , nelm ) :
93
+ def analyze_block (lines , ntot , nelm , ml = False ) :
91
94
coord = []
92
95
cell = []
93
96
energy = None
94
97
force = []
95
98
virial = None
96
99
is_converge = True
97
100
sc_index = 0
98
- for idx ,ii in enumerate (lines ) :
99
- if 'Iteration' in ii :
101
+ #select different searching tokens based on the ml label
102
+ energy_token = ['free energy TOTEN' , 'free energy ML TOTEN' ]
103
+ energy_index = [4 , 5 ]
104
+ viral_token = ['FORCE on cell =-STRESS in cart. coord. units' , 'ML FORCE' ]
105
+ viral_index = [14 , 4 ]
106
+ cell_token = ['VOLUME and BASIS' , 'ML FORCE' ]
107
+ cell_index = [5 , 12 ]
108
+ ml_index = int (ml )
109
+ for idx ,ii in enumerate (lines ):
110
+ #if set ml == True, is_converged will always be True
111
+ if ('Iteration' in ii ) and (not ml ):
100
112
sc_index = int (ii .split ()[3 ][:- 1 ])
101
113
if sc_index >= nelm :
102
114
is_converge = False
103
- elif 'free energy TOTEN' in ii :
104
- energy = float (ii .split ()[4 ])
115
+ elif energy_token [ ml_index ] in ii :
116
+ energy = float (ii .split ()[energy_index [ ml_index ] ])
105
117
assert ((force is not None ) and len (coord ) > 0 and len (cell ) > 0 )
106
- # all_coords.append(coord)
107
- # all_cells.append(cell)
108
- # all_energies.append(energy)
109
- # all_forces.append(force)
110
- # if virial is not None :
111
- # all_virials.append(virial)
112
118
return coord , cell , energy , force , virial , is_converge
113
- elif 'VOLUME and BASIS' in ii :
119
+ elif cell_token [ ml_index ] in ii :
114
120
for dd in range (3 ) :
115
- tmp_l = lines [idx + 5 + dd ]
121
+ tmp_l = lines [idx + cell_index [ ml_index ] + dd ]
116
122
cell .append ([float (ss )
117
123
for ss in tmp_l .replace ('-' ,' -' ).split ()[0 :3 ]])
118
- elif 'in kB' in ii :
119
- tmp_v = [float (ss ) for ss in ii .split ()[2 :8 ]]
124
+ elif viral_token [ ml_index ] in ii :
125
+ tmp_v = [float (ss ) for ss in lines [ idx + viral_index [ ml_index ]] .split ()[2 :8 ]]
120
126
virial = np .zeros ([3 ,3 ])
121
127
virial [0 ][0 ] = tmp_v [0 ]
122
128
virial [1 ][1 ] = tmp_v [1 ]
@@ -127,9 +133,7 @@ def analyze_block(lines, ntot, nelm) :
127
133
virial [2 ][1 ] = tmp_v [4 ]
128
134
virial [0 ][2 ] = tmp_v [5 ]
129
135
virial [2 ][0 ] = tmp_v [5 ]
130
- elif 'TOTAL-FORCE' in ii and ("ML" not in ii ):
131
- # use the lines with " POSITION TOTAL-FORCE (eV/Angst)"
132
- # exclude the lines with " POSITION TOTAL-FORCE (eV/Angst) (ML)"
136
+ elif 'TOTAL-FORCE' in ii and (("ML" in ii ) == ml ):
133
137
for jj in range (idx + 2 , idx + 2 + ntot ) :
134
138
tmp_l = lines [jj ]
135
139
info = [float (ss ) for ss in tmp_l .split ()]
0 commit comments