Skip to content

Commit 655a34e

Browse files
authored
Merge pull request #376 from anyangml2nd/feat/support-new-molecule-tasks
Feat: support two molecule tasks
2 parents aca46ad + ab928d9 commit 655a34e

File tree

6 files changed

+172
-7
lines changed

6 files changed

+172
-7
lines changed

lambench/metrics/downstream_tasks_metrics.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,11 @@ vacancy:
2424
domain: Inorganic Materials
2525
metrics: [MAE]
2626
dummy: {"MAE": 4.381}
27+
binding_energy:
28+
domain: Molecules
29+
metrics: [MAE]
30+
dummy: {"MAE": 8.098}
31+
rxn_barrier:
32+
domain: Molecules
33+
metrics: [MAE]
34+
dummy: {"MAE": 20.975}

lambench/metrics/post_process.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ def process_domain_specific_for_one_model(model: BaseLargeAtomModel):
118118
"wiggle150",
119119
"elastic",
120120
"vacancy",
121+
"binding_energy",
122+
"rxn_barrier",
121123
]:
122124
applicability_results[record.task_name] = record.metrics
123125
return applicability_results

lambench/models/ase_models.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(self, *args, **kwargs):
8282
self._calc = None
8383

8484
@property
85-
def calc(self) -> Calculator:
85+
def calc(self, head=None) -> Calculator:
8686
"""ASE Calculator with the model loaded."""
8787
calculator_dispatch = {
8888
"MACE": self._init_mace_calculator,
@@ -101,7 +101,6 @@ def calc(self) -> Calculator:
101101
f"Model {self.model_name} is not supported by ASEModel, using EMT as default calculator."
102102
)
103103
self._calc = EMT()
104-
105104
else:
106105
self._calc = calculator_dispatch[self.model_family]()
107106
return self._calc
@@ -114,10 +113,12 @@ def calc(self, value: Calculator):
114113
def _init_mace_calculator(self) -> Calculator:
115114
from mace.calculators import mace_mp
116115

116+
if self.model_domain == "molecules":
117+
head = "omol"
118+
else:
119+
head = "oc20_usemppbe"
117120
return mace_mp(
118-
model=self.model_name.split("_")[-1],
119-
device="cuda",
120-
default_dtype="float64",
121+
model=self.model_path, device="cuda", default_dtype="float64", head=head
121122
)
122123

123124
def _init_orb_calculator(self) -> Calculator:
@@ -134,7 +135,7 @@ def _init_sevennet_calculator(self) -> Calculator:
134135

135136
model_config = {"model": self.model_name, "device": "cuda"}
136137
if self.model_name == "7net-mf-ompa":
137-
model_config["modal"] = "mpa"
138+
model_config["modal"] = "omat24"
138139
return SevenNetCalculator(**model_config)
139140

140141
def _init_equiformer_calculator(self) -> Calculator:
@@ -171,7 +172,7 @@ def _init_dp_calculator(self) -> Calculator:
171172
else:
172173
return DP(
173174
model=self.model_path,
174-
head="MP_traj_v024_alldata_mixu",
175+
head="Omat24",
175176
)
176177

