Skip to content

Commit b1ba78b

Browse files
wanghan-iapcmHan Wangpre-commit-ci[bot]
authored
feat: scf convergence check in vasp .xml format. (#862)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added an option to filter VASP XML calculation steps based on SCF convergence status. * Introduced a new parameter to enable or disable convergence checking when loading VASP XML files. * **Tests** * Added new tests to verify behavior with and without SCF convergence filtering using sample VASP XML files. * Included a new sample VASP XML output file for testing purposes. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Han Wang <[email protected]> Co-authored-by: Han Wang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b424090 commit b1ba78b

File tree

4 files changed

+1843
-9
lines changed

4 files changed

+1843
-9
lines changed

dpdata/plugins/vasp.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ def from_labeled_system(
124124
@Format.register("vasp/xml")
125125
class VASPXMLFormat(Format):
126126
@Format.post("rot_lower_triangular")
127-
def from_labeled_system(self, file_name, begin=0, step=1, **kwargs):
127+
def from_labeled_system(
128+
self, file_name, begin=0, step=1, convergence_check=True, **kwargs
129+
):
128130
data = {}
129131
(
130132
data["atom_names"],
@@ -135,7 +137,11 @@ def from_labeled_system(self, file_name, begin=0, step=1, **kwargs):
135137
data["forces"],
136138
tmp_virial,
137139
) = dpdata.vasp.xml.analyze(
138-
file_name, type_idx_zero=True, begin=begin, step=step
140+
file_name,
141+
type_idx_zero=True,
142+
begin=begin,
143+
step=step,
144+
convergence_check=convergence_check,
139145
)
140146
data["atom_numbs"] = []
141147
for ii in range(len(data["atom_names"])):

dpdata/vasp/xml.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import annotations
33

44
import xml.etree.ElementTree as ET
5+
from typing import Any
56

67
import numpy as np
78

@@ -33,13 +34,46 @@ def analyze_atominfo(atominfo_xml):
3334
return eles, types
3435

3536

36-
def analyze_calculation(cc):
37+
def analyze_calculation(
38+
cc: Any,
39+
nelm: int | None,
40+
) -> tuple[np.ndarray, np.ndarray, float, np.ndarray, np.ndarray | None, bool | None]:
41+
"""Analyze a calculation block.
42+
43+
Parameters
44+
----------
45+
cc : xml.etree.ElementTree.Element
46+
The xml element for a ion step calculation
47+
nelm : Optional[int]
48+
The number nelm, if it is not None, convergence check is performed.
49+
50+
Returns
51+
-------
52+
posi : np.ndarray
53+
The positions
54+
cell : np.ndarray
55+
The cell
56+
ener : float
57+
The energy
58+
force : np.ndarray
59+
The forces
60+
str : Optional[np.ndarray]
61+
The stress
62+
is_converged: Optional[bool]
63+
If the scf calculation is converged. Only return boolean when
64+
nelm is not None. Otherwise return None.
65+
66+
"""
3767
structure_xml = cc.find("structure")
3868
check_name(structure_xml.find("crystal").find("varray"), "basis")
3969
check_name(structure_xml.find("varray"), "positions")
4070
cell = get_varray(structure_xml.find("crystal").find("varray"))
4171
posi = get_varray(structure_xml.find("varray"))
4272
strs = None
73+
is_converged = None
74+
if nelm is not None:
75+
niter = len(cc.findall(".//scstep"))
76+
is_converged = niter < nelm
4377
for vv in cc.findall("varray"):
4478
if vv.attrib["name"] == "forces":
4579
forc = get_varray(vv)
@@ -48,9 +82,7 @@ def analyze_calculation(cc):
4882
for ii in cc.find("energy").findall("i"):
4983
if ii.attrib["name"] == "e_fr_energy":
5084
ener = float(ii.text)
51-
# print(ener)
52-
# return 'a'
53-
return posi, cell, ener, forc, strs
85+
return posi, cell, ener, forc, strs, is_converged
5486

5587

5688
def formulate_config(eles, types, posi, cell, ener, forc, strs_):
@@ -80,14 +112,24 @@ def formulate_config(eles, types, posi, cell, ener, forc, strs_):
80112
return ret
81113

82114

83-
def analyze(fname, type_idx_zero=False, begin=0, step=1):
115+
def analyze(fname, type_idx_zero=False, begin=0, step=1, convergence_check=True):
84116
"""Deal with broken xml file."""
85117
all_posi = []
86118
all_cell = []
87119
all_ener = []
88120
all_forc = []
89121
all_strs = []
90122
cc = 0
123+
if convergence_check:
124+
tree = ET.parse(fname)
125+
root = tree.getroot()
126+
parameters = root.find(".//parameters")
127+
nelm = parameters.find(".//i[@name='NELM']")
128+
# will check convergence
129+
nelm = int(nelm.text)
130+
else:
131+
# not checking convergence
132+
nelm = None
91133
try:
92134
for event, elem in ET.iterparse(fname):
93135
if elem.tag == "atominfo":
@@ -96,8 +138,16 @@ def analyze(fname, type_idx_zero=False, begin=0, step=1):
96138
if type_idx_zero:
97139
types = types - 1
98140
if elem.tag == "calculation":
99-
posi, cell, ener, forc, strs = analyze_calculation(elem)
100-
if cc >= begin and (cc - begin) % step == 0:
141+
posi, cell, ener, forc, strs, is_converged = analyze_calculation(
142+
elem, nelm
143+
)
144+
# record when not checking convergence or is_converged
145+
# and the step criteria is satisfied
146+
if (
147+
(nelm is None or is_converged)
148+
and cc >= begin
149+
and (cc - begin) % step == 0
150+
):
101151
all_posi.append(posi)
102152
all_cell.append(cell)
103153
all_ener.append(ener)

0 commit comments

Comments
 (0)