Skip to content

Commit 7d797d4

Browse files
authored
Merge pull request #112 from benrich37/missing-unit-conversion
Slightly cleaner nan stripping
2 parents bb39b73 + 598bea5 commit 7d797d4

File tree

2 files changed

+99
-104
lines changed

2 files changed

+99
-104
lines changed

src/pymatgen/io/jdftx/inputs.py

Lines changed: 75 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,7 @@ class is written.
2121
from pymatgen.core import Lattice, Structure
2222
from pymatgen.core.periodic_table import Element
2323
from pymatgen.core.units import bohr_to_ang
24-
from pymatgen.io.jdftx.generic_tags import (
25-
AbstractTag,
26-
BoolTagContainer,
27-
DumpTagContainer,
28-
FloatTag,
29-
MultiformatTag,
30-
TagContainer,
31-
)
24+
from pymatgen.io.jdftx.generic_tags import AbstractTag, BoolTagContainer, DumpTagContainer, MultiformatTag, TagContainer
3225
from pymatgen.io.jdftx.jdftxinfile_default_inputs import default_inputs
3326
from pymatgen.io.jdftx.jdftxinfile_master_format import (
3427
__PHONON_TAGS__,
@@ -958,105 +951,84 @@ def movescale_array_to_selective_dynamics_site_prop(movescale: ArrayLike[int | f
958951
return selective_dynamics
959952

960953

961-
# def _strip_nans(infile: JDFTXInfile) -> JDFTXInfile:
962-
# for k, v in infile.items():
963-
# tag_object = get_tag_object_on_val(k, v)
964-
# if isinstance(v, dict):
965-
# infile[k] = _strip_nans(v)
966-
# if isinstance(v, float) and np.isnan(v):
967-
# infile[k] = None
968-
# elif isinstance(v, list):
969-
# infile[k] = [x for x in v if not (isinstance(x, float) and np.isnan(x))]
970-
# elif isinstance(v, dict):
971-
# infile[k] = _strip_nans(v)
972-
973-
# def _has_nans(value: float | list) -> bool:
974-
# if isinstance(value, float) and np.isnan(value):
975-
# return True
976-
# elif isinstance(value, list):
977-
# return any(_has_nans(x) for x in value)
978-
# return False
979-
980-
981-
def _isnan(x):
982-
try:
983-
return np.isnan(x)
984-
except TypeError:
985-
return False
986-
987-
988-
def _check_tagcontainer_for_nan(tag_container: TagContainer, val_dict: dict):
989-
hasnans = []
990-
for kk, vv in val_dict.items():
991-
tag = tag_container.subtags[kk]
992-
if not isinstance(tag, TagContainer):
993-
if _isnan(vv):
994-
print(f"Tag {kk} has nan value")
995-
hasnans.append(kk)
996-
elif not tag.can_repeat:
997-
_hasnans = _check_tagcontainer_for_nan(tag, vv)
998-
if len(_hasnans) > 0:
999-
hasnans.append({kk: _hasnans})
1000-
return hasnans
1001-
1002-
1003-
def has_nan_in_required_subtag(tag_container: TagContainer, val_dict: dict):
1004-
for kk, vv in val_dict.items():
1005-
tag = tag_container.subtags[kk]
1006-
if not isinstance(tag, TagContainer):
1007-
if _isnan(vv) and not tag.optional:
1008-
return True
1009-
elif not tag.can_repeat and has_nan_in_required_subtag(tag, vv):
1010-
return True
1011-
return False
954+
def clean_infile_of_nans(infile: JDFTXInfile) -> JDFTXInfile:
955+
infile_cleaned = infile.copy()
956+
popmaps = get_popmaps(infile_cleaned)
957+
apply_popmaps(infile_cleaned, popmaps)
958+
return infile_cleaned
1012959

1013960

1014-
def clean_tagcontainer_of_nans(tag_container: TagContainer, val_dict: dict):
1015-
subtags_to_delete = []
1016-
for kk, vv in val_dict.items():
1017-
tag = tag_container.subtags[kk]
1018-
if not isinstance(tag, TagContainer):
1019-
if _isnan(vv):
1020-
print(f"Removing tag {kk} with nan value")
1021-
subtags_to_delete.append(kk)
1022-
elif not tag.can_repeat:
1023-
if has_nan_in_required_subtag(tag, vv) and tag_container.optional:
1024-
subtags_to_delete.append(kk)
1025-
else:
1026-
clean_tagcontainer_of_nans(tag, vv)
1027-
if len(tag.subtags) == 0:
1028-
print(f"Removing empty tag container {kk}")
1029-
subtags_to_delete.append(kk)
1030-
clean_tagcontainer_of_nans(tag, vv)
1031-
for kk in subtags_to_delete:
1032-
val_dict.pop(kk)
961+
def hasnan(val: Any) -> bool:
962+
if type(val) in [float, int]:
963+
return np.isnan(val)
964+
if isinstance(val, list):
965+
for v in val:
966+
if hasnan(v):
967+
return True
968+
elif isinstance(val, dict):
969+
for v in val.values():
970+
if hasnan(v):
971+
return True
972+
return False
1033973

1034974

1035-
def clean_infile_of_nans(infile: JDFTXInfile) -> JDFTXInfile:
1036-
hasnans = []
1037-
for k, v in infile.items():
1038-
tag = get_tag_object_on_val(k, v)
1039-
if isinstance(tag, FloatTag) and _isnan(v):
1040-
print(f"Tag {k} has nan value")
1041-
elif isinstance(tag, TagContainer) and not tag.can_repeat:
1042-
_hasnans = _check_tagcontainer_for_nan(tag, v)
1043-
if len(_hasnans) > 0:
1044-
hasnans.append({k: _hasnans})
1045-
infile_cleaned = JDFTXInfile.from_dict(infile.as_dict())
1046-
1047-
# infile_cleaned = JDFTXInfile(infile)
1048-
1049-
for h in hasnans:
1050-
if isinstance(h, str):
1051-
infile_cleaned.pop(h)
1052-
elif isinstance(h, dict):
1053-
for _h in h:
1054-
tag = get_tag_object_on_val(_h, infile_cleaned[_h])
1055-
if _isnan(infile_cleaned[_h]):
1056-
infile_cleaned.pop(_h)
1057-
else:
1058-
clean_tagcontainer_of_nans(tag, infile_cleaned[_h])
1059-
return infile_cleaned
975+
def get_nanmaps(val: Any) -> list[list[str | int]]:
976+
nanmaps: list[list[str | int]] = []
977+
generator = None
978+
if type(val) in [float, int]:
979+
if hasnan(val):
980+
nanmaps.append([])
981+
elif isinstance(val, dict):
982+
generator = [(k, v) for k, v in val.items()]
983+
elif isinstance(val, list):
984+
generator = [(i, v) for i, v in enumerate(val)]
985+
elif type(val in [str]):
986+
pass
987+
else:
988+
raise ValueError(f"Unexpected type {type(val)} of value {val}")
989+
if generator is not None:
990+
for k, v in generator:
991+
submaps = get_nanmaps(v)
992+
for submap in submaps:
993+
add_map = [k]
994+
add_map.extend(submap)
995+
nanmaps.append(add_map)
996+
return nanmaps
997+
998+
999+
def get_popmaps(infile: JDFTXInfile) -> list[list[str | int]]:
1000+
popmaps = []
1001+
for tag in infile:
1002+
nanmaps = get_nanmaps(infile[tag])
1003+
for nanmap in nanmaps:
1004+
add_map = [tag]
1005+
add_map.extend(nanmap)
1006+
popmaps.append(add_map)
1007+
return popmaps
1008+
1009+
1010+
def apply_popmaps(edit_infile: JDFTXInfile, popmaps: list[list[str | int]]) -> None:
1011+
for popmap in popmaps:
1012+
popmap_use = popmap
1013+
popmap_destination = edit_infile
1014+
to = get_tag_object_on_val(popmap[0], edit_infile[popmap[0]])
1015+
if isinstance(to, TagContainer):
1016+
subtag = to
1017+
popmap_destination = popmap_destination[popmap[0]]
1018+
popmap_route = popmap[1:]
1019+
popmap_use = [popmap[0]]
1020+
step = None
1021+
for i, step in enumerate(popmap_route):
1022+
if isinstance(step, str):
1023+
subtag = subtag.subtags[step]
1024+
if not subtag.optional:
1025+
step = popmap_route[i - 1]
1026+
break
1027+
popmap_use.append(step)
1028+
popmap_destination = edit_infile
1029+
for step in popmap_use[:-1]:
1030+
popmap_destination = popmap_destination[step]
1031+
popmap_destination.pop(popmap_use[-1], None)
10601032

10611033

10621034
@dataclass

tests/io/jdftx/test_jdftxinfile.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99

1010
from pymatgen.core.structure import Site, Structure
1111
from pymatgen.core.units import bohr_to_ang
12-
from pymatgen.io.jdftx.inputs import JDFTXInfile, JDFTXStructure, selective_dynamics_site_prop_to_jdftx_interpretable
12+
from pymatgen.io.jdftx.inputs import (
13+
JDFTXInfile,
14+
JDFTXStructure,
15+
clean_infile_of_nans,
16+
selective_dynamics_site_prop_to_jdftx_interpretable,
17+
)
1318
from pymatgen.io.jdftx.jdftxinfile_default_inputs import antoinePvap, default_inputs
1419
from pymatgen.io.jdftx.jdftxinfile_master_format import get_tag_object
1520

@@ -615,3 +620,21 @@ def test_jdftxinfile_comparison():
615620

616621
def test_antoine_pvap():
617622
assert_same_value(antoinePvap(298, 7.31549, 1794.88, -34.764), 1.06736e-10)
623+
624+
625+
def test_nan_stripping():
626+
jif = JDFTXInfile.from_file(ex_infile1_fname)
627+
# Single optional subtag being nan should only remove that subtag
628+
jif["elec-cutoff"]["EcutRho"] = np.nan
629+
jif = clean_infile_of_nans(jif)
630+
assert "EcutRho" not in jif["elec-cutoff"]
631+
# Single required subtag being nan should remove the whole tag
632+
jif["latt-move-scale"]["s0"] = np.nan
633+
jif = clean_infile_of_nans(jif)
634+
assert "latt-move-scale" not in jif
635+
jif.pop("fluid-solvent")
636+
jif.read_line("fluid-solvent H2O 55.338 ScalarEOS epsBulk nan")
637+
assert "epsBulk" in jif["fluid-solvent"][0]
638+
jif = clean_infile_of_nans(jif)
639+
assert "epsBulk" not in jif["fluid-solvent"][0]
640+
print(jif)

0 commit comments

Comments
 (0)