Skip to content

Commit da2a26c

Browse files
authored
Support more MACE foundation models (#109)
* Fixed errors in documentation * Support more MACE foundation models * Tests for MACE foundation models * Removed documentation on MACE-OMOL-0 * Don't register MACE-OMOL-0
1 parent bbf86e6 commit da2a26c

File tree

5 files changed

+109
-17
lines changed

5 files changed

+109
-17
lines changed

doc/userguide.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ are supported.
7878

7979
| Name | Model |
8080
| --- | --- |
81-
| `mace-off23-small`<br>`mace-off23-medium`<br>`mace-off23-large` | Pretrained [MACE-OFF23](https://arxiv.org/abs/2312.15211) models |
81+
| `mace-off23-small`<br>`mace-off23-medium`<br>`mace-off23-large`<br>`mace-off24-medium` | Pretrained [MACE-OFF](https://pubs.acs.org/doi/10.1021/jacs.4c07099) models |
82+
| `mace-mpa-0-medium` | Pretrained [MACE-MPA-0](https://github.com/ACEsuit/mace-foundations) model |
83+
| `mace-omat-0-small`<br>`mace-omat-0-medium` | Pretrained [MACE-OMAT-0](https://github.com/ACEsuit/mace-foundations) models |
8284
| `mace` | Custom MACE models specified with the `modelPath` argument |
8385

8486
When creating MACE models, the following keyword arguments to the `MLPotential` constructor are supported.
@@ -92,7 +94,7 @@ When using MACE models, the following extra keyword arguments to `createSystem()
9294
| Argument | Description |
9395
| --- | --- |
9496
| `precision` | The numerical precision of the model. Supported options are `'single'` and `'double'`. If `None`, the default precision of the model is used. |
95-
| `returnEnergyType` | Whether to return the interaction energy or the energy including the self-energy. The default is `'interaction_energy'`. Supported options are `'interaction_energy'` and `'energy'`. |
97+
| `returnEnergyType` | Whether to return the interaction energy or the energy including the self-energy. The default is `'interaction_energy'`. Supported options are `'interaction_energy'` and `'energy'`. |
9698

9799
### AIMNet2
98100

@@ -106,9 +108,9 @@ are supported.
106108

107109
When using AIMNet2 models, the following extra keyword arguments to `createSystem()` and `createMixedSystem()` are supported.
108110

109-
| Argument | Description |
111+
| Argument | Description |
110112
| --- | --- |
111-
| `charge` | The total charge of the system. If omitted, it is assumed to be 0. |
113+
| `charge` | The total charge of the system. If omitted, it is assumed to be 0. |
112114
| `multiplicity` | The spin multiplicity of the system. If omitted, it is assumed to be 1. |
113115

114116
### NequIP

openmmml/models/macepotential.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@ class MACEPotentialImpl(MLPotentialImpl):
4747
4848
The MACE potential is constructed using MACE to build a PyTorch model,
4949
and then integrated into the OpenMM System using a TorchForce.
50-
This implementation supports both MACE-OFF23 and locally trained MACE models.
50+
This implementation supports both foundation models and locally trained MACE models.
5151
52-
To use one of the pre-trained MACE-OFF23 models, specify the model name. For example:
52+
To use one of the pre-trained MACE foundation models, specify the model name. For example:
5353
5454
>>> potential = MLPotential('mace-off23-small')
5555
56-
Other available MACE-OFF23 models include 'mace-off23-medium' and 'mace-off23-large'.
56+
Other available models include 'mace-off23-medium', 'mace-off23-large', and 'mace-off24-medium'.
5757
5858
To use a locally trained MACE model, provide the path to the model file. For example:
5959
@@ -93,7 +93,8 @@ def __init__(self, name: str, modelPath) -> None:
9393
----------
9494
name : str
9595
The name of the MACE model.
96-
Options include 'mace-off23-small', 'mace-off23-medium', 'mace-off23-large', and 'mace'.
96+
Options include 'mace-off23-small', 'mace-off23-medium', 'mace-off23-large',
97+
'mace-off24-medium', and 'mace'.
9798
modelPath : str, optional
9899
The path to the locally trained MACE model if ``name`` is 'mace'.
99100
"""
@@ -135,7 +136,7 @@ def addForces(
135136

136137
try:
137138
from mace.tools import utils, to_one_hot, atomic_numbers_to_indices
138-
from mace.calculators.foundations_models import mace_off
139+
from mace.calculators.foundations_models import mace_off, mace_mp, mace_omol
139140
except ImportError as e:
140141
raise ImportError(
141142
f"Failed to import mace with error: {e}. "
@@ -161,13 +162,24 @@ def addForces(
161162
"energy",
162163
], f"Unsupported returnEnergyType: '{returnEnergyType}'. Supported options are 'interaction_energy' or 'energy'."
163164

165+
models = {
166+
'mace-off23-small': (mace_off, 'small', True),
167+
'mace-off23-medium': (mace_off, 'medium', True),
168+
'mace-off23-large': (mace_off, 'large', True),
169+
'mace-off24-medium': (mace_off, 'https://github.com/ACEsuit/mace-off/blob/main/mace_off24/MACE-OFF24_medium.model?raw=true', True),
170+
'mace-mpa-0-medium': (mace_mp, 'medium-mpa-0', False),
171+
'mace-omat-0-small': (mace_mp, 'small-omat-0', True),
172+
'mace-omat-0-medium': (mace_mp, 'medium-omat-0', True),
173+
'mace-omol-0-extra-large': (mace_omol, 'extra_large', True)
174+
}
175+
164176
# Load the model to the CPU (OpenMM-Torch takes care of loading to the right devices)
165-
if self.name.startswith("mace-off23"):
166-
size = self.name.split("-")[-1]
167-
assert (
168-
size in ["small", "medium", "large"]
169-
), f"Unsupported MACE model: '{self.name}'. Available MACE-OFF23 models are 'mace-off23-small', 'mace-off23-medium', 'mace-off23-large'"
170-
model = mace_off(model=size, device="cpu", return_raw_model=True)
177+
if self.name in models:
178+
fn, name, warn = models[self.name]
179+
model = fn(model=name, device="cpu", return_raw_model=True)
180+
if warn:
181+
import logging
182+
logging.warning(f'The model {self.name} is distributed under the restrictive ASL license. Commercial use is not permitted.')
171183
elif self.name == "mace":
172184
if self.modelPath is not None:
173185
model = torch.load(self.modelPath, map_location="cpu")
@@ -231,6 +243,10 @@ class MACEForce(torch.nn.Module):
231243
Conversion factor for the length, viz. nm to Angstrom.
232244
indices : torch.Tensor
233245
The indices of the atoms to calculate the energy for.
246+
charge : float
247+
Total charge of the system
248+
multiplicity : float
249+
Spin multiplicity of the system
234250
returnEnergyType : str
235251
Whether to return the interaction energy or the energy including the self-energy.
236252
inputDict : dict
@@ -242,6 +258,8 @@ def __init__(
242258
model: torch.jit._script.RecursiveScriptModule,
243259
nodeAttrs: torch.Tensor,
244260
atoms: Optional[Iterable[int]],
261+
charge: float,
262+
multiplicity: float,
245263
periodic: bool,
246264
dtype: torch.dtype,
247265
returnEnergyType: str,
@@ -282,6 +300,8 @@ def __init__(
282300
self.register_buffer("node_attrs", nodeAttrs.to(self.dtype))
283301
self.register_buffer("batch", torch.zeros(nodeAttrs.shape[0], dtype=torch.long, requires_grad=False))
284302
self.register_buffer("pbc", torch.tensor([periodic, periodic, periodic], dtype=torch.bool, requires_grad=False))
303+
self.register_buffer("charge", torch.tensor([charge], dtype=dtype, requires_grad=False))
304+
self.register_buffer("multiplicity", torch.tensor([multiplicity], dtype=dtype, requires_grad=False))
285305

286306
def _getNeighborPairs(
287307
self, positions: torch.Tensor, cell: Optional[torch.Tensor]
@@ -371,7 +391,9 @@ def forward(
371391
"edge_index": edgeIndex,
372392
"shifts": shifts,
373393
"cell": cell if cell is not None else torch.zeros(3, 3, dtype=self.dtype),
374-
}
394+
"total_charge": self.charge,
395+
"total_spin": self.multiplicity
396+
}
375397

376398
# Predict the energy.
377399
energy = self.model(inputDict, compute_force=False)[
@@ -392,6 +414,8 @@ def forward(
392414
model,
393415
nodeAttrs,
394416
atoms,
417+
float(args.get('charge', 0)),
418+
float(args.get('multiplicity', 1)),
395419
isPeriodic,
396420
dtype,
397421
returnEnergyType,

setup.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@
4848
'mace-off23-small = openmmml.models.macepotential:MACEPotentialImplFactory',
4949
'mace-off23-medium = openmmml.models.macepotential:MACEPotentialImplFactory',
5050
'mace-off23-large = openmmml.models.macepotential:MACEPotentialImplFactory',
51+
'mace-off24-medium = openmmml.models.macepotential:MACEPotentialImplFactory',
52+
'mace-mpa-0-medium = openmmml.models.macepotential:MACEPotentialImplFactory',
53+
'mace-omat-0-small = openmmml.models.macepotential:MACEPotentialImplFactory',
54+
'mace-omat-0-medium = openmmml.models.macepotential:MACEPotentialImplFactory',
5155
'nequip = openmmml.models.nequippotential:NequIPPotentialImplFactory',
5256
'deepmd = openmmml.models.deepmdpotential:DeepmdPotentialImplFactory',
5357
'torchmdnet = openmmml.models.torchmdnetpotential:TorchMDNetPotentialImplFactory',
@@ -56,4 +60,3 @@
5660
]
5761
}
5862
)
59-

test/TestMACEPotential.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import os
2+
3+
import numpy as np
4+
import openmm as mm
5+
import openmm.app as app
6+
import openmm.unit as unit
7+
import pytest
8+
9+
from openmmml import MLPotential
10+
11+
mace = pytest.importorskip("mace", reason="mace is not installed")
12+
platform_ints = range(mm.Platform.getNumPlatforms())
13+
# Get the path to the test data
14+
test_data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
15+
16+
@pytest.mark.parametrize("platform_int", list(platform_ints))
17+
@pytest.mark.parametrize("model", ['mace-off23-small', 'mace-off23-medium', 'mace-off23-large', 'mace-off24-medium',
18+
'mace-mpa-0-medium', 'mace-omat-0-small', 'mace-omat-0-medium'])#, 'mace-omol-0-extra-large'])
19+
class TestMACE:
20+
def testCreatePureMLSystem(self, platform_int, model):
21+
pdb = app.PDBFile(os.path.join(test_data_dir, "toluene", "toluene.pdb"))
22+
potential = MLPotential(model)
23+
system = potential.createSystem(pdb.topology, returnEnergyType='energy')
24+
platform = mm.Platform.getPlatform(platform_int)
25+
context = mm.Context(system, mm.VerletIntegrator(0.001), platform)
26+
context.setPositions(pdb.getPositions(asNumpy=True))
27+
energyML = context.getState(energy=True).getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole)
28+
# Reference energies are calculated with MACECalculator
29+
refEnergy = {'mace-off23-small': -713468.6327560507,
30+
'mace-off23-medium': -713468.0563706581,
31+
'mace-off23-large': -713467.7476380612,
32+
'mace-off24-medium': -713467.9394350434,
33+
'mace-mpa-0-medium': -8839.299589829867,
34+
'mace-omat-0-small': -8726.63865431241,
35+
'mace-omat-0-medium': -8679.026847088873,
36+
'mace-omol-0-extra-large': -712903.4934289698}
37+
assert np.isclose(refEnergy[model], energyML, rtol=1e-6)

test/data/toluene/mace_energies.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# This script computes reference energies for the MACE foundation models.
2+
3+
import ase.io
4+
from mace.calculators.foundations_models import mace_off, mace_mp, mace_omol
5+
from openmm.unit import kilojoules_per_mole, ev, item
6+
7+
atoms = ase.io.read('toluene.pdb')
8+
results = {}
9+
atoms.calc = mace_off('small')
10+
results['mace-off23-small'] = atoms.get_potential_energy()
11+
atoms.calc = mace_off('medium')
12+
results['mace-off23-medium'] = atoms.get_potential_energy()
13+
atoms.calc = mace_off('large')
14+
results['mace-off23-large'] = atoms.get_potential_energy()
15+
atoms.calc = mace_off('https://github.com/ACEsuit/mace-off/blob/main/mace_off24/MACE-OFF24_medium.model?raw=true')
16+
results['mace-off24-medium'] = atoms.get_potential_energy()
17+
atoms.calc = mace_mp('medium-mpa-0')
18+
results['mace-mpa-0-medium'] = atoms.get_potential_energy()
19+
atoms.calc = mace_mp('small-omat-0')
20+
results['mace-omat-0-small'] = atoms.get_potential_energy()
21+
atoms.calc = mace_mp('medium-omat-0')
22+
results['mace-omat-0-medium'] = atoms.get_potential_energy()
23+
atoms.calc = mace_omol('extra_large')
24+
results['mace-omol-0-extra-large'] = atoms.get_potential_energy()
25+
for key in results:
26+
print(f'{key}: {(results[key]*ev/item).value_in_unit(kilojoules_per_mole)}')

0 commit comments

Comments
 (0)