33from __future__ import annotations
44import copy
55from dataclasses import dataclass
6- from math import isclose
76import numpy as np
87from emmet .core .vasp .calc_types .enums import TaskType
98
10- from pymatgen .io .validation .common import BaseValidator
9+ from pymatgen .io .validation .common import BaseValidator , BasicValidator
10+ from pymatgen .io .validation .vasp_defaults import InputCategory
1111
12- from typing import TYPE_CHECKING , Literal
12+ from typing import TYPE_CHECKING
1313
1414if TYPE_CHECKING :
1515 from typing import Any , Sequence
@@ -77,9 +77,9 @@ def check(self) -> None:
7777 Check calculation parameters related to INCAR input tags.
7878
7979 This first updates any parameter with a specified update method.
80- In practice, each INCAR tag in `vasp_defaults.yaml ` has a "tag"
81- attribute. If there is an update method
82- `UpdateParameterValues.update_{tag.replace(" ","_") }_params`,
80+ In practice, each INCAR tag in `VASP ` has a "tag" attribute.
81+ If there is an update method
82+ `UpdateParameterValues.update_{tag}_params`,
8383 all parameters with that tag will be updated.
8484
8585 Then after all missing values in the supplied parameters (padding
@@ -116,7 +116,7 @@ def check(self) -> None:
116116 simple_validator .check_parameter (
117117 reasons = self .reasons ,
118118 warnings = self .warnings ,
119- input_tag = working_params .defaults [key ].get ("alias" , key ) ,
119+ input_tag = working_params .defaults [key ].get ("alias" ) or key ,
120120 current_values = working_params .parameters [key ],
121121 reference_values = working_params .valid_values [key ],
122122 operations = working_params .defaults [key ]["operation" ],
@@ -125,6 +125,8 @@ def check(self) -> None:
125125 severity = working_params .defaults [key ]["severity" ],
126126 )
127127
128+ if key == "LCHIMAG" :
129+ print (self .reasons )
128130
129131class UpdateParameterValues :
130132 """
@@ -141,9 +143,9 @@ class UpdateParameterValues:
141143 This class allows one to mimic the VASP NBANDS functionality for computing
142144 NBANDS dynamically, and update both the current and reference values for NBANDs.
143145
144- To do this in a simple, automatic fashion, each parameter in `vasp_defaults.yaml ` has
146+ To do this in a simple, automatic fashion, each parameter in `VASP_DEFAULTS ` has
145147 a "tag" field. To update a set of parameters with a given tag, one then adds a function
146- to `GetParams` called `update_{tag.replace(" ","_") }_params`. For example, the "dft plus u"
148+ to `GetParams` called `update_{tag}_params`. For example, the "dft plus u"
147149 tag has an update function called `update_dft_plus_u_params`. If no such update method
148150 exists, that tag is skipped.
149151
@@ -201,7 +203,7 @@ def __init__(
201203 """
202204
203205 self .parameters = copy .deepcopy (parameters )
204- self .defaults = copy . deepcopy ( defaults )
206+ self .defaults = { k : v . __dict__ () for k , v in defaults . items ()}
205207 self .input_set = input_set
206208 self .vasp_version = vasp_version
207209 self .structure = structure
@@ -220,17 +222,14 @@ def __init__(
220222 def update_parameters_and_defaults (self ) -> None :
221223 """Update user parameters and defaults for tags with a specified update method."""
222224
223- self .categories : dict [str , list [str ]] = {}
225+ self .categories : dict [str , list [str ]] = {tag : [] for tag in InputCategory . __members__ }
224226 for key in self .defaults :
225- if self .defaults [key ]["tag" ] not in self .categories :
226- self .categories [self .defaults [key ]["tag" ]] = []
227227 self .categories [self .defaults [key ]["tag" ]].append (key )
228228
229- tag_order = [key .replace (" " , "_" ) for key in self .categories if key != "post_init" ] + ["post_init" ]
230229 # add defaults to parameters from the incar as needed
231230 self .add_defaults_to_parameters (valid_values_source = self .input_set .incar )
232231 # collect list of tags in parameter defaults
233- for tag in tag_order :
232+ for tag in InputCategory . __members__ :
234233 # check to see if update method for that tag exists, and if so, run it
235234 update_method_str = f"update_{ tag } _params"
236235 if hasattr (self , update_method_str ):
@@ -264,7 +263,7 @@ def update_dft_plus_u_params(self) -> None:
264263 if not self .parameters ["LDAU" ]:
265264 return
266265
267- for key in self .categories ["dft plus u " ]:
266+ for key in self .categories ["dft_plus_u " ]:
268267 valid_value = self .input_set .incar .get (key , self .defaults [key ]["value" ])
269268
270269 # TODO: ADK: is LDAUTYPE usually specified as a list??
@@ -785,198 +784,3 @@ def update_post_init_params(self):
785784 4.0 * self .defaults ["NBANDS" ]["value" ]
786785 )
787786 self .defaults ["EBREAK" ]["operation" ] = "auto fail"
788-
789-
790- class BasicValidator :
791- """
792- Compare test and reference values according to one or more operations.
793-
794- Parameters
795- -----------
796- global_tolerance : float = 1.e-4
797- Default tolerance for assessing approximate equality via math.isclose
798-
799- Attrs
800- -----------
801- operations : set[str]
802- List of acceptable operations, such as "==" for strict equality, or "in" to
803- check if a Sequence contains an element
804- """
805-
806- # avoiding dunder methods because these raise too many NotImplemented's
807- operations : set [str | None ] = {
808- "==" ,
809- ">" ,
810- ">=" ,
811- "<" ,
812- "<=" ,
813- "in" ,
814- "approx" ,
815- "auto fail" ,
816- None ,
817- }
818-
819- def __init__ (self , global_tolerance = 1.0e-4 ) -> None :
820- """Set math.isclose tolerance"""
821- self .tolerance = global_tolerance
822-
823- def _comparator (self , lhs : Any , operation : str , rhs : Any , ** kwargs ) -> bool :
824- """
825- Compare different values using one of a set of supported operations in self.operations.
826-
827- Parameters
828- -----------
829- lhs : Any
830- Left-hand side of the operation.
831- operation : str
832- Operation acting on rhs from lhs. For example, if operation is ">",
833- this returns (lhs > rhs).
834- rhs : Any
835- Right-hand of the operation.
836- kwargs
837- If needed, kwargs to pass to operation.
838- """
839- if operation is None :
840- c = True
841- elif operation == "auto fail" :
842- c = False
843- elif operation == "==" :
844- c = lhs == rhs
845- elif operation == ">" :
846- c = lhs > rhs
847- elif operation == ">=" :
848- c = lhs >= rhs
849- elif operation == "<" :
850- c = lhs < rhs
851- elif operation == "<=" :
852- c = lhs <= rhs
853- elif operation == "in" :
854- c = lhs in rhs
855- elif operation == "approx" :
856- c = isclose (lhs , rhs , ** kwargs )
857- return c
858-
859- def _check_parameter (
860- self ,
861- error_list : list [str ],
862- input_tag : str ,
863- current_value : Any ,
864- reference_value : Any ,
865- operation : str ,
866- tolerance : float | None = None ,
867- append_comments : str | None = None ,
868- ) -> None :
869- """
870- Determine validity of parameter subject to a single specified operation.
871-
872- Parameters
873- -----------
874- error_list : list[str]
875- A list of error/warning strings to update if a check fails.
876- input_tag : str
877- The name of the input tag which is being checked.
878- current_value : Any
879- The test value.
880- reference_value : Any
881- The value to compare the test value to.
882- operation : str
883- A valid operation in self.operations. For example, if operation = "<=",
884- this checks `current_value <= reference_value` (note order of values).
885- tolerance : float or None (default)
886- If None and operation == "approx", default tolerance to self.tolerance.
887- Otherwise, use the user-supplied tolerance.
888- append_comments : str or None (default)
889- Additional comments that may be helpful for the user to understand why
890- a check failed.
891- """
892-
893- append_comments = append_comments or ""
894-
895- if isinstance (current_value , str ):
896- current_value = current_value .upper ()
897-
898- kwargs : dict [str , Any ] = {}
899- if operation == "approx" and isinstance (current_value , float ):
900- kwargs .update ({"rel_tol" : tolerance or self .tolerance , "abs_tol" : 0.0 })
901- valid_value = self ._comparator (current_value , operation , reference_value , ** kwargs )
902-
903- if not valid_value :
904- error_list .append (
905- f"INPUT SETTINGS --> { input_tag } : is { current_value } , but should be "
906- f"{ '' if operation == 'auto fail' else operation + ' ' } { reference_value } ."
907- f"{ ' ' if len (append_comments ) > 0 else '' } { append_comments } "
908- )
909-
910- def check_parameter (
911- self ,
912- reasons : list [str ],
913- warnings : list [str ],
914- input_tag : str ,
915- current_values : Any ,
916- reference_values : Any ,
917- operations : str | list [str ],
918- tolerance : float = None ,
919- append_comments : str | None = None ,
920- severity : Literal ["reason" , "warning" ] = "reason" ,
921- ) -> None :
922- """
923- Determine validity of parameter according to one or more operations.
924-
925- Parameters
926- -----------
927- reasons : list[str]
928- A list of error strings to update if a check fails. These are higher
929- severity and would deprecate a calculation.
930- warnings : list[str]
931- A list of warning strings to update if a check fails. These are lower
932- severity and would flag a calculation for possible review.
933- input_tag : str
934- The name of the input tag which is being checked.
935- current_values : Any
936- The test value(s). If multiple operations are specified, must be a Sequence
937- of test values.
938- reference_values : Any
939- The value(s) to compare the test value(s) to. If multiple operations are
940- specified, must be a Sequence of reference values.
941- operations : str
942- One or more valid operations in self.operations. For example, if operations = "<=",
943- this checks `current_values <= reference_values` (note order of values).
944- Or, if operations == ["<=", ">"], this checks
945- ```
946- (
947- (current_values[0] <= reference_values[0])
948- and (current_values[1] > reference_values[1])
949- )
950- ```
951- tolerance : float or None (default)
952- Tolerance to use in math.isclose if any of operations is "approx". Defaults
953- to self.tolerance.
954- append_comments : str or None (default)
955- Additional comments that may be helpful for the user to understand why
956- a check failed.
957- severity : Literal["reason", "warning"]
958- If a calculation fails, the severity of failure. Directs output to
959- either reasons or warnings.
960- """
961-
962- severity_to_list = {"reason" : reasons , "warning" : warnings }
963-
964- if not isinstance (operations , list ):
965- operations = [operations ]
966- current_values = [current_values ]
967- reference_values = [reference_values ]
968-
969- unknown_operations = {operation for operation in operations if operation not in self .operations }
970- if len (unknown_operations ) > 0 :
971- raise ValueError ("Unknown operations:\n " + ", " .join ([f"{ uo } " for uo in unknown_operations ]))
972-
973- for iop in range (len (operations )):
974- self ._check_parameter (
975- error_list = severity_to_list [severity ],
976- input_tag = input_tag ,
977- current_value = current_values [iop ],
978- reference_value = reference_values [iop ],
979- operation = operations [iop ],
980- tolerance = tolerance ,
981- append_comments = append_comments ,
982- )
0 commit comments