Skip to content

Commit f85e91e

Browse files
authored
forcefields: Standardize MLFF handling and maker naming (#1360)
* Standardize MLFF handling and maker naming * Introduce mlff property in tutorial
1 parent eeeec5d commit f85e91e

File tree

7 files changed

+105
-63
lines changed

7 files changed

+105
-63
lines changed

src/atomate2/forcefields/flows/approx_neb.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from typing_extensions import Self
1010

1111
from atomate2.common.flows.approx_neb import ApproxNebFromEndpointsMaker
12-
from atomate2.forcefields import MLFF, _get_formatted_ff_name
1312
from atomate2.forcefields.jobs import ForceFieldRelaxMaker
13+
from atomate2.forcefields.utils import MLFF
1414

1515

1616
@dataclass
@@ -84,15 +84,11 @@ def from_force_field_name(
8484
-------
8585
MLFFApproxNebFromEndpointsMaker
8686
"""
87-
force_field_name = _get_formatted_ff_name(force_field_name)
88-
kwargs.update(
89-
image_relax_maker=ForceFieldRelaxMaker(
90-
force_field_name=force_field_name, relax_cell=False
91-
),
87+
image_relax_maker = ForceFieldRelaxMaker(
88+
force_field_name=force_field_name, relax_cell=False
9289
)
90+
kwargs.update(image_relax_maker=image_relax_maker)
9391
return cls(
94-
name=(
95-
f"{force_field_name.split('MLFF.')[-1]} ApproxNEB from endpoints Maker"
96-
),
92+
name=(f"{image_relax_maker.mlff.name} ApproxNEB from endpoints Maker"),
9793
**kwargs,
9894
)

src/atomate2/forcefields/flows/elastic.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77

88
from atomate2 import SETTINGS
99
from atomate2.common.flows.elastic import BaseElasticMaker
10-
from atomate2.forcefields import MLFF, _get_formatted_ff_name
1110
from atomate2.forcefields.jobs import ForceFieldRelaxMaker
1211

1312
if TYPE_CHECKING:
1413
from typing import Any
1514

1615
from typing_extensions import Self
1716

17+
from atomate2.forcefields import MLFF
18+
1819
# default options for the forcefield makers in ElasticMaker
1920
_DEFAULT_RELAX_KWARGS: dict[str, Any] = {
2021
"force_field_name": "CHGNet",
@@ -125,19 +126,20 @@ def from_force_field_name(
125126
default_kwargs: dict[str, Any] = {
126127
**_DEFAULT_RELAX_KWARGS,
127128
**(mlff_kwargs or {}),
128-
"force_field_name": _get_formatted_ff_name(force_field_name),
129+
"force_field_name": force_field_name,
129130
}
131+
bulk_relax_maker = ForceFieldRelaxMaker(
132+
relax_cell=True,
133+
**default_kwargs,
134+
)
130135
kwargs.update(
131-
bulk_relax_maker=ForceFieldRelaxMaker(
132-
relax_cell=True,
133-
**default_kwargs,
134-
),
136+
bulk_relax_maker=bulk_relax_maker,
135137
elastic_relax_maker=ForceFieldRelaxMaker(
136138
relax_cell=False,
137139
**default_kwargs,
138140
),
139141
)
140142
return cls(
141-
name=f"{str(force_field_name).split('MLFF.')[-1]} elastic",
143+
name=f"{bulk_relax_maker.mlff.name} elastic",
142144
**kwargs,
143145
)

src/atomate2/forcefields/flows/eos.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import TYPE_CHECKING
77

88
from atomate2.common.flows.eos import CommonEosMaker
9-
from atomate2.forcefields import _get_formatted_ff_name
109
from atomate2.forcefields.jobs import ForceFieldRelaxMaker
1110

1211
if TYPE_CHECKING:
@@ -81,22 +80,19 @@ def from_force_field_name(
8180
-------
8281
ForceFieldEosMaker
8382
"""
84-
force_field_name = _get_formatted_ff_name(force_field_name)
83+
eos_relax_maker = ForceFieldRelaxMaker(
84+
force_field_name=force_field_name, relax_cell=False
85+
)
8586
kwargs.update(
8687
initial_relax_maker=(
8788
ForceFieldRelaxMaker(force_field_name=force_field_name)
8889
if relax_initial_structure
8990
else None
90-
)
91-
)
92-
93-
kwargs.update(
94-
eos_relax_maker=ForceFieldRelaxMaker(
95-
force_field_name=force_field_name, relax_cell=False
9691
),
92+
eos_relax_maker=eos_relax_maker,
9793
static_maker=None,
9894
)
9995
return cls(
100-
name=f"{force_field_name.split('MLFF.')[-1]} EOS Maker",
96+
name=f"{eos_relax_maker.mlff.name} EOS Maker",
10197
**kwargs,
10298
)

src/atomate2/forcefields/flows/phonons.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
from atomate2 import SETTINGS
99
from atomate2.common.flows.phonons import BasePhononMaker
10-
from atomate2.forcefields import _get_formatted_ff_name
1110
from atomate2.forcefields.jobs import ForceFieldRelaxMaker, ForceFieldStaticMaker
11+
from atomate2.forcefields.utils import MLFF
1212

1313
if TYPE_CHECKING:
1414
from typing_extensions import Self
@@ -153,6 +153,11 @@ def prev_calc_dir_argname(self) -> None:
153153
"""
154154
return
155155

156+
@property
157+
def mlff(self) -> MLFF:
158+
"""The MLFF enum corresponding to the force field name."""
159+
return self.phonon_displacement_maker.mlff
160+
156161
@classmethod
157162
def from_force_field_name(
158163
cls,
@@ -177,24 +182,22 @@ def from_force_field_name(
177182
-------
178183
PhononMaker
179184
"""
180-
force_field_name = _get_formatted_ff_name(force_field_name)
181-
185+
static_energy_maker = ForceFieldStaticMaker(force_field_name=force_field_name)
182186
kwargs.update(
183187
bulk_relax_maker=(
184188
ForceFieldRelaxMaker(
185189
force_field_name=force_field_name, relax_kwargs={"fmax": 1e-5}
186190
)
187191
if relax_initial_structure
188192
else None
189-
)
190-
)
191-
kwargs.update(
192-
static_energy_maker=ForceFieldStaticMaker(
193-
force_field_name=force_field_name
194193
),
194+
static_energy_maker=static_energy_maker,
195195
phonon_displacement_maker=ForceFieldStaticMaker(
196196
force_field_name=force_field_name
197197
),
198198
born_maker=None,
199199
)
200-
return cls(name=f"{force_field_name.split('MLFF.')[-1]} Phonon Maker", **kwargs)
200+
return cls(
201+
name=(f"{static_energy_maker.mlff.name} Phonon Maker"),
202+
**kwargs,
203+
)

src/atomate2/forcefields/flows/qha.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import TYPE_CHECKING, Literal
77

88
from atomate2.common.flows.qha import CommonQhaMaker
9-
from atomate2.forcefields import _get_formatted_ff_name
109
from atomate2.forcefields.flows.phonons import PhononMaker
1110
from atomate2.forcefields.jobs import ForceFieldRelaxMaker
1211

@@ -117,16 +116,12 @@ def from_force_field_name(
117116
-------
118117
ForceFieldQhaMaker
119118
"""
120-
force_field_name = _get_formatted_ff_name(force_field_name)
121119
kwargs.update(
122120
initial_relax_maker=(
123121
ForceFieldRelaxMaker(force_field_name=force_field_name)
124122
if relax_initial_structure
125123
else None
126-
)
127-
)
128-
129-
kwargs.update(
124+
),
130125
eos_relax_maker=(
131126
ForceFieldRelaxMaker(
132127
force_field_name=force_field_name,
@@ -135,13 +130,14 @@ def from_force_field_name(
135130
)
136131
if run_eos_flow
137132
else None
138-
)
133+
),
134+
)
135+
phonon_maker = PhononMaker.from_force_field_name(
136+
force_field_name=force_field_name, relax_initial_structure=False
139137
)
140138
return cls(
141-
phonon_maker=PhononMaker.from_force_field_name(
142-
force_field_name=force_field_name, relax_initial_structure=False
143-
),
144-
name=f"{force_field_name.split('MLFF.')[-1]} QHA Maker",
139+
phonon_maker=phonon_maker,
140+
name=f"{phonon_maker.mlff.name} QHA Maker",
145141
**kwargs,
146142
)
147143

src/atomate2/forcefields/utils.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ase.units import Bohr
1313
from ase.units import GPa as _GPa_to_eV_per_A3
1414
from monty.json import MontyDecoder
15+
from typing_extensions import deprecated
1516

1617
if TYPE_CHECKING:
1718
from collections.abc import Generator
@@ -75,27 +76,34 @@ def _missing_(cls, value: Any) -> Any:
7576
}
7677

7778

78-
def _get_formatted_ff_name(force_field_name: str | MLFF) -> str:
79-
"""
80-
Get the standardized force field name.
79+
def _get_standardized_mlff(force_field_name: str | MLFF) -> MLFF:
80+
"""Get the standardized force field name.
8181
8282
Parameters
8383
----------
8484
force_field_name : str or .MLFF
8585
The name of the force field
86+
For str, accept both with and without the `MLFF.` prefix.
8687
8788
Returns
8889
-------
89-
str : the name of the forcefield from MLFF
90+
MLFF: the name of the forcefield
9091
"""
9192
if isinstance(force_field_name, str):
9293
# ensure `force_field_name` uses enum format
94+
if force_field_name.startswith("MLFF."):
95+
force_field_name = force_field_name.split("MLFF.")[-1]
96+
9397
if force_field_name in MLFF.__members__:
9498
force_field_name = MLFF[force_field_name]
9599
elif force_field_name in [v.value for v in MLFF]:
96100
force_field_name = MLFF(force_field_name)
97-
force_field_name = str(force_field_name)
98-
if force_field_name in {"MLFF.MACE", "MACE"}:
101+
else:
102+
raise ValueError(
103+
f"force_field_name={force_field_name} is not a valid MLFF name."
104+
)
105+
106+
if force_field_name == MLFF.MACE:
99107
warnings.warn(
100108
"Because the default MP-trained MACE model is constantly evolving, "
101109
"we no longer recommend using `MACE` or `MLFF.MACE` to specify "
@@ -108,6 +116,24 @@ def _get_formatted_ff_name(force_field_name: str | MLFF) -> str:
108116
return force_field_name
109117

110118

119+
@deprecated("Use _get_standardized_mlff instead.")
120+
def _get_formatted_ff_name(force_field_name: str | MLFF) -> str:
121+
"""
122+
Get the standardized force field name.
123+
124+
Parameters
125+
----------
126+
force_field_name : str or .MLFF
127+
The name of the force field
128+
129+
Returns
130+
-------
131+
str : the name of the forcefield from MLFF
132+
"""
133+
force_field_name = _get_standardized_mlff(force_field_name)
134+
return str(force_field_name)
135+
136+
111137
@dataclass
112138
class ForceFieldMixin:
113139
"""Mix-in class for force-fields.
@@ -134,18 +160,17 @@ def __post_init__(self) -> None:
134160
if hasattr(super(), "__post_init__"):
135161
super().__post_init__() # type: ignore[misc]
136162

137-
self.force_field_name = _get_formatted_ff_name(self.force_field_name)
163+
mlff = _get_standardized_mlff(self.force_field_name)
164+
self.force_field_name: str = str(mlff) # Narrow-down type for mypy
138165

139166
# Pad calculator_kwargs with default values, but permit user to override them
140167
self.calculator_kwargs = {
141-
**_DEFAULT_CALCULATOR_KWARGS.get(
142-
MLFF(self.force_field_name.split("MLFF.")[-1]), {}
143-
),
168+
**_DEFAULT_CALCULATOR_KWARGS.get(mlff, {}),
144169
**self.calculator_kwargs,
145170
}
146171

147172
if not self.task_document_kwargs.get("force_field_name"):
148-
self.task_document_kwargs["force_field_name"] = str(self.force_field_name)
173+
self.task_document_kwargs["force_field_name"] = self.force_field_name
149174

150175
def _run_ase_safe(self, *args, **kwargs) -> AseResult:
151176
if not hasattr(self, "run_ase"):
@@ -159,10 +184,15 @@ def _run_ase_safe(self, *args, **kwargs) -> AseResult:
159184
def calculator(self) -> Calculator:
160185
"""ASE calculator, can be overwritten by user."""
161186
return ase_calculator(
162-
str(self.force_field_name), # make mypy happy
187+
self.force_field_name,
163188
**self.calculator_kwargs,
164189
)
165190

191+
@property
192+
def mlff(self) -> MLFF:
193+
"""The MLFF enum corresponding to the force field name."""
194+
return MLFF(str(self.force_field_name).split("MLFF.")[-1])
195+
166196

167197
def ase_calculator(
168198
calculator_meta: str | MLFF | dict, **kwargs: Any

tutorials/force_fields/phonon_workflow.ipynb

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,16 +134,15 @@
134134
"metadata": {},
135135
"outputs": [],
136136
"source": [
137-
"PhononMaker.from_force_field_name(force_field_name=\"MACE_MP_0B3\")\n",
138-
"run_locally(flow, create_folders=True, raise_immediately=True, root_dir=tmp_dir)"
137+
"maker = PhononMaker.from_force_field_name(force_field_name=\"MACE_MP_0B3\")"
139138
]
140139
},
141140
{
142141
"cell_type": "markdown",
143142
"id": "11",
144143
"metadata": {},
145144
"source": [
146-
"Now, we clean up the temporary directory that we made. In reality, you might want to keep this data."
145+
"We can confirm that the specified force field is being used via `PhononMaker.mlff` property:"
147146
]
148147
},
149148
{
@@ -152,6 +151,26 @@
152151
"id": "12",
153152
"metadata": {},
154153
"outputs": [],
154+
"source": [
155+
"from atomate2.forcefields.utils import MLFF\n",
156+
"\n",
157+
"assert maker.mlff == MLFF.MACE_MP_0B3 # noqa: S101"
158+
]
159+
},
160+
{
161+
"cell_type": "markdown",
162+
"id": "13",
163+
"metadata": {},
164+
"source": [
165+
"Now, we clean up the temporary directory that we made. In reality, you might want to keep this data."
166+
]
167+
},
168+
{
169+
"cell_type": "code",
170+
"execution_count": null,
171+
"id": "14",
172+
"metadata": {},
173+
"outputs": [],
155174
"source": [
156175
"import shutil\n",
157176
"\n",
@@ -163,14 +182,14 @@
163182
"language_info": {
164183
"codemirror_mode": {
165184
"name": "ipython",
166-
"version": 2
185+
"version": 3
167186
},
168187
"file_extension": ".py",
169188
"mimetype": "text/x-python",
170189
"name": "python",
171190
"nbconvert_exporter": "python",
172-
"pygments_lexer": "ipython2",
173-
"version": "2.7.6"
191+
"pygments_lexer": "ipython3",
192+
"version": "3.11.14"
174193
}
175194
},
176195
"nbformat": 4,

0 commit comments

Comments
 (0)