@@ -112,6 +112,13 @@ def _validate_repeat(self, tag: str, value: Any) -> None:
112
112
if not isinstance (value , list ):
113
113
raise TypeError (f"The '{ tag } ' tag can repeat but is not a list: '{ value } '" )
114
114
115
+ def validate_value_bounds (
116
+ self ,
117
+ tag : str ,
118
+ value : Any ,
119
+ ) -> tuple [bool , str ]:
120
+ return True , ""
121
+
115
122
@abstractmethod
116
123
def read (self , tag : str , value_str : str ) -> Any :
117
124
"""Read and parse the value string for this tag.
@@ -365,7 +372,62 @@ def get_token_len(self) -> int:
365
372
366
373
367
374
@dataclass
368
- class IntTag (AbstractTag ):
375
+ class AbstractNumericTag (AbstractTag ):
376
+ """Abstract base class for numeric tags."""
377
+
378
+ lb : float | None = None # lower bound
379
+ ub : float | None = None # upper bound
380
+ lb_incl : bool = True # lower bound inclusive
381
+ ub_incl : bool = True # upper bound inclusive
382
+
383
+ def val_is_within_bounds (self , value : float ) -> bool :
384
+ """Check if the value is within the bounds.
385
+
386
+ Args:
387
+ value (float | int): The value to check.
388
+
389
+ Returns:
390
+ bool: True if the value is within the bounds, False otherwise.
391
+ """
392
+ good = True
393
+ if self .lb is not None :
394
+ good = good and value >= self .lb if self .lb_incl else good and value > self .lb
395
+ if self .ub is not None :
396
+ good = good and value <= self .ub if self .ub_incl else good and value < self .ub
397
+ return good
398
+
399
+ def get_invalid_value_error_str (self , tag : str , value : float ) -> str :
400
+ """Raise a ValueError for the invalid value.
401
+
402
+ Args:
403
+ tag (str): The tag to raise the ValueError for.
404
+ value (float | int): The value to raise the ValueError for.
405
+ """
406
+ err_str = f"Value '{ value } ' for tag '{ tag } ' is not within bounds"
407
+ if self .ub is not None :
408
+ err_str += f" { self .ub } >"
409
+ if self .ub_incl :
410
+ err_str += "="
411
+ err_str += " x "
412
+ if self .lb is not None :
413
+ err_str += ">"
414
+ if self .lb_incl :
415
+ err_str += "="
416
+ err_str += f" { self .lb } "
417
+ return err_str
418
+
419
+ def validate_value_bounds (
420
+ self ,
421
+ tag : str ,
422
+ value : Any ,
423
+ ) -> tuple [bool , str ]:
424
+ if not self .val_is_within_bounds (value ):
425
+ return False , self .get_invalid_value_error_str (tag , value )
426
+ return True , ""
427
+
428
+
429
+ @dataclass
430
+ class IntTag (AbstractNumericTag ):
369
431
"""Tag for integer values in JDFTx input files.
370
432
371
433
Tag for integer values in JDFTx input files.
@@ -411,6 +473,8 @@ def write(self, tag: str, value: Any) -> str:
411
473
Returns:
412
474
str: The tag and its value as a string.
413
475
"""
476
+ if not self .val_is_within_bounds (value ):
477
+ return ""
414
478
return self ._write (tag , value )
415
479
416
480
def get_token_len (self ) -> int :
@@ -423,14 +487,13 @@ def get_token_len(self) -> int:
423
487
424
488
425
489
@dataclass
426
- class FloatTag (AbstractTag ):
490
+ class FloatTag (AbstractNumericTag ):
427
491
"""Tag for float values in JDFTx input files.
428
492
429
493
Tag for float values in JDFTx input files.
430
494
"""
431
495
432
496
prec : int | None = None
433
- minval : float | None = None
434
497
435
498
def validate_value_type (self , tag : str , value : Any , try_auto_type_fix : bool = False ) -> tuple [str , bool , Any ]:
436
499
"""Validate the type of the value for this tag.
@@ -473,10 +536,7 @@ def write(self, tag: str, value: Any) -> str:
473
536
Returns:
474
537
str: The tag and its value as a string.
475
538
"""
476
- # Returning an empty string instead of raising an error as value == self.minval
477
- # will cause JDFTx to throw an error, but the internal infile dumps the value as
478
- # as the minval if not set by the user.
479
- if (self .minval is not None ) and (not value > self .minval ):
539
+ if not self .val_is_within_bounds (value ):
480
540
return ""
481
541
# pre-convert to string: self.prec+3 is minimum room for:
482
542
# - sign, 1 integer left of decimal, decimal, and precision.
@@ -598,6 +658,50 @@ def _validate_single_entry(
598
658
types_checks .append (check )
599
659
return tags_checked , types_checks , updated_value
600
660
661
+ def _validate_bounds_single_entry (self , value : dict | list [dict ]) -> tuple [list [str ], list [bool ], list [str ]]:
662
+ if not isinstance (value , dict ):
663
+ raise TypeError (f"The value '{ value } ' (of type { type (value )} ) must be a dict for this TagContainer!" )
664
+ tags_checked : list [str ] = []
665
+ types_checks : list [bool ] = []
666
+ reported_errors : list [str ] = []
667
+ for subtag , subtag_value in value .items ():
668
+ subtag_object = self .subtags [subtag ]
669
+ check , err_str = subtag_object .validate_value_bounds (subtag , subtag_value )
670
+ tags_checked .append (subtag )
671
+ types_checks .append (check )
672
+ reported_errors .append (err_str )
673
+ return tags_checked , types_checks , reported_errors
674
+
675
+ def validate_value_bounds (self , tag : str , value : Any ) -> tuple [bool , str ]:
676
+ value_dict = value
677
+ if self .can_repeat :
678
+ self ._validate_repeat (tag , value_dict )
679
+ results = [self ._validate_bounds_single_entry (x ) for x in value_dict ]
680
+ tags_list_list : list [list [str ]] = [result [0 ] for result in results ]
681
+ is_valids_list_list : list [list [bool ]] = [result [1 ] for result in results ]
682
+ reported_errors_list : list [list [str ]] = [result [2 ] for result in results ]
683
+ is_valid_out = all (all (x ) for x in is_valids_list_list )
684
+ errors_out = "," .join (["," .join (x ) for x in reported_errors_list ])
685
+ if not is_valid_out :
686
+ warnmsg = "Invalid value(s) found for: "
687
+ for i , x in enumerate (is_valids_list_list ):
688
+ if not all (x ):
689
+ for j , y in enumerate (x ):
690
+ if not y :
691
+ warnmsg += f"{ tags_list_list [i ][j ]} ({ reported_errors_list [i ][j ]} ) "
692
+ warnings .warn (warnmsg , stacklevel = 2 )
693
+ else :
694
+ tags , is_valids , reported_errors = self ._validate_bounds_single_entry (value_dict )
695
+ is_valid_out = all (is_valids )
696
+ errors_out = "," .join (reported_errors )
697
+ if not is_valid_out :
698
+ warnmsg = "Invalid value(s) found for: "
699
+ for ii , xx in enumerate (is_valids ):
700
+ if not xx :
701
+ warnmsg += f"{ tags [ii ]} ({ reported_errors [ii ]} ) "
702
+ warnings .warn (warnmsg , stacklevel = 2 )
703
+ return is_valid_out , f"{ tag } : { errors_out } "
704
+
601
705
def validate_value_type (self , tag : str , value : Any , try_auto_type_fix : bool = False ) -> tuple [str , bool , Any ]:
602
706
"""Validate the type of the value for this tag.
603
707
0 commit comments