177178
def _init_grace_calculator(self) -> Calculator:
@@ -290,6 +291,16 @@ def evaluate(
290291
elif task.task_name == "vacancy":
291292
from lambench.tasks.calculator.vacancy.vacancy import run_inference
292293

294+
assert task.test_data is not None
295+
return {"metrics": run_inference(self, task.test_data)}
296+
elif task.task_name == "rxn_barrier":
297+
from lambench.tasks.calculator.rxn_barrier.barrier import run_inference
298+
299+
assert task.test_data is not None
300+
return {"metrics": run_inference(self, task.test_data)}
301+
elif task.task_name == "binding_energy":
302+
from lambench.tasks.calculator.binding.binding import run_inference
303+
293304
assert task.test_data is not None
294305
return {"metrics": run_inference(self, task.test_data)}
295306
else:
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""
2+
The test data is retrieved from:
3+
J. Chem. Inf. Model. 2020, 60, 3, 1453–1460
4+
5+
https://pubs.acs.org/doi/10.1021/acs.jcim.9b01171
6+
7+
Only the PLF547 dataset is used.
8+
9+
"""
10+
11+
from ase.io import read
12+
import numpy as np
13+
from tqdm import tqdm
14+
from sklearn.metrics import root_mean_squared_error, mean_absolute_error
15+
from pathlib import Path
16+
from lambench.models.ase_models import ASEModel
17+
import logging
18+
19+
20+
def run_inference(
21+
model: ASEModel,
22+
test_data: Path,
23+
) -> dict[str, float]:
24+
active_site_atoms = read(test_data / "active_site.traj", ":")
25+
drug_atoms = read(test_data / "drug.traj", ":")
26+
combined_atoms = read(test_data / "combined.traj", ":")
27+
labels = np.load(test_data / "labels.npy")
28+
29+
EV_TO_KCAL = 23.06092234465
30+
31+
calc = model.calc
32+
preds = []
33+
success_labels = []
34+
35+
for site, drug, combo, label in tqdm(
36+
zip(active_site_atoms, drug_atoms, combined_atoms, labels)
37+
):
38+
try:
39+
for atoms in (site, drug, combo):
40+
atoms.calc = calc
41+
atoms.info.update(
42+
{"fparam": np.array([atoms.info["charge"], atoms.info["spin"]])}
43+
)
44+
45+
site_energy = site.get_potential_energy()
46+
drug_energy = drug.get_potential_energy()
47+
combo_energy = combo.get_potential_energy()
48+
49+
binding_energy = combo_energy - site_energy - drug_energy
50+
preds.append(binding_energy * EV_TO_KCAL)
51+
success_labels.append(label)
52+
except Exception as e:
53+
logging.warning(f"Failed to calculate binding energy for one sample: {e}")
54+
continue
55+
56+
return {
57+
"MAE": mean_absolute_error(success_labels, preds), # kcal/mol
58+
"RMSE": root_mean_squared_error(success_labels, preds), # kcal/mol
59+
"success_rate": len(success_labels) / len(labels),
60+
}

lambench/tasks/calculator/calculator_tasks.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,9 @@ elastic:
3131
vacancy:
3232
test_data: /bohr/lambench-vacancy-a2xo/v1
3333
calculator_params: null
34+
binding_energy:
35+
test_data: /bohr/lambench-binding-dlc6/v1/PLF547
36+
calculator_params: null
37+
rxn_barrier:
38+
test_data: /bohr/lambench-BH876-uplk/v1/BH876
39+
calculator_params: null
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""
2+
The test data is retrieved from:
3+
4+
@misc{liang2025gold,
5+
title={Gold-Standard Chemical Database 137 (GSCDB137): A diverse set of accurate energy differences for assessing and developing density functionals},
6+
author={Jiashu Liang and Martin Head-Gordon},
7+
year={2025},
8+
eprint={2508.13468},
9+
archivePrefix={arXiv},
10+
primaryClass={physics.chem-ph},
11+
url={https://arxiv.org/abs/2508.13468},
12+
}
13+
14+
https://github.com/JiashuLiang/GSCDB
15+
16+
Only the BH876 dataset is used.
17+
18+
"""
19+
20+
from ase.io import read
21+
import pandas as pd
22+
import numpy as np
23+
from tqdm import tqdm
24+
from sklearn.metrics import root_mean_squared_error, mean_absolute_error
25+
from pathlib import Path
26+
from lambench.models.ase_models import ASEModel
27+
import logging
28+
29+
30+
def run_inference(
31+
model: ASEModel,
32+
test_data: Path,
33+
) -> dict[str, float]:
34+
lookup_table = pd.read_csv(test_data / "lookup_table.csv")
35+
lookup_table.reset_index(inplace=True)
36+
stoichiometry = pd.read_csv(test_data / "stoichiometry.csv")
37+
traj = read(test_data / "BH876.traj", ":")
38+
39+
EV_TO_KCAL = 23.06092234465
40+
HARTREE_TO_KCAL = 627.50947406
41+
42+
preds = []
43+
labels = []
44+
success = len(stoichiometry)
45+
46+
calc = model.calc
47+
48+
for i, row in tqdm(stoichiometry.iterrows()):
49+
try:
50+
reactions = row["Stoichiometry"].split(",")
51+
num_species = len(reactions) // 2
52+
pred = 0
53+
for i in range(num_species):
54+
stoi = float(reactions[2 * i])
55+
reactant = reactions[2 * i + 1]
56+
structure_index = lookup_table[
57+
lookup_table["ID"] == reactant
58+
].index.values[0]
59+
atoms = traj[structure_index]
60+
atoms.info.update(
61+
{"fparam": np.array([atoms.info["charge"], atoms.info["spin"]])}
62+
)
63+
atoms.calc = calc
64+
energy = atoms.get_potential_energy()
65+
pred += stoi * energy
66+
preds.append(pred * EV_TO_KCAL)
67+
labels.append(row["Reference"] * HARTREE_TO_KCAL)
68+
except Exception as e:
69+
logging.warning(
70+
f"Failed to calculate reaction energy for reaction: {row['Stoichiometry']}. Error: {e}"
71+
)
72+
success -= 1
73+
74+
return {
75+
"MAE": mean_absolute_error(labels, preds), # kcal/mol
76+
"RMSE": root_mean_squared_error(labels, preds), # kcal/mol
77+
"success_rate": success / len(stoichiometry),
78+
}

0 commit comments

Comments
 (0)