Skip to content

Commit 24cd680

Browse files
authored
JDFTx Inputs - boundary value checking (#4410)
* Boundary checking * Remove commented out code * Remove commented out code * typo * mypy fix?
1 parent 35dd50b commit 24cd680

File tree

5 files changed

+198
-20
lines changed

5 files changed

+198
-20
lines changed

src/pymatgen/io/jdftx/generic_tags.py

Lines changed: 111 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,13 @@ def _validate_repeat(self, tag: str, value: Any) -> None:
112112
if not isinstance(value, list):
113113
raise TypeError(f"The '{tag}' tag can repeat but is not a list: '{value}'")
114114

115+
def validate_value_bounds(
116+
self,
117+
tag: str,
118+
value: Any,
119+
) -> tuple[bool, str]:
120+
return True, ""
121+
115122
@abstractmethod
116123
def read(self, tag: str, value_str: str) -> Any:
117124
"""Read and parse the value string for this tag.
@@ -365,7 +372,62 @@ def get_token_len(self) -> int:
365372

366373

367374
@dataclass
368-
class IntTag(AbstractTag):
375+
class AbstractNumericTag(AbstractTag):
376+
"""Abstract base class for numeric tags."""
377+
378+
lb: float | None = None # lower bound
379+
ub: float | None = None # upper bound
380+
lb_incl: bool = True # lower bound inclusive
381+
ub_incl: bool = True # upper bound inclusive
382+
383+
def val_is_within_bounds(self, value: float) -> bool:
384+
"""Check if the value is within the bounds.
385+
386+
Args:
387+
value (float | int): The value to check.
388+
389+
Returns:
390+
bool: True if the value is within the bounds, False otherwise.
391+
"""
392+
good = True
393+
if self.lb is not None:
394+
good = good and value >= self.lb if self.lb_incl else good and value > self.lb
395+
if self.ub is not None:
396+
good = good and value <= self.ub if self.ub_incl else good and value < self.ub
397+
return good
398+
399+
def get_invalid_value_error_str(self, tag: str, value: float) -> str:
400+
"""Raise a ValueError for the invalid value.
401+
402+
Args:
403+
tag (str): The tag to raise the ValueError for.
404+
value (float | int): The value to raise the ValueError for.
405+
"""
406+
err_str = f"Value '{value}' for tag '{tag}' is not within bounds"
407+
if self.ub is not None:
408+
err_str += f" {self.ub} >"
409+
if self.ub_incl:
410+
err_str += "="
411+
err_str += " x "
412+
if self.lb is not None:
413+
err_str += ">"
414+
if self.lb_incl:
415+
err_str += "="
416+
err_str += f" {self.lb}"
417+
return err_str
418+
419+
def validate_value_bounds(
420+
self,
421+
tag: str,
422+
value: Any,
423+
) -> tuple[bool, str]:
424+
if not self.val_is_within_bounds(value):
425+
return False, self.get_invalid_value_error_str(tag, value)
426+
return True, ""
427+
428+
429+
@dataclass
430+
class IntTag(AbstractNumericTag):
369431
"""Tag for integer values in JDFTx input files.
370432
371433
Tag for integer values in JDFTx input files.
@@ -411,6 +473,8 @@ def write(self, tag: str, value: Any) -> str:
411473
Returns:
412474
str: The tag and its value as a string.
413475
"""
476+
if not self.val_is_within_bounds(value):
477+
return ""
414478
return self._write(tag, value)
415479

416480
def get_token_len(self) -> int:
@@ -423,14 +487,13 @@ def get_token_len(self) -> int:
423487

424488

425489
@dataclass
426-
class FloatTag(AbstractTag):
490+
class FloatTag(AbstractNumericTag):
427491
"""Tag for float values in JDFTx input files.
428492
429493
Tag for float values in JDFTx input files.
430494
"""
431495

432496
prec: int | None = None
433-
minval: float | None = None
434497

435498
def validate_value_type(self, tag: str, value: Any, try_auto_type_fix: bool = False) -> tuple[str, bool, Any]:
436499
"""Validate the type of the value for this tag.
@@ -473,10 +536,7 @@ def write(self, tag: str, value: Any) -> str:
473536
Returns:
474537
str: The tag and its value as a string.
475538
"""
476-
# Returning an empty string instead of raising an error as value == self.minval
477-
# will cause JDFTx to throw an error, but the internal infile dumps the value as
478-
# as the minval if not set by the user.
479-
if (self.minval is not None) and (not value > self.minval):
539+
if not self.val_is_within_bounds(value):
480540
return ""
481541
# pre-convert to string: self.prec+3 is minimum room for:
482542
# - sign, 1 integer left of decimal, decimal, and precision.
@@ -598,6 +658,50 @@ def _validate_single_entry(
598658
types_checks.append(check)
599659
return tags_checked, types_checks, updated_value
600660

661+
def _validate_bounds_single_entry(self, value: dict | list[dict]) -> tuple[list[str], list[bool], list[str]]:
662+
if not isinstance(value, dict):
663+
raise TypeError(f"The value '{value}' (of type {type(value)}) must be a dict for this TagContainer!")
664+
tags_checked: list[str] = []
665+
types_checks: list[bool] = []
666+
reported_errors: list[str] = []
667+
for subtag, subtag_value in value.items():
668+
subtag_object = self.subtags[subtag]
669+
check, err_str = subtag_object.validate_value_bounds(subtag, subtag_value)
670+
tags_checked.append(subtag)
671+
types_checks.append(check)
672+
reported_errors.append(err_str)
673+
return tags_checked, types_checks, reported_errors
674+
675+
def validate_value_bounds(self, tag: str, value: Any) -> tuple[bool, str]:
676+
value_dict = value
677+
if self.can_repeat:
678+
self._validate_repeat(tag, value_dict)
679+
results = [self._validate_bounds_single_entry(x) for x in value_dict]
680+
tags_list_list: list[list[str]] = [result[0] for result in results]
681+
is_valids_list_list: list[list[bool]] = [result[1] for result in results]
682+
reported_errors_list: list[list[str]] = [result[2] for result in results]
683+
is_valid_out = all(all(x) for x in is_valids_list_list)
684+
errors_out = ",".join([",".join(x) for x in reported_errors_list])
685+
if not is_valid_out:
686+
warnmsg = "Invalid value(s) found for: "
687+
for i, x in enumerate(is_valids_list_list):
688+
if not all(x):
689+
for j, y in enumerate(x):
690+
if not y:
691+
warnmsg += f"{tags_list_list[i][j]} ({reported_errors_list[i][j]}) "
692+
warnings.warn(warnmsg, stacklevel=2)
693+
else:
694+
tags, is_valids, reported_errors = self._validate_bounds_single_entry(value_dict)
695+
is_valid_out = all(is_valids)
696+
errors_out = ",".join(reported_errors)
697+
if not is_valid_out:
698+
warnmsg = "Invalid value(s) found for: "
699+
for ii, xx in enumerate(is_valids):
700+
if not xx:
701+
warnmsg += f"{tags[ii]} ({reported_errors[ii]}) "
702+
warnings.warn(warnmsg, stacklevel=2)
703+
return is_valid_out, f"{tag}: {errors_out}"
704+
601705
def validate_value_type(self, tag: str, value: Any, try_auto_type_fix: bool = False) -> tuple[str, bool, Any]:
602706
"""Validate the type of the value for this tag.
603707

src/pymatgen/io/jdftx/inputs.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class JDFTXInfile(dict, MSONable):
5959
Essentially a dictionary with some helper functions.
6060
"""
6161

62-
path_parent: str | None = None # Only gets a value if JDFTXInfile is initializedf with from_file
62+
path_parent: str | None = None # Only gets a value if JDFTXInfile is initialized with from_file
6363

6464
def __init__(self, params: dict[str, Any] | None = None) -> None:
6565
"""
@@ -147,7 +147,7 @@ def _from_dict(cls, dct: dict[str, Any]) -> JDFTXInfile:
147147
return cls.get_list_representation(temp)
148148

149149
@classmethod
150-
def from_dict(cls, d: dict[str, Any]) -> JDFTXInfile:
150+
def from_dict(cls, d: dict[str, Any], validate_value_boundaries=True) -> JDFTXInfile:
151151
"""Create JDFTXInfile from a dictionary.
152152
153153
Args:
@@ -160,6 +160,8 @@ def from_dict(cls, d: dict[str, Any]) -> JDFTXInfile:
160160
for k, v in d.items():
161161
if k not in ("@module", "@class"):
162162
instance[k] = v
163+
if validate_value_boundaries:
164+
instance.validate_boundaries()
163165
return instance
164166

165167
def copy(self) -> JDFTXInfile:
@@ -213,6 +215,7 @@ def from_file(
213215
dont_require_structure: bool = False,
214216
sort_tags: bool = True,
215217
assign_path_parent: bool = True,
218+
validate_value_boundaries: bool = True,
216219
) -> JDFTXInfile:
217220
"""Read a JDFTXInfile object from a file.
218221
@@ -235,6 +238,7 @@ def from_file(
235238
dont_require_structure=dont_require_structure,
236239
sort_tags=sort_tags,
237240
path_parent=path_parent,
241+
validate_value_boundaries=validate_value_boundaries,
238242
)
239243

240244
@staticmethod
@@ -373,6 +377,7 @@ def from_str(
373377
dont_require_structure: bool = False,
374378
sort_tags: bool = True,
375379
path_parent: Path | None = None,
380+
validate_value_boundaries: bool = True,
376381
) -> JDFTXInfile:
377382
"""Read a JDFTXInfile object from a string.
378383
@@ -382,6 +387,7 @@ def from_str(
382387
sort_tags (bool, optional): Whether to sort the tags. Defaults to True.
383388
path_parent (Path, optional): Path to the parent directory of the input file for include tags.
384389
Defaults to None.
390+
validate_value_boundaries (bool, optional): Whether to validate the value boundaries. Defaults to True.
385391
386392
Returns:
387393
JDFTXInfile: The created JDFTXInfile object.
@@ -416,7 +422,10 @@ def from_str(
416422
raise ValueError("This input file is missing required structure tags")
417423
if sort_tags:
418424
params = {tag: params[tag] for tag in __TAG_LIST__ if tag in params}
419-
return cls(params)
425+
instance = cls(params)
426+
if validate_value_boundaries:
427+
instance.validate_boundaries()
428+
return instance
420429

421430
@classmethod
422431
def to_jdftxstructure(cls, jdftxinfile: JDFTXInfile, sort_structure: bool = False) -> JDFTXStructure:
@@ -573,6 +582,25 @@ def validate_tags(
573582
warnmsg += "(Check earlier warnings for more details)\n"
574583
warnings.warn(warnmsg, stacklevel=2)
575584

585+
def validate_boundaries(self) -> None:
586+
"""Validate the boundaries of the JDFTXInfile.
587+
588+
Validate the boundaries of the JDFTXInfile. This is a placeholder for future functionality.
589+
"""
590+
error_strs: list[str] = []
591+
for tag in self:
592+
tag_object = get_tag_object(tag)
593+
is_valid, error_str = tag_object.validate_value_bounds(tag, self[tag])
594+
if not is_valid:
595+
error_strs.append(error_str)
596+
if len(error_strs) > 0:
597+
err_cat = "\n".join(error_strs)
598+
raise ValueError(
599+
f"The following boundary errors were found in the JDFTXInfile:\n{err_cat}\n"
600+
"\n Hint - if you are reading from a JDFTX out file, you need to set validate_value_boundaries "
601+
"to False, as JDFTx will dump values at non-inclusive boundaries (ie 0.0 for values strictly > 0.0)."
602+
)
603+
576604
def strip_structure_tags(self) -> None:
577605
"""Strip all structural tags from the JDFTXInfile.
578606
@@ -614,7 +642,7 @@ def __setitem__(self, key: str, value: Any) -> None:
614642
if self._is_numeric(value):
615643
value = str(value)
616644
if not tag_object.can_repeat:
617-
value = [value]
645+
value = [value] # Shortcut to avoid writing a separate block for non-repeatable tags
618646
for v in value:
619647
processed_value = tag_object.read(key, v) if isinstance(v, str) else v
620648
params = self._store_value(params, tag_object, key, processed_value)

src/pymatgen/io/jdftx/jdftxinfile_ref_options.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -800,11 +800,11 @@
800800
"wolfeGradient": FloatTag(),
801801
}
802802
jdftxfluid_subtagdict = {
803-
"epsBulk": FloatTag(minval=1.0),
804-
"epsInf": FloatTag(),
803+
"epsBulk": FloatTag(lb=0.0, lb_incl=False),
804+
"epsInf": FloatTag(lb=1.0, lb_incl=True),
805805
"epsLJ": FloatTag(),
806806
"Nnorm": FloatTag(),
807-
"pMol": FloatTag(),
807+
"pMol": FloatTag(lb=0.0, lb_incl=True),
808808
"poleEl": TagContainer(
809809
can_repeat=True,
810810
write_tagname=True,
@@ -814,13 +814,13 @@
814814
"A0": FloatTag(write_tagname=False, optional=False),
815815
},
816816
),
817-
"Pvap": FloatTag(minval=0.0),
817+
"Pvap": FloatTag(lb=0.0, lb_incl=False),
818818
"quad_nAlpha": FloatTag(),
819819
"quad_nBeta": FloatTag(),
820820
"quad_nGamma": FloatTag(),
821821
"representation": TagContainer(subtags={"MuEps": FloatTag(), "Pomega": FloatTag(), "PsiAlpha": FloatTag()}),
822-
"Res": FloatTag(minval=0.0),
823-
"Rvdw": FloatTag(),
822+
"Res": FloatTag(lb=0.0, lb_incl=False),
823+
"Rvdw": FloatTag(lb=0.0, lb_incl=False),
824824
"s2quadType": StrTag(
825825
options=[
826826
"10design60",
@@ -844,7 +844,7 @@
844844
"Tetrahedron",
845845
]
846846
),
847-
"sigmaBulk": FloatTag(minval=0.0),
848-
"tauNuc": FloatTag(),
847+
"sigmaBulk": FloatTag(lb=0.0, lb_incl=False),
848+
"tauNuc": FloatTag(lb=0.0, lb_incl=False),
849849
"translation": StrTag(options=["ConstantSpline", "Fourier", "LinearSpline"]),
850850
}

