3
3
from __future__ import annotations
4
4
import copy
5
5
from dataclasses import dataclass
6
- from math import isclose
7
6
import numpy as np
8
7
from emmet .core .vasp .calc_types .enums import TaskType
9
8
10
- from pymatgen .io .validation .common import BaseValidator
9
+ from pymatgen .io .validation .common import BaseValidator , BasicValidator
10
+ from pymatgen .io .validation .vasp_defaults import InputCategory
11
11
12
- from typing import TYPE_CHECKING , Literal
12
+ from typing import TYPE_CHECKING
13
13
14
14
if TYPE_CHECKING :
15
15
from typing import Any , Sequence
@@ -77,9 +77,9 @@ def check(self) -> None:
77
77
Check calculation parameters related to INCAR input tags.
78
78
79
79
This first updates any parameter with a specified update method.
80
- In practice, each INCAR tag in `vasp_defaults.yaml ` has a "tag"
81
- attribute. If there is an update method
82
- `UpdateParameterValues.update_{tag.replace(" ","_") }_params`,
80
+ In practice, each INCAR tag in `VASP ` has a "tag" attribute.
81
+ If there is an update method
82
+ `UpdateParameterValues.update_{tag}_params`,
83
83
all parameters with that tag will be updated.
84
84
85
85
Then after all missing values in the supplied parameters (padding
@@ -116,7 +116,7 @@ def check(self) -> None:
116
116
simple_validator .check_parameter (
117
117
reasons = self .reasons ,
118
118
warnings = self .warnings ,
119
- input_tag = working_params .defaults [key ].get ("alias" , key ) ,
119
+ input_tag = working_params .defaults [key ].get ("alias" ) or key ,
120
120
current_values = working_params .parameters [key ],
121
121
reference_values = working_params .valid_values [key ],
122
122
operations = working_params .defaults [key ]["operation" ],
@@ -125,6 +125,8 @@ def check(self) -> None:
125
125
severity = working_params .defaults [key ]["severity" ],
126
126
)
127
127
128
+ if key == "LCHIMAG" :
129
+ print (self .reasons )
128
130
129
131
class UpdateParameterValues :
130
132
"""
@@ -141,9 +143,9 @@ class UpdateParameterValues:
141
143
This class allows one to mimic the VASP NBANDS functionality for computing
142
144
NBANDS dynamically, and update both the current and reference values for NBANDs.
143
145
144
- To do this in a simple, automatic fashion, each parameter in `vasp_defaults.yaml ` has
146
+ To do this in a simple, automatic fashion, each parameter in `VASP_DEFAULTS ` has
145
147
a "tag" field. To update a set of parameters with a given tag, one then adds a function
146
- to `GetParams` called `update_{tag.replace(" ","_") }_params`. For example, the "dft plus u"
148
+ to `GetParams` called `update_{tag}_params`. For example, the "dft plus u"
147
149
tag has an update function called `update_dft_plus_u_params`. If no such update method
148
150
exists, that tag is skipped.
149
151
@@ -201,7 +203,7 @@ def __init__(
201
203
"""
202
204
203
205
self .parameters = copy .deepcopy (parameters )
204
- self .defaults = copy . deepcopy ( defaults )
206
+ self .defaults = { k : v . __dict__ () for k , v in defaults . items ()}
205
207
self .input_set = input_set
206
208
self .vasp_version = vasp_version
207
209
self .structure = structure
@@ -220,17 +222,14 @@ def __init__(
220
222
def update_parameters_and_defaults (self ) -> None :
221
223
"""Update user parameters and defaults for tags with a specified update method."""
222
224
223
- self .categories : dict [str , list [str ]] = {}
225
+ self .categories : dict [str , list [str ]] = {tag : [] for tag in InputCategory . __members__ }
224
226
for key in self .defaults :
225
- if self .defaults [key ]["tag" ] not in self .categories :
226
- self .categories [self .defaults [key ]["tag" ]] = []
227
227
self .categories [self .defaults [key ]["tag" ]].append (key )
228
228
229
- tag_order = [key .replace (" " , "_" ) for key in self .categories if key != "post_init" ] + ["post_init" ]
230
229
# add defaults to parameters from the incar as needed
231
230
self .add_defaults_to_parameters (valid_values_source = self .input_set .incar )
232
231
# collect list of tags in parameter defaults
233
- for tag in tag_order :
232
+ for tag in InputCategory . __members__ :
234
233
# check to see if update method for that tag exists, and if so, run it
235
234
update_method_str = f"update_{ tag } _params"
236
235
if hasattr (self , update_method_str ):
@@ -264,7 +263,7 @@ def update_dft_plus_u_params(self) -> None:
264
263
if not self .parameters ["LDAU" ]:
265
264
return
266
265
267
- for key in self .categories ["dft plus u " ]:
266
+ for key in self .categories ["dft_plus_u " ]:
268
267
valid_value = self .input_set .incar .get (key , self .defaults [key ]["value" ])
269
268
270
269
# TODO: ADK: is LDAUTYPE usually specified as a list??
@@ -785,198 +784,3 @@ def update_post_init_params(self):
785
784
4.0 * self .defaults ["NBANDS" ]["value" ]
786
785
)
787
786
self .defaults ["EBREAK" ]["operation" ] = "auto fail"
788
-
789
-
790
- class BasicValidator :
791
- """
792
- Compare test and reference values according to one or more operations.
793
-
794
- Parameters
795
- -----------
796
- global_tolerance : float = 1.e-4
797
- Default tolerance for assessing approximate equality via math.isclose
798
-
799
- Attrs
800
- -----------
801
- operations : set[str]
802
- List of acceptable operations, such as "==" for strict equality, or "in" to
803
- check if a Sequence contains an element
804
- """
805
-
806
- # avoiding dunder methods because these raise too many NotImplemented's
807
- operations : set [str | None ] = {
808
- "==" ,
809
- ">" ,
810
- ">=" ,
811
- "<" ,
812
- "<=" ,
813
- "in" ,
814
- "approx" ,
815
- "auto fail" ,
816
- None ,
817
- }
818
-
819
- def __init__ (self , global_tolerance = 1.0e-4 ) -> None :
820
- """Set math.isclose tolerance"""
821
- self .tolerance = global_tolerance
822
-
823
- def _comparator (self , lhs : Any , operation : str , rhs : Any , ** kwargs ) -> bool :
824
- """
825
- Compare different values using one of a set of supported operations in self.operations.
826
-
827
- Parameters
828
- -----------
829
- lhs : Any
830
- Left-hand side of the operation.
831
- operation : str
832
- Operation acting on rhs from lhs. For example, if operation is ">",
833
- this returns (lhs > rhs).
834
- rhs : Any
835
- Right-hand of the operation.
836
- kwargs
837
- If needed, kwargs to pass to operation.
838
- """
839
- if operation is None :
840
- c = True
841
- elif operation == "auto fail" :
842
- c = False
843
- elif operation == "==" :
844
- c = lhs == rhs
845
- elif operation == ">" :
846
- c = lhs > rhs
847
- elif operation == ">=" :
848
- c = lhs >= rhs
849
- elif operation == "<" :
850
- c = lhs < rhs
851
- elif operation == "<=" :
852
- c = lhs <= rhs
853
- elif operation == "in" :
854
- c = lhs in rhs
855
- elif operation == "approx" :
856
- c = isclose (lhs , rhs , ** kwargs )
857
- return c
858
-
859
- def _check_parameter (
860
- self ,
861
- error_list : list [str ],
862
- input_tag : str ,
863
- current_value : Any ,
864
- reference_value : Any ,
865
- operation : str ,
866
- tolerance : float | None = None ,
867
- append_comments : str | None = None ,
868
- ) -> None :
869
- """
870
- Determine validity of parameter subject to a single specified operation.
871
-
872
- Parameters
873
- -----------
874
- error_list : list[str]
875
- A list of error/warning strings to update if a check fails.
876
- input_tag : str
877
- The name of the input tag which is being checked.
878
- current_value : Any
879
- The test value.
880
- reference_value : Any
881
- The value to compare the test value to.
882
- operation : str
883
- A valid operation in self.operations. For example, if operation = "<=",
884
- this checks `current_value <= reference_value` (note order of values).
885
- tolerance : float or None (default)
886
- If None and operation == "approx", default tolerance to self.tolerance.
887
- Otherwise, use the user-supplied tolerance.
888
- append_comments : str or None (default)
889
- Additional comments that may be helpful for the user to understand why
890
- a check failed.
891
- """
892
-
893
- append_comments = append_comments or ""
894
-
895
- if isinstance (current_value , str ):
896
- current_value = current_value .upper ()
897
-
898
- kwargs : dict [str , Any ] = {}
899
- if operation == "approx" and isinstance (current_value , float ):
900
- kwargs .update ({"rel_tol" : tolerance or self .tolerance , "abs_tol" : 0.0 })
901
- valid_value = self ._comparator (current_value , operation , reference_value , ** kwargs )
902
-
903
- if not valid_value :
904
- error_list .append (
905
- f"INPUT SETTINGS --> { input_tag } : is { current_value } , but should be "
906
- f"{ '' if operation == 'auto fail' else operation + ' ' } { reference_value } ."
907
- f"{ ' ' if len (append_comments ) > 0 else '' } { append_comments } "
908
- )
909
-
910
- def check_parameter (
911
- self ,
912
- reasons : list [str ],
913
- warnings : list [str ],
914
- input_tag : str ,
915
- current_values : Any ,
916
- reference_values : Any ,
917
- operations : str | list [str ],
918
- tolerance : float = None ,
919
- append_comments : str | None = None ,
920
- severity : Literal ["reason" , "warning" ] = "reason" ,
921
- ) -> None :
922
- """
923
- Determine validity of parameter according to one or more operations.
924
-
925
- Parameters
926
- -----------
927
- reasons : list[str]
928
- A list of error strings to update if a check fails. These are higher
929
- severity and would deprecate a calculation.
930
- warnings : list[str]
931
- A list of warning strings to update if a check fails. These are lower
932
- severity and would flag a calculation for possible review.
933
- input_tag : str
934
- The name of the input tag which is being checked.
935
- current_values : Any
936
- The test value(s). If multiple operations are specified, must be a Sequence
937
- of test values.
938
- reference_values : Any
939
- The value(s) to compare the test value(s) to. If multiple operations are
940
- specified, must be a Sequence of reference values.
941
- operations : str
942
- One or more valid operations in self.operations. For example, if operations = "<=",
943
- this checks `current_values <= reference_values` (note order of values).
944
- Or, if operations == ["<=", ">"], this checks
945
- ```
946
- (
947
- (current_values[0] <= reference_values[0])
948
- and (current_values[1] > reference_values[1])
949
- )
950
- ```
951
- tolerance : float or None (default)
952
- Tolerance to use in math.isclose if any of operations is "approx". Defaults
953
- to self.tolerance.
954
- append_comments : str or None (default)
955
- Additional comments that may be helpful for the user to understand why
956
- a check failed.
957
- severity : Literal["reason", "warning"]
958
- If a calculation fails, the severity of failure. Directs output to
959
- either reasons or warnings.
960
- """
961
-
962
- severity_to_list = {"reason" : reasons , "warning" : warnings }
963
-
964
- if not isinstance (operations , list ):
965
- operations = [operations ]
966
- current_values = [current_values ]
967
- reference_values = [reference_values ]
968
-
969
- unknown_operations = {operation for operation in operations if operation not in self .operations }
970
- if len (unknown_operations ) > 0 :
971
- raise ValueError ("Unknown operations:\n " + ", " .join ([f"{ uo } " for uo in unknown_operations ]))
972
-
973
- for iop in range (len (operations )):
974
- self ._check_parameter (
975
- error_list = severity_to_list [severity ],
976
- input_tag = input_tag ,
977
- current_value = current_values [iop ],
978
- reference_value = reference_values [iop ],
979
- operation = operations [iop ],
980
- tolerance = tolerance ,
981
- append_comments = append_comments ,
982
- )
0 commit comments