Skip to content

Commit 0a1c18b

Browse files
authored
add ML labeling option for vasp ml-aimd OUTCAR (#282)
* add ML labeling option for vasp ml-aimd OUTCAR * add test cases for ML labeling OUTCAR * Update vasp.py * Update vasp.py * Update outcar.py * merge functions of ML labeling and fp labeling * add and delete necessary comments * Update outcar.py
1 parent 506f168 commit 0a1c18b

File tree

4 files changed

+3334
-27
lines changed

4 files changed

+3334
-27
lines changed

dpdata/plugins/vasp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class VASPOutcarFormat(Format):
5656
@Format.post("rot_lower_triangular")
5757
def from_labeled_system(self, file_name, begin=0, step=1, **kwargs):
5858
data = {}
59+
ml = kwargs.get("ml", False)
5960
data['atom_names'], \
6061
data['atom_numbs'], \
6162
data['atom_types'], \
@@ -64,7 +65,7 @@ def from_labeled_system(self, file_name, begin=0, step=1, **kwargs):
6465
data['energies'], \
6566
data['forces'], \
6667
tmp_virial, \
67-
= dpdata.vasp.outcar.get_frames(file_name, begin=begin, step=step)
68+
= dpdata.vasp.outcar.get_frames(file_name, begin=begin, step=step, ml=ml)
6869
if tmp_virial is not None:
6970
data['virials'] = tmp_virial
7071
# scale virial to the unit of eV
@@ -107,3 +108,4 @@ def from_labeled_system(self, file_name, begin=0, step=1, **kwargs):
107108
data['virials'][ii] *= v_pref * vol
108109
data = uniq_atom_names(data)
109110
return data
111+

dpdata/vasp/outcar.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import re
33

4-
def system_info (lines, type_idx_zero = False) :
4+
def system_info(lines, type_idx_zero = False):
55
atom_names = []
66
atom_numbs = None
77
nelm = None
@@ -30,7 +30,7 @@ def system_info (lines, type_idx_zero = False) :
3030
assert(atom_numbs is not None), "cannot find ion type info in OUTCAR"
3131
atom_names = atom_names[:len(atom_numbs)]
3232
atom_types = []
33-
for idx,ii in enumerate(atom_numbs) :
33+
for idx,ii in enumerate(atom_numbs):
3434
for jj in range(ii) :
3535
if type_idx_zero :
3636
atom_types.append(idx)
@@ -39,18 +39,20 @@ def system_info (lines, type_idx_zero = False) :
3939
return atom_names, atom_numbs, np.array(atom_types, dtype = int), nelm
4040

4141

42-
def get_outcar_block(fp) :
42+
def get_outcar_block(fp, ml = False):
4343
blk = []
44+
energy_token = ['free energy TOTEN', 'free energy ML TOTEN']
45+
ml_index = int(ml)
4446
for ii in fp :
4547
if not ii :
4648
return blk
4749
blk.append(ii.rstrip('\n'))
48-
if 'free energy TOTEN' in ii:
50+
if energy_token[ml_index] in ii:
4951
return blk
5052
return blk
5153

5254
# we assume that the force is printed ...
53-
def get_frames (fname, begin = 0, step = 1) :
55+
def get_frames(fname, begin = 0, step = 1, ml = False):
5456
fp = open(fname)
5557
blk = get_outcar_block(fp)
5658

@@ -66,7 +68,7 @@ def get_frames (fname, begin = 0, step = 1) :
6668
cc = 0
6769
while len(blk) > 0 :
6870
if cc >= begin and (cc - begin) % step == 0 :
69-
coord, cell, energy, force, virial, is_converge = analyze_block(blk, ntot, nelm)
71+
coord, cell, energy, force, virial, is_converge = analyze_block(blk, ntot, nelm, ml)
7072
if is_converge :
7173
if len(coord) == 0:
7274
break
@@ -76,9 +78,10 @@ def get_frames (fname, begin = 0, step = 1) :
7678
all_forces.append(force)
7779
if virial is not None :
7880
all_virials.append(virial)
79-
blk = get_outcar_block(fp)
81+
82+
blk = get_outcar_block(fp, ml)
8083
cc += 1
81-
84+
8285
if len(all_virials) == 0 :
8386
all_virials = None
8487
else :
@@ -87,36 +90,39 @@ def get_frames (fname, begin = 0, step = 1) :
8790
return atom_names, atom_numbs, atom_types, np.array(all_cells), np.array(all_coords), np.array(all_energies), np.array(all_forces), all_virials
8891

8992

90-
def analyze_block(lines, ntot, nelm) :
93+
def analyze_block(lines, ntot, nelm, ml = False):
9194
coord = []
9295
cell = []
9396
energy = None
9497
force = []
9598
virial = None
9699
is_converge = True
97100
sc_index = 0
98-
for idx,ii in enumerate(lines) :
99-
if 'Iteration' in ii:
101+
#select different searching tokens based on the ml label
102+
energy_token = ['free energy TOTEN', 'free energy ML TOTEN']
103+
energy_index = [4, 5]
104+
viral_token = ['FORCE on cell =-STRESS in cart. coord. units', 'ML FORCE']
105+
viral_index = [14, 4]
106+
cell_token = ['VOLUME and BASIS', 'ML FORCE']
107+
cell_index = [5, 12]
108+
ml_index = int(ml)
109+
for idx,ii in enumerate(lines):
110+
#if set ml == True, is_converged will always be True
111+
if ('Iteration' in ii) and (not ml):
100112
sc_index = int(ii.split()[3][:-1])
101113
if sc_index >= nelm:
102114
is_converge = False
103-
elif 'free energy TOTEN' in ii:
104-
energy = float(ii.split()[4])
115+
elif energy_token[ml_index] in ii:
116+
energy = float(ii.split()[energy_index[ml_index]])
105117
assert((force is not None) and len(coord) > 0 and len(cell) > 0)
106-
# all_coords.append(coord)
107-
# all_cells.append(cell)
108-
# all_energies.append(energy)
109-
# all_forces.append(force)
110-
# if virial is not None :
111-
# all_virials.append(virial)
112118
return coord, cell, energy, force, virial, is_converge
113-
elif 'VOLUME and BASIS' in ii:
119+
elif cell_token[ml_index] in ii:
114120
for dd in range(3) :
115-
tmp_l = lines[idx+5+dd]
121+
tmp_l = lines[idx+cell_index[ml_index]+dd]
116122
cell.append([float(ss)
117123
for ss in tmp_l.replace('-',' -').split()[0:3]])
118-
elif 'in kB' in ii:
119-
tmp_v = [float(ss) for ss in ii.split()[2:8]]
124+
elif viral_token[ml_index] in ii:
125+
tmp_v = [float(ss) for ss in lines[idx+viral_index[ml_index]].split()[2:8]]
120126
virial = np.zeros([3,3])
121127
virial[0][0] = tmp_v[0]
122128
virial[1][1] = tmp_v[1]
@@ -127,9 +133,7 @@ def analyze_block(lines, ntot, nelm) :
127133
virial[2][1] = tmp_v[4]
128134
virial[0][2] = tmp_v[5]
129135
virial[2][0] = tmp_v[5]
130-
elif 'TOTAL-FORCE' in ii and ("ML" not in ii):
131-
# use the lines with " POSITION TOTAL-FORCE (eV/Angst)"
132-
# exclude the lines with " POSITION TOTAL-FORCE (eV/Angst) (ML)"
136+
elif 'TOTAL-FORCE' in ii and (("ML" in ii) == ml):
133137
for jj in range(idx+2, idx+2+ntot) :
134138
tmp_l = lines[jj]
135139
info = [float(ss) for ss in tmp_l.split()]

0 commit comments

Comments
 (0)