src/pymatgen/io/jdftx/jdftxoutfileslice.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,9 @@ def _set_internal_infile(self, text: list[str]) -> None:
402402
break
403403
if end_line_idx is None:
404404
raise ValueError("Calculation did not begin for this out file slice.")
405-
self.infile = JDFTXInfile.from_str("\n".join(text[start_line_idx:end_line_idx]))
405+
self.infile = JDFTXInfile.from_str(
406+
"\n".join(text[start_line_idx:end_line_idx]), validate_value_boundaries=False
407+
)
406408
self.constant_lattice = True
407409
if "lattice-minimize" in self.infile:
408410
latsteps = self.infile["lattice-minimize"]["nIterations"]

tests/io/jdftx/test_generic_tags.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,3 +441,47 @@ def test_multiformattagcontainer():
441441
"Check your inputs and/or MASTER_TAG_LIST!"
442442
with pytest.raises(ValueError, match=re.escape(err_str)):
443443
mftg._determine_format_option(tag, value)
444+
445+
446+
def test_boundary_checking():
447+
# Check that non-numeric tag returns valid always
448+
tag = "barbie"
449+
value = "notanumber"
450+
valtag = StrTag()
451+
assert valtag.validate_value_bounds(tag, value)[0] is True
452+
# Check that numeric tags can return False
453+
value = 0.0
454+
valtag = FloatTag(lb=1.0)
455+
assert valtag.validate_value_bounds(tag, value)[0] is False
456+
valtag = FloatTag(lb=0.0, lb_incl=False)
457+
assert valtag.validate_value_bounds(tag, value)[0] is False
458+
valtag = FloatTag(ub=-1.0)
459+
assert valtag.validate_value_bounds(tag, value)[0] is False
460+
valtag = FloatTag(ub=0.0, ub_incl=False)
461+
assert valtag.validate_value_bounds(tag, value)[0] is False
462+
# Check that numeric tags can return True
463+
valtag = FloatTag(lb=0.0, lb_incl=True)
464+
assert valtag.validate_value_bounds(tag, value)[0] is True
465+
valtag = FloatTag(ub=0.0, ub_incl=True)
466+
assert valtag.validate_value_bounds(tag, value)[0] is True
467+
valtag = FloatTag(lb=-1.0, ub=1.0)
468+
assert valtag.validate_value_bounds(tag, value)[0] is True
469+
# Check functionality for tagcontainers
470+
tagcontainer = TagContainer(
471+
subtags={
472+
"ken": FloatTag(lb=0.0, lb_incl=True),
473+
"allan": StrTag(),
474+
"skipper": FloatTag(ub=1.0, ub_incl=True, lb=-1.0, lb_incl=False),
475+
},
476+
)
477+
valid, errors = tagcontainer.validate_value_bounds(tag, {"ken": -1.0, "allan": "notanumber", "skipper": 2.0})
478+
assert valid is False
479+
assert "allan" not in errors
480+
assert "ken" in errors
481+
assert "x >= 0.0" in errors
482+
assert "skipper" in errors
483+
assert "1.0 >= x > -1.0" in errors
484+
# Make sure tags will never write a value that is out of bounds
485+
valtag = FloatTag(lb=-1.0, ub=1.0)
486+
assert len(valtag.write(tag, 0.0))
487+
assert not len(valtag.write(tag, 2.0))

0 commit comments

Comments
 (0)