1212from ase .units import Bohr
1313from ase .units import GPa as _GPa_to_eV_per_A3
1414from monty .json import MontyDecoder
15+ from typing_extensions import deprecated
1516
1617if TYPE_CHECKING :
1718 from collections .abc import Generator
@@ -75,27 +76,34 @@ def _missing_(cls, value: Any) -> Any:
7576}
7677
7778
78- def _get_formatted_ff_name (force_field_name : str | MLFF ) -> str :
79- """
80- Get the standardized force field name.
79+ def _get_standardized_mlff (force_field_name : str | MLFF ) -> MLFF :
80+ """Get the standardized force field name.
8181
8282 Parameters
8383 ----------
8484 force_field_name : str or .MLFF
8585 The name of the force field
86+ For str, accept both with and without the `MLFF.` prefix.
8687
8788 Returns
8889 -------
89- str : the name of the forcefield from MLFF
90+ MLFF : the name of the forcefield
9091 """
9192 if isinstance (force_field_name , str ):
9293 # ensure `force_field_name` uses enum format
94+ if force_field_name .startswith ("MLFF." ):
95+ force_field_name = force_field_name .split ("MLFF." )[- 1 ]
96+
9397 if force_field_name in MLFF .__members__ :
9498 force_field_name = MLFF [force_field_name ]
9599 elif force_field_name in [v .value for v in MLFF ]:
96100 force_field_name = MLFF (force_field_name )
97- force_field_name = str (force_field_name )
98- if force_field_name in {"MLFF.MACE" , "MACE" }:
101+ else :
102+ raise ValueError (
103+ f"force_field_name={ force_field_name } is not a valid MLFF name."
104+ )
105+
106+ if force_field_name == MLFF .MACE :
99107 warnings .warn (
100108 "Because the default MP-trained MACE model is constantly evolving, "
101109 "we no longer recommend using `MACE` or `MLFF.MACE` to specify "
@@ -108,6 +116,24 @@ def _get_formatted_ff_name(force_field_name: str | MLFF) -> str:
108116 return force_field_name
109117
110118
119+ @deprecated ("Use _get_standardized_mlff instead." )
120+ def _get_formatted_ff_name (force_field_name : str | MLFF ) -> str :
121+ """
122+ Get the standardized force field name.
123+
124+ Parameters
125+ ----------
126+ force_field_name : str or .MLFF
127+ The name of the force field
128+
129+ Returns
130+ -------
131+ str : the name of the forcefield from MLFF
132+ """
133+ force_field_name = _get_standardized_mlff (force_field_name )
134+ return str (force_field_name )
135+
136+
111137@dataclass
112138class ForceFieldMixin :
113139 """Mix-in class for force-fields.
@@ -134,18 +160,17 @@ def __post_init__(self) -> None:
134160 if hasattr (super (), "__post_init__" ):
135161 super ().__post_init__ () # type: ignore[misc]
136162
137- self .force_field_name = _get_formatted_ff_name (self .force_field_name )
163+ mlff = _get_standardized_mlff (self .force_field_name )
164+ self .force_field_name : str = str (mlff ) # Narrow-down type for mypy
138165
139166 # Pad calculator_kwargs with default values, but permit user to override them
140167 self .calculator_kwargs = {
141- ** _DEFAULT_CALCULATOR_KWARGS .get (
142- MLFF (self .force_field_name .split ("MLFF." )[- 1 ]), {}
143- ),
168+ ** _DEFAULT_CALCULATOR_KWARGS .get (mlff , {}),
144169 ** self .calculator_kwargs ,
145170 }
146171
147172 if not self .task_document_kwargs .get ("force_field_name" ):
148- self .task_document_kwargs ["force_field_name" ] = str ( self .force_field_name )
173+ self .task_document_kwargs ["force_field_name" ] = self .force_field_name
149174
150175 def _run_ase_safe (self , * args , ** kwargs ) -> AseResult :
151176 if not hasattr (self , "run_ase" ):
@@ -159,10 +184,15 @@ def _run_ase_safe(self, *args, **kwargs) -> AseResult:
159184 def calculator (self ) -> Calculator :
160185 """ASE calculator, can be overwritten by user."""
161186 return ase_calculator (
162- str ( self .force_field_name ), # make mypy happy
187+ self .force_field_name ,
163188 ** self .calculator_kwargs ,
164189 )
165190
191+ @property
192+ def mlff (self ) -> MLFF :
193+ """The MLFF enum corresponding to the force field name."""
194+ return MLFF (str (self .force_field_name ).split ("MLFF." )[- 1 ])
195+
166196
167197def ase_calculator (
168198 calculator_meta : str | MLFF | dict , ** kwargs : Any
0 commit comments