Skip to content

Commit ab928d9

Browse files
committed
chore: format
1 parent 834c8ea commit ab928d9

File tree

5 files changed

+35
-36
lines changed

5 files changed

+35
-36
lines changed

lambench/metrics/downstream_tasks_metrics.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@ binding_energy:
3131
rxn_barrier:
3232
domain: Molecules
3333
metrics: [MAE]
34-
dummy: {"MAE": 20.975}
34+
dummy: {"MAE": 20.975}

lambench/models/ase_models.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,26 +104,23 @@ def calc(self, head=None) -> Calculator:
104104
else:
105105
self._calc = calculator_dispatch[self.model_family]()
106106
return self._calc
107-
107+
108108
@calc.setter
109109
def calc(self, value: Calculator):
110110
logging.warning("Overriding the default calculator.")
111111
self._calc = value
112112

113113
def _init_mace_calculator(self) -> Calculator:
114114
from mace.calculators import mace_mp
115+
115116
if self.model_domain == "molecules":
116117
head = "omol"
117118
else:
118119
head = "oc20_usemppbe"
119120
return mace_mp(
120-
model=self.model_path,
121-
device="cuda",
122-
default_dtype="float64",
123-
head=head
121+
model=self.model_path, device="cuda", default_dtype="float64", head=head
124122
)
125123

126-
127124
def _init_orb_calculator(self) -> Calculator:
128125
from orb_models.forcefield import pretrained
129126
from orb_models.forcefield.calculator import ORBCalculator

lambench/tasks/calculator/binding/binding.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,44 +17,44 @@
1717
import logging
1818

1919

20-
21-
2220
def run_inference(
2321
model: ASEModel,
2422
test_data: Path,
2523
) -> dict[str, float]:
26-
27-
active_site_atoms = read(test_data / "active_site.traj",":")
28-
drug_atoms = read(test_data / "drug.traj",":")
29-
combined_atoms = read(test_data / "combined.traj",":")
24+
active_site_atoms = read(test_data / "active_site.traj", ":")
25+
drug_atoms = read(test_data / "drug.traj", ":")
26+
combined_atoms = read(test_data / "combined.traj", ":")
3027
labels = np.load(test_data / "labels.npy")
3128

3229
EV_TO_KCAL = 23.06092234465
3330

34-
calc = model.calc
31+
calc = model.calc
3532
preds = []
3633
success_labels = []
3734

38-
for site, drug, combo, label in tqdm(zip(active_site_atoms, drug_atoms, combined_atoms, labels)):
35+
for site, drug, combo, label in tqdm(
36+
zip(active_site_atoms, drug_atoms, combined_atoms, labels)
37+
):
3938
try:
4039
for atoms in (site, drug, combo):
4140
atoms.calc = calc
42-
atoms.info.update({"fparam": np.array([atoms.info["charge"], atoms.info["spin"]])})
41+
atoms.info.update(
42+
{"fparam": np.array([atoms.info["charge"], atoms.info["spin"]])}
43+
)
4344

4445
site_energy = site.get_potential_energy()
4546
drug_energy = drug.get_potential_energy()
4647
combo_energy = combo.get_potential_energy()
4748

48-
binding_energy = combo_energy - site_energy - drug_energy
49+
binding_energy = combo_energy - site_energy - drug_energy
4950
preds.append(binding_energy * EV_TO_KCAL)
5051
success_labels.append(label)
5152
except Exception as e:
5253
logging.warning(f"Failed to calculate binding energy for one sample: {e}")
5354
continue
54-
5555

5656
return {
5757
"MAE": mean_absolute_error(success_labels, preds), # kcal/mol
5858
"RMSE": root_mean_squared_error(success_labels, preds), # kcal/mol
5959
"success_rate": len(success_labels) / len(labels),
60-
}
60+
}

lambench/tasks/calculator/calculator_tasks.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,4 @@ binding_energy:
3636
calculator_params: null
3737
rxn_barrier:
3838
test_data: /bohr/lambench-BH876-uplk/v1/BH876
39-
calculator_params: null
39+
calculator_params: null
Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
"""
22
The test data is retrieved from:
33
4-
@misc{liang2025gold,
5-
title={Gold-Standard Chemical Database 137 (GSCDB137): A diverse set of accurate energy differences for assessing and developing density functionals},
4+
@misc{liang2025gold,
5+
title={Gold-Standard Chemical Database 137 (GSCDB137): A diverse set of accurate energy differences for assessing and developing density functionals},
66
author={Jiashu Liang and Martin Head-Gordon},
77
year={2025},
88
eprint={2508.13468},
99
archivePrefix={arXiv},
1010
primaryClass={physics.chem-ph},
11-
url={https://arxiv.org/abs/2508.13468},
11+
url={https://arxiv.org/abs/2508.13468},
1212
}
1313
1414
https://github.com/JiashuLiang/GSCDB
@@ -27,13 +27,10 @@
2727
import logging
2828

2929

30-
31-
3230
def run_inference(
3331
model: ASEModel,
3432
test_data: Path,
3533
) -> dict[str, float]:
36-
3734
lookup_table = pd.read_csv(test_data / "lookup_table.csv")
3835
lookup_table.reset_index(inplace=True)
3936
stoichiometry = pd.read_csv(test_data / "stoichiometry.csv")
@@ -46,31 +43,36 @@ def run_inference(
4643
labels = []
4744
success = len(stoichiometry)
4845

49-
calc = model.calc
46+
calc = model.calc
5047

5148
for i, row in tqdm(stoichiometry.iterrows()):
5249
try:
5350
reactions = row["Stoichiometry"].split(",")
54-
num_species = len(reactions) // 2
51+
num_species = len(reactions) // 2
5552
pred = 0
5653
for i in range(num_species):
57-
stoi = float(reactions[2*i])
58-
reactant = reactions[2*i+1]
59-
structure_index = lookup_table[lookup_table["ID"] == reactant].index.values[0]
54+
stoi = float(reactions[2 * i])
55+
reactant = reactions[2 * i + 1]
56+
structure_index = lookup_table[
57+
lookup_table["ID"] == reactant
58+
].index.values[0]
6059
atoms = traj[structure_index]
61-
atoms.info.update({"fparam": np.array([atoms.info["charge"], atoms.info["spin"]])})
60+
atoms.info.update(
61+
{"fparam": np.array([atoms.info["charge"], atoms.info["spin"]])}
62+
)
6263
atoms.calc = calc
6364
energy = atoms.get_potential_energy()
6465
pred += stoi * energy
6566
preds.append(pred * EV_TO_KCAL)
6667
labels.append(row["Reference"] * HARTREE_TO_KCAL)
67-
except:
68-
logging.warning(f"Failed to calculate reaction energy for reaction: {row['Stoichiometry']}")
68+
except Exception as e:
69+
logging.warning(
70+
f"Failed to calculate reaction energy for reaction: {row['Stoichiometry']}. Error: {e}"
71+
)
6972
success -= 1
70-
7173

7274
return {
7375
"MAE": mean_absolute_error(labels, preds), # kcal/mol
7476
"RMSE": root_mean_squared_error(labels, preds), # kcal/mol
7577
"success_rate": success / len(stoichiometry),
76-
}
78+
}

0 commit comments

Comments
 (0)