Skip to content

Commit f57657a

Browse files
authored
Merge pull request #110 from benrich37/faster-parsing
Misc edits
2 parents af19747 + 0bc390e commit f57657a

File tree

11 files changed

+289
-75
lines changed

11 files changed

+289
-75
lines changed

src/pymatgen/io/jdftx/inputs.py

Lines changed: 153 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,14 @@ class is written.
2121
from pymatgen.core import Lattice, Structure
2222
from pymatgen.core.periodic_table import Element
2323
from pymatgen.core.units import bohr_to_ang
24-
from pymatgen.io.jdftx.generic_tags import AbstractTag, BoolTagContainer, DumpTagContainer, MultiformatTag, TagContainer
24+
from pymatgen.io.jdftx.generic_tags import (
25+
AbstractTag,
26+
BoolTagContainer,
27+
DumpTagContainer,
28+
FloatTag,
29+
MultiformatTag,
30+
TagContainer,
31+
)
2532
from pymatgen.io.jdftx.jdftxinfile_default_inputs import default_inputs
2633
from pymatgen.io.jdftx.jdftxinfile_master_format import (
2734
__PHONON_TAGS__,
@@ -45,14 +52,15 @@ class is written.
4552

4653
__author__ = "Jacob Clary, Ben Rich"
4754

55+
4856
# TODO: Add check for whether all ions have or lack velocities.
4957
# TODO: Add default value filling like JDFTx does.
5058
# TODO: Add more robust checking for if two repeatable tag values represent the
5159
# same information. This is likely fixed by implementing filling of default values.
5260
# TODO: Incorporate something to collapse repeated dump tags of the same frequency
5361
# into a single value.
54-
55-
62+
# TODO: Add a method to strip all tags that contain their default values for simpler
63+
# files written (especially when a `JDFTXInfile` is created from `JDFTXOutfileSlice`)
5664
class JDFTXInfile(dict, MSONable):
5765
"""Class for reading/writing JDFtx input files.
5866
@@ -270,6 +278,32 @@ def from_jdftxstructure(
270278
jstr = jdftxstructure.get_str()
271279
return cls.from_str(jstr)
272280

281+
def read_line(
282+
self,
283+
line: str,
284+
validate_value_boundaries: bool = True,
285+
autofix: bool = True,
286+
overwrite_nonrepeatable: bool = True,
287+
) -> None:
288+
"""Read a single line and update the JDFTXInfile object.
289+
290+
Convenience method for reading a single line and updating the JDFTXInfile object.
291+
292+
Args:
293+
line (str): Line to read.
294+
"""
295+
line = line.strip()
296+
tag_object, tag, value = self._preprocess_line(line)
297+
if not tag_object.can_repeat and overwrite_nonrepeatable and tag in self:
298+
del self[tag]
299+
processed_value = tag_object.read(tag, value)
300+
_params = self.as_dict(skip_module_keys=True)
301+
_params = self._store_value(_params, tag_object, tag, processed_value)
302+
self.update(_params)
303+
self.validate_tags(try_auto_type_fix=autofix, error_on_failed_fix=True)
304+
if validate_value_boundaries:
305+
self.validate_boundaries()
306+
273307
@classmethod
274308
def from_str(
275309
cls,
@@ -397,7 +431,9 @@ def copy(self) -> JDFTXInfile:
397431
Returns:
398432
JDFTXInfile: Copy of the JDFTXInfile object.
399433
"""
400-
return type(self)(self)
434+
# Wasn't working before
435+
# return type(self)(self)
436+
return self.from_dict(self.as_dict(skip_module_keys=True), validate_value_boundaries=False)
401437

402438
def get_text_list(self) -> list[str]:
403439
"""Get a list of strings representation of the JDFTXInfile.
@@ -426,14 +462,24 @@ def get_text_list(self) -> list[str]:
426462
text.append("")
427463
return text
428464

429-
def write_file(self, filename: PathLike) -> None:
465+
# TODO: JDFTXInfile can accept nan for values, as this is occasionally what is stored
466+
# for unused variables, but JDFTx has no way read nan for an input value. All subtags
467+
# with nan values should be removed before writing to file.
468+
# TODO: Detect for and warn for tags that can be used together but likely shouldn't be,
469+
# ie (ion-width being 0 while fluid is not None)
470+
def write_file(self, filename: PathLike, strip_nan: bool = False) -> None:
430471
"""Write JDFTXInfile to an in file.
431472
432473
Args:
433474
filename (PathLike): Filename to write to.
475+
strip_nan (bool, optional): Whether to strip all subtags with nan values before writing.
476+
Defaults to False. WARNING - VERY JANKY RIGHT NOW
434477
"""
478+
write_infile = self
479+
if strip_nan:
480+
write_infile = clean_infile_of_nans(self)
435481
with open(filename, mode="w") as file:
436-
file.write(str(self))
482+
file.write(str(write_infile))
437483

438484
@classmethod
439485
def to_jdftxstructure(
@@ -912,6 +958,107 @@ def movescale_array_to_selective_dynamics_site_prop(movescale: ArrayLike[int | f
912958
return selective_dynamics
913959

914960

961+
# def _strip_nans(infile: JDFTXInfile) -> JDFTXInfile:
962+
# for k, v in infile.items():
963+
# tag_object = get_tag_object_on_val(k, v)
964+
# if isinstance(v, dict):
965+
# infile[k] = _strip_nans(v)
966+
# if isinstance(v, float) and np.isnan(v):
967+
# infile[k] = None
968+
# elif isinstance(v, list):
969+
# infile[k] = [x for x in v if not (isinstance(x, float) and np.isnan(x))]
970+
# elif isinstance(v, dict):
971+
# infile[k] = _strip_nans(v)
972+
973+
# def _has_nans(value: float | list) -> bool:
974+
# if isinstance(value, float) and np.isnan(value):
975+
# return True
976+
# elif isinstance(value, list):
977+
# return any(_has_nans(x) for x in value)
978+
# return False
979+
980+
981+
def _isnan(x):
982+
try:
983+
return np.isnan(x)
984+
except TypeError:
985+
return False
986+
987+
988+
def _check_tagcontainer_for_nan(tag_container: TagContainer, val_dict: dict):
989+
hasnans = []
990+
for kk, vv in val_dict.items():
991+
tag = tag_container.subtags[kk]
992+
if not isinstance(tag, TagContainer):
993+
if _isnan(vv):
994+
print(f"Tag {kk} has nan value")
995+
hasnans.append(kk)
996+
elif not tag.can_repeat:
997+
_hasnans = _check_tagcontainer_for_nan(tag, vv)
998+
if len(_hasnans) > 0:
999+
hasnans.append({kk: _hasnans})
1000+
return hasnans
1001+
1002+
1003+
def has_nan_in_required_subtag(tag_container: TagContainer, val_dict: dict):
1004+
for kk, vv in val_dict.items():
1005+
tag = tag_container.subtags[kk]
1006+
if not isinstance(tag, TagContainer):
1007+
if _isnan(vv) and not tag.optional:
1008+
return True
1009+
elif not tag.can_repeat and has_nan_in_required_subtag(tag, vv):
1010+
return True
1011+
return False
1012+
1013+
1014+
def clean_tagcontainer_of_nans(tag_container: TagContainer, val_dict: dict):
1015+
subtags_to_delete = []
1016+
for kk, vv in val_dict.items():
1017+
tag = tag_container.subtags[kk]
1018+
if not isinstance(tag, TagContainer):
1019+
if _isnan(vv):
1020+
print(f"Removing tag {kk} with nan value")
1021+
subtags_to_delete.append(kk)
1022+
elif not tag.can_repeat:
1023+
if has_nan_in_required_subtag(tag, vv) and tag_container.optional:
1024+
subtags_to_delete.append(kk)
1025+
else:
1026+
clean_tagcontainer_of_nans(tag, vv)
1027+
if len(tag.subtags) == 0:
1028+
print(f"Removing empty tag container {kk}")
1029+
subtags_to_delete.append(kk)
1030+
clean_tagcontainer_of_nans(tag, vv)
1031+
for kk in subtags_to_delete:
1032+
val_dict.pop(kk)
1033+
1034+
1035+
def clean_infile_of_nans(infile: JDFTXInfile) -> JDFTXInfile:
1036+
hasnans = []
1037+
for k, v in infile.items():
1038+
tag = get_tag_object_on_val(k, v)
1039+
if isinstance(tag, FloatTag) and _isnan(v):
1040+
print(f"Tag {k} has nan value")
1041+
elif isinstance(tag, TagContainer) and not tag.can_repeat:
1042+
_hasnans = _check_tagcontainer_for_nan(tag, v)
1043+
if len(_hasnans) > 0:
1044+
hasnans.append({k: _hasnans})
1045+
infile_cleaned = JDFTXInfile.from_dict(infile.as_dict())
1046+
1047+
# infile_cleaned = JDFTXInfile(infile)
1048+
1049+
for h in hasnans:
1050+
if isinstance(h, str):
1051+
infile_cleaned.pop(h)
1052+
elif isinstance(h, dict):
1053+
for _h in h:
1054+
tag = get_tag_object_on_val(_h, infile_cleaned[_h])
1055+
if _isnan(infile_cleaned[_h]):
1056+
infile_cleaned.pop(_h)
1057+
else:
1058+
clean_tagcontainer_of_nans(tag, infile_cleaned[_h])
1059+
return infile_cleaned
1060+
1061+
9151062
@dataclass
9161063
class JDFTXStructure(MSONable):
9171064
"""Object for representing the data in JDFTXStructure tags.

src/pymatgen/io/jdftx/jdftxinfile_default_inputs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# One example is fluid-minimize, which has different convergence thresholds and max iterations depending
88
# on the algorithm specified. For these tags, a second set of default values which can map partially
99
# filled tagcontainers to the set as filled by JDFTx is needed.
10+
# TODO: Make sure ion-width, which changes based on 'fluid' tag, is handled correctly.
1011
default_inputs = {
1112
"basis": "kpoint-dependent",
1213
"coords-type": "Lattice",

src/pymatgen/io/jdftx/jdftxinfile_master_format.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@
574574
write_tagname=False,
575575
optional=False,
576576
),
577-
"origin": TagContainer(
577+
"center": TagContainer(
578578
allow_list_representation=True,
579579
optional=True,
580580
subtags={

src/pymatgen/io/jdftx/jdftxoutfileslice.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ class JDFTXOutfileSlice:
291291
converged: bool | None = None
292292
structure: Structure | None = None
293293
initial_structure: Structure | None = None
294-
trajectory: Trajectory | None = None
294+
_trajectory: Trajectory | None = None
295295
electronic_output: dict | None = None
296296
eopt_type: str | None = None
297297
elecmindata: JElSteps | None = None
@@ -336,10 +336,10 @@ def _get_mu(self) -> None | float:
336336
def _from_out_slice(
337337
cls,
338338
text: list[str],
339+
skim_levels: list[str],
340+
skip_props: list[str],
339341
is_bgw: bool = False,
340342
none_on_error: bool = True,
341-
skim_levels: list[str] | None = None,
342-
skip_props: list[str] | None = None,
343343
) -> JDFTXOutfileSlice | None:
344344
"""
345345
Read slice of out file into a JDFTXOutfileSlice instance.
@@ -356,16 +356,14 @@ def _from_out_slice(
356356
instance = cls()
357357
instance.is_bgw = is_bgw
358358
try:
359-
instance._from_out_slice_init_all(text, skim_levels=skim_levels, skip_props=skip_props)
359+
instance._from_out_slice_init_all(text, skim_levels, skip_props)
360360
except (ValueError, IndexError, TypeError, KeyError, AttributeError):
361361
if none_on_error:
362362
return None
363363
raise
364364
return instance
365365

366-
def _from_out_slice_init_all(
367-
self, text: list[str], skim_levels: list[str] | None = None, skip_props: list[str] | None = None
368-
) -> None:
366+
def _from_out_slice_init_all(self, text: list[str], skim_levels: list[str], skip_props: list[str]) -> None:
369367
self._set_internal_infile(text)
370368
# self._set_min_settings(text)
371369
self._set_geomopt_vars(text)
@@ -470,12 +468,23 @@ def _set_trajectory(self) -> None:
470468
structures = [slc.structure for slc in self.jstrucs]
471469
constant_lattice = self.constant_lattice if self.constant_lattice is not None else False
472470
frame_properties = [slc.properties for slc in self.jstrucs]
473-
self.trajectory = Trajectory.from_structures(
471+
self._trajectory = Trajectory.from_structures(
474472
structures=structures,
475473
constant_lattice=constant_lattice,
476474
frame_properties=frame_properties,
477475
)
478476

477+
@property
478+
def trajectory(self) -> Trajectory | None:
479+
"""Return pymatgen trajectory object.
480+
481+
Returns:
482+
Trajectory: pymatgen Trajectory object containing intermediate Structure's of outfile slice calculation.
483+
"""
484+
if self._trajectory is None:
485+
self._set_trajectory()
486+
return self._trajectory
487+
479488
def _set_electronic_output(self) -> None:
480489
"""Return a dictionary with all relevant electronic information.
481490
@@ -884,9 +893,7 @@ def _get_initial_structure(self, text: list[str]) -> Structure | None:
884893
raise ValueError("Provided out file slice's inputs preamble does not contain input structure data.")
885894
return init_struc
886895

887-
def _set_jstrucs(
888-
self, text: list[str], skim_levels: list[str] | None = None, skip_props: list[str] | None = None
889-
) -> None:
896+
def _set_jstrucs(self, text: list[str], skim_levels: list[str], skip_props: list[str]) -> None:
890897
"""Set the jstrucs class variable.
891898
892899
Set the JStructures object to jstrucs from the out file text and all class attributes initialized from jstrucs.
@@ -903,17 +910,19 @@ def _set_jstrucs(
903910
expected_etype = "G"
904911
self.jstrucs = JOutStructures._from_out_slice(
905912
text,
913+
# skim_levels,
914+
# skip_props,
906915
opt_type=self.geom_opt_label,
907916
init_struc=self.initial_structure,
908917
is_md=self.is_md,
909918
expected_etype=expected_etype,
910919
skim_levels=skim_levels,
911920
skip_props=skip_props,
912921
)
913-
if self.etype is None:
914-
self.etype = self.jstrucs[-1].etype
915922
if self.jstrucs is not None:
916-
self._set_trajectory()
923+
if self.etype is None:
924+
self.etype = self.jstrucs[-1].etype
925+
# self._set_trajectory()
917926
self.mu = self._get_mu()
918927
for var in _jofs_atr_from_jstrucs:
919928
setattr(self, var, getattr(self.jstrucs, var))

src/pymatgen/io/jdftx/joutstructure.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ class JOutStructure(Structure):
108108
elec_grad_k: float | np.float64 | None = None
109109
elec_alpha: float | np.float64 | None = None
110110
elec_linmin: float | np.float64 | None = None
111-
structure: Structure | None = None
111+
_structure: Structure | None = None
112112
is_md: bool = False
113113
# thermostat_velocity: np.ndarray | None = None
114114
_velocities: list[NDArray[np.float64] | None] | None = None
@@ -386,7 +386,6 @@ def _from_text_slice(
386386
instance._parse_lattice_lines(line_collections["lattice"]["lines"])
387387
# Posns must be parsed before forces and lowdin analysis so that they can be stored in site_properties
388388
cur_species = instance._parse_posns_lines(line_collections["posns"]["lines"], cur_species)
389-
390389
if "forces" not in skip_props:
391390
instance._parse_forces_lines(line_collections["forces"]["lines"])
392391
if "lowdin" not in skip_props:
@@ -406,9 +405,16 @@ def _from_text_slice(
406405
# Set relevant properties in self.properties
407406
instance._fill_properties()
408407
# Done last in case of any changes to site-properties
409-
instance._init_structure(cur_species)
408+
# instance._init_structure(cur_species)
410409
return instance
411410

411+
@property
412+
def structure(self) -> Structure:
413+
"""Return structure attribute."""
414+
if self._structure is None:
415+
self._init_structure()
416+
return self._structure
417+
412418
def _init_e_sp_backup(self) -> None:
413419
"""Initialize self.e with coverage for single-point calculations."""
414420
err_str = None
@@ -967,11 +973,11 @@ def _fill_properties(self) -> None:
967973
"strain": self.strain,
968974
}
969975

970-
def _init_structure(self, cur_species: Sequence[Element | Species]) -> None:
976+
def _init_structure(self) -> None:
971977
"""Initialize structure attribute."""
972-
self.structure = Structure(
978+
self._structure = Structure(
973979
lattice=self.lattice,
974-
species=cur_species,
980+
species=self.species,
975981
coords=self.cart_coords,
976982
site_properties=self.site_properties,
977983
coords_are_cartesian=True,

0 commit comments

Comments
 (0)