Skip to content

Commit 9c4447c

Browse files
orionarcherJaGeo
andauthored
Allow custom mace model by specifying "model" in calculator kwargs" (#1017)
* allow custom mace model by specifying "model" in calculator kwargs" * fix error in trying to turn None into path * Add support for ORB model * Specify more dependencies * remove orb implementation * add line * add line * remove device * set device * fix set device * fix test * fix linting * restore test * remove os and rely on pathlib only --------- Co-authored-by: J. George <[email protected]> Co-authored-by: JaGeo <[email protected]>
1 parent 42bc7b8 commit 9c4447c

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

src/atomate2/forcefields/utils.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import json
66
from contextlib import contextmanager
7+
from pathlib import Path
78
from typing import TYPE_CHECKING
89

910
from monty.json import MontyDecoder
@@ -59,9 +60,21 @@ def ase_calculator(calculator_meta: str | dict, **kwargs: Any) -> Calculator | N
5960
calculator = PESCalculator(potential, **kwargs)
6061

6162
elif calculator_name == MLFF.MACE:
62-
from mace.calculators import mace_mp
63-
64-
calculator = mace_mp(**kwargs)
63+
from mace.calculators import MACECalculator, mace_mp
64+
65+
model = kwargs.get("model")
66+
if isinstance(model, str | Path) and Path(model).exists():
67+
model_path = model
68+
device = kwargs.get("device") or "cpu"
69+
if "device" in kwargs:
70+
del kwargs["device"]
71+
calculator = MACECalculator(
72+
model_paths=model_path,
73+
device=device,
74+
**kwargs,
75+
)
76+
else:
77+
calculator = mace_mp(**kwargs)
6578

6679
elif calculator_name == MLFF.GAP:
6780
from quippy.potential import Potential

tests/forcefields/test_jobs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def test_mace_relax_maker(
277277
# NOTE the test model is not trained on Si, so the energy is not accurate
278278
job = ForceFieldRelaxMaker(
279279
force_field_name="MACE",
280-
calculator_kwargs={"model": model},
280+
calculator_kwargs={"model": model, "default_dtype": "float32"},
281281
steps=25,
282282
optimizer_kwargs={"optimizer": "BFGSLineSearch"},
283283
relax_cell=relax_cell,
@@ -308,7 +308,7 @@ def test_mace_relax_maker(
308308

309309
if fix_symmetry: # if symmetry is fixed, the symmetry should be the same or higher
310310
assert is_subgroup(symmetry_ops_init, symmetry_ops_final)
311-
else: # if symmetry is not fixed, it can both increase or decrease
311+
else: # if symmetry is not fixed, it can both increase or decrease or stay the same
312312
assert not is_subgroup(symmetry_ops_init, symmetry_ops_final)
313313

314314
if relax_cell:

0 commit comments

Comments
 (0)