Skip to content

Commit 9fea170

Browse files
authored
JDFTXOutfileSlice Durability Improvement (#4418)
* Protecting the final JOutStructure with try/except so in-progress out files can be parsed * Enabling partial eigstats filling * Variety of small changes to improve durability of parsing partially finished out files - can return a usable `JDFTXOutfileSlice` as long as electronic minimize has started * Removing commented out code * Removing commented out code * Removing walk-through-helping lines from test
1 parent c6031ff commit 9fea170

File tree

5 files changed

+81
-35
lines changed

5 files changed

+81
-35
lines changed

src/pymatgen/io/jdftx/_output_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,17 @@ def find_all_key(key_input: str, tempfile: list[str], startline: int = 0) -> lis
336336
return [i for i in range(startline, len(tempfile)) if key_input in tempfile[i]]
337337

338338

339+
def _init_dict_from_colon_dump_lines(lines: list[str]):
340+
varsdict = {}
341+
for line in lines:
342+
if ":" in line:
343+
lsplit = line.split(":")
344+
key = lsplit[0].strip()
345+
val = lsplit[1].split()[0].strip()
346+
varsdict[key] = val
347+
return varsdict
348+
349+
339350
def _parse_bandfile_complex(bandfile_filepath: str | Path) -> NDArray[np.complex64]:
340351
Dtype: TypeAlias = np.complex64
341352
token_parser = _complex_token_parser

src/pymatgen/io/jdftx/jdftxoutfileslice.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from pymatgen.core.trajectory import Trajectory
2626
from pymatgen.core.units import Ha_to_eV, ang_to_bohr, bohr_to_ang
2727
from pymatgen.io.jdftx._output_utils import (
28+
_init_dict_from_colon_dump_lines,
2829
find_all_key,
2930
find_first_range_key,
3031
find_key,
@@ -220,6 +221,7 @@ class JDFTXOutfileSlice:
220221
_electronic_output: ClassVar[list[str]] = [
221222
"efermi",
222223
"egap",
224+
"optical_egap",
223225
"emin",
224226
"emax",
225227
"homo",
@@ -230,6 +232,7 @@ class JDFTXOutfileSlice:
230232
]
231233
efermi: float | None = None
232234
egap: float | None = None
235+
optical_egap: float | None = None
233236
emin: float | None = None
234237
emax: float | None = None
235238
homo: float | None = None
@@ -666,22 +669,19 @@ def _get_eigstats_varsdict(self, text: list[str], prefix: str | None) -> dict[st
666669
lines2 = find_all_key("eigStats' ...", text)
667670
lines3 = [lines1[i] for i in range(len(lines1)) if lines1[i] in lines2]
668671
if not lines3:
669-
varsdict["emin"] = None
670-
varsdict["homo"] = None
671-
varsdict["efermi"] = None
672-
varsdict["lumo"] = None
673-
varsdict["emax"] = None
674-
varsdict["egap"] = None
672+
for key in list(eigstats_keymap.keys()):
673+
varsdict[eigstats_keymap[key]] = None
675674
self.has_eigstats = False
676675
else:
677-
line = lines3[-1]
678-
varsdict["emin"] = float(text[line + 1].split()[1]) * Ha_to_eV
679-
varsdict["homo"] = float(text[line + 2].split()[1]) * Ha_to_eV
680-
varsdict["efermi"] = float(text[line + 3].split()[2]) * Ha_to_eV
681-
varsdict["lumo"] = float(text[line + 4].split()[1]) * Ha_to_eV
682-
varsdict["emax"] = float(text[line + 5].split()[1]) * Ha_to_eV
683-
varsdict["egap"] = float(text[line + 6].split()[2]) * Ha_to_eV
684-
self.has_eigstats = True
676+
line_start = lines3[-1]
677+
line_start_rel_idx = lines1.index(line_start)
678+
line_end = lines1[line_start_rel_idx + 1] if len(lines1) >= line_start_rel_idx + 2 else len(lines1) - 1
679+
_varsdict = _init_dict_from_colon_dump_lines([text[idx] for idx in range(line_start, line_end)])
680+
for key in _varsdict:
681+
varsdict[eigstats_keymap[key]] = float(_varsdict[key]) * Ha_to_eV
682+
self.has_eigstats = all(eigstats_keymap[key] in varsdict for key in eigstats_keymap) and all(
683+
eigstats_keymap[key] is not None for key in eigstats_keymap
684+
)
685685
return varsdict
686686

687687
def _set_eigvars(self, text: list[str]) -> None:
@@ -691,12 +691,8 @@ def _set_eigvars(self, text: list[str]) -> None:
691691
text (list[str]): Output of read_file for out file.
692692
"""
693693
eigstats = self._get_eigstats_varsdict(text, self.prefix)
694-
self.emin = eigstats["emin"]
695-
self.homo = eigstats["homo"]
696-
self.efermi = eigstats["efermi"]
697-
self.lumo = eigstats["lumo"]
698-
self.emax = eigstats["emax"]
699-
self.egap = eigstats["egap"]
694+
for key, val in eigstats.items():
695+
setattr(self, key, val)
700696
if self.efermi is None:
701697
if self.mu is None:
702698
self.mu = self._get_mu()
@@ -1063,12 +1059,9 @@ def _set_atom_vars(self, text: list[str]) -> None:
10631059
self.atom_elements = atom_elements
10641060
self.atom_elements_int = [Element(x).Z for x in self.atom_elements]
10651061
self.atom_types = atom_types
1066-
line = find_key("# Ionic positions in", text)
1067-
if line is not None:
1068-
line += 1
1069-
coords = np.array([text[i].split()[2:5] for i in range(line, line + self.nat)], dtype=float)
1070-
self.atom_coords_final = coords
1071-
self.atom_coords = coords.copy()
1062+
if isinstance(self.structure, Structure):
1063+
self.atom_coords = self.structure.cart_coords
1064+
self.atom_coords_final = self.structure.cart_coords
10721065

10731066
def _set_lattice_vars(self, text: list[str]) -> None:
10741067
"""Set the lattice variables.
@@ -1246,6 +1239,17 @@ def __str__(self) -> str:
12461239
return pprint.pformat(self)
12471240

12481241

1242+
eigstats_keymap = {
1243+
"eMin": "emin",
1244+
"HOMO": "homo",
1245+
"mu": "efermi",
1246+
"LUMO": "lumo",
1247+
"eMax": "emax",
1248+
"HOMO-LUMO gap": "egap",
1249+
"Optical gap": "optical_egap",
1250+
}
1251+
1252+
12491253
def get_pseudo_read_section_bounds(text: list[str]) -> list[list[int]]:
12501254
"""Get the boundary line numbers for the pseudopotential read section.
12511255

src/pymatgen/io/jdftx/joutstructure.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ def _parse_lattice_lines(self, lattice_lines: list[str]) -> None:
584584
order (vec i = r[i,:]) and converts from Bohr to Angstroms.
585585
"""
586586
r = None
587-
if len(lattice_lines):
587+
if len(lattice_lines) >= 5:
588588
r = _brkt_list_of_3x3_to_nparray(lattice_lines, i_start=2)
589589
r = r.T * bohr_to_ang
590590
self.lattice = Lattice(np.array(r))
@@ -597,7 +597,7 @@ def _parse_strain_lines(self, strain_lines: list[str]) -> None:
597597
strain tensor. Converts from column-major to row-major order.
598598
"""
599599
st = None
600-
if len(strain_lines):
600+
if len(strain_lines) == 4:
601601
st = _brkt_list_of_3x3_to_nparray(strain_lines, i_start=1)
602602
st = st.T
603603
self.strain = st
@@ -616,7 +616,7 @@ def _parse_stress_lines(self, stress_lines: list[str]) -> None:
616616
# "[Eh/a0^3]" (Hartree per bohr cubed). Check if this changes for direct
617617
# coordinates.
618618
st = None
619-
if len(stress_lines):
619+
if len(stress_lines) == 4:
620620
st = _brkt_list_of_3x3_to_nparray(stress_lines, i_start=1)
621621
st = st.T
622622
st *= Ha_to_eV / (bohr_to_ang**3)
@@ -636,7 +636,7 @@ def _parse_kinetic_stress_lines(self, stress_lines: list[str]) -> None:
636636
# "[Eh/a0^3]" (Hartree per bohr cubed). Check if this changes for direct
637637
# coordinates.
638638
st = None
639-
if len(stress_lines):
639+
if len(stress_lines) == 4:
640640
st = _brkt_list_of_3x3_to_nparray(stress_lines, i_start=1)
641641
st = st.T
642642
st *= Ha_to_eV / (bohr_to_ang**3)
@@ -659,6 +659,18 @@ def _parse_thermostat_line(self, posns_lines: list[str]) -> None:
659659
else:
660660
self.thermostat_velocity = None
661661

662+
def _check_for_structure_consistency(self, names: list[str]) -> bool:
663+
# If JOutStructure was constructed with a reference init_structure
664+
if len(self.species):
665+
if len(names) != len(self.species):
666+
return False
667+
_names = list(set(names))
668+
_self_names = [s.symbol for s in self.species]
669+
for _name in _names:
670+
if names.count(_name) != _self_names.count(_name):
671+
return False
672+
return True
673+
662674
def _parse_posns_lines(self, posns_lines: list[str]) -> None:
663675
"""Parse positions lines.
664676
@@ -673,8 +685,8 @@ def _parse_posns_lines(self, posns_lines: list[str]) -> None:
673685
the name of the element, and sd is a flag indicating whether the ion is
674686
excluded from optimization (1) or not (0).
675687
"""
688+
self.copy()
676689
if len(posns_lines):
677-
self.remove_sites(list(range(len(self.species))))
678690
coords_type = posns_lines[0].split("positions in")[1]
679691
coords_type = coords_type.strip().split()[0].strip()
680692
_posns: list[NDArray[np.float64]] = []
@@ -697,6 +709,11 @@ def _parse_posns_lines(self, posns_lines: list[str]) -> None:
697709
constraint_types.append(constraint_type)
698710
constraint_vectors.append(constraint_vector)
699711
group_names_list.append(group_names)
712+
is_good = self._check_for_structure_consistency(names)
713+
if not is_good and len(self.species):
714+
# Abort structure updating if we have a pre-existing structure
715+
return
716+
self.remove_sites(list(range(len(self.species))))
700717
posns = np.array(_posns)
701718
if coords_type.lower() != "cartesian":
702719
posns = np.dot(posns, self.lattice.matrix)

src/pymatgen/io/jdftx/joutstructures.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -384,11 +384,18 @@ def _get_joutstructure_list(
384384
for i, bounds in enumerate(out_bounds):
385385
if i > 0:
386386
init_structure = joutstructure_list[-1]
387-
joutstructure_list.append(
388-
JOutStructure._from_text_slice(
387+
joutstructure = None
388+
# The final out_slice slice is protected by the try/except block, as this slice has a high
389+
# chance of being empty or malformed.
390+
try:
391+
joutstructure = JOutStructure._from_text_slice(
389392
out_slice[bounds[0] : bounds[1]],
390393
init_structure=init_structure,
391394
opt_type=opt_type,
392395
)
393-
)
396+
except (ValueError, IndexError, TypeError, KeyError, AttributeError):
397+
if not i == len(out_bounds) - 1:
398+
raise
399+
if joutstructure is not None:
400+
joutstructure_list.append(joutstructure)
394401
return joutstructure_list

tests/io/jdftx/test_jdftxoutfileslice.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,10 @@ def test_as_dict():
208208
assert isinstance(out_dict, dict)
209209

210210

211+
def should_be_parsable_out_slice(out_slice: list[str]):
212+
return any("ElecMinimize: Iter:" in line for line in out_slice[::-1])
213+
214+
211215
# Make sure all possible exceptions are caught when none_on_error is True
212216
@pytest.mark.parametrize(("ex_slice"), [(ex_slice1)])
213217
def test_none_on_partial(ex_slice: list[str]):
@@ -216,4 +220,7 @@ def test_none_on_partial(ex_slice: list[str]):
216220
for i in range(int(len(ex_slice) / freq)):
217221
test_slice = ex_slice[: -(i * freq)]
218222
joutslice = JDFTXOutfileSlice._from_out_slice(test_slice, none_on_error=True)
219-
assert isinstance(joutslice, JDFTXOutfileSlice | None)
223+
if should_be_parsable_out_slice(test_slice):
224+
assert isinstance(joutslice, JDFTXOutfileSlice | None)
225+
else:
226+
assert isinstance(joutslice, JDFTXOutfileSlice | None)

0 commit comments

Comments
 (0)