Skip to content

Commit aca46ad

Browse files
authored
Merge pull request #373 from anyangml2nd/feat/support-ood-dual-label
Feat: support OOD dual label
2 parents 242201b + fb1ada5 commit aca46ad

File tree

5 files changed

+31
-8
lines changed

5 files changed

+31
-8
lines changed

lambench/models/ase_models.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,16 @@ def evaluate(
201201
import torch
202202

203203
torch.set_default_dtype(torch.float32)
204-
return self.run_ase_dptest(self, task.test_data, task.dispersion_correction)
204+
# Use corresponding DFT label for models supporting OMol25 on Molecules tasks
205+
if isinstance(task.test_data, dict):
206+
if self.supports_omol and self.model_domain == "molecules":
207+
data_path = task.test_data["wB97"]
208+
else:
209+
data_path = task.test_data["PBE"]
210+
else:
211+
data_path = task.test_data
212+
213+
return self.run_ase_dptest(self, data_path, task.dispersion_correction)
205214
elif isinstance(task, CalculatorTask):
206215
if task.task_name == "nve_md":
207216
from lambench.tasks.calculator.nve_md.nve_md import (

lambench/tasks/base_task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class BaseTask(BaseModel):
2626
"""
2727

2828
task_name: str
29-
test_data: Path
29+
test_data: Path | dict[str, Path]
3030
task_config: ClassVar[Path]
3131
model_config = ConfigDict(extra="allow")
3232
workdir: Path = Path(tempfile.gettempdir()) / "lambench"

lambench/tasks/direct/direct_tasks.yml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
1-
ANI:
2-
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/ANI"
31
HEA25_S:
42
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/HEA25S"
53
HEA25_bulk:
64
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/HEA25"
75
MoS2:
86
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/MoS2"
9-
MD22:
10-
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/MD22"
117
REANN_CO2_Ni100:
128
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/REANN_CO2_Ni100"
139
NequIP_NC_2022:
@@ -24,6 +20,10 @@ HPt_NC_2022:
2420
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/HPt_NC2022"
2521
Ca_batteries_CM2021:
2622
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/Ca_batteries"
23+
AQM:
24+
test_data:
25+
PBE: "/bohr/temp-lambench-ood-5zz5/v3/AQM-sol-PBE__downsampled_1000"
26+
wB97: "/bohr/temp-lambench-ood-5zz5/v3/AQM-sol-PBE__downsampled_1000_OMol-wb97mv-def2tzvpd-ORCA600"
2727
## DEPRECATED
2828
# Collision:
2929
# test_data: "/bohr/lambench-ood-zwtr/v2/LAMBench-TestData-v2/Collision"
@@ -39,3 +39,7 @@ Ca_batteries_CM2021:
3939
# test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/HEMC_HEMB"
4040
# Torsionnet500:
4141
# test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/raw_torsionnet500"
42+
# ANI:
43+
# test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/ANI"
44+
# MD22:
45+
# test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/MD22"

lambench/workflow/dflow.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ def submit_tasks_dflow(
5353
name = f"{task.task_name}--{model.model_name}"
5454
# dflow task name should be alphanumeric
5555
name = "".join([c if c.isalnum() else "-" for c in name])
56+
if task.test_data is not None:
57+
# handle dict type test_data, NOTE: if the datasets are in the same parent folder, only need to upload the artifact once.
58+
task_data = (
59+
list(task.test_data.values())[0]
60+
if isinstance(task.test_data, dict)
61+
else task.test_data
62+
)
63+
else:
64+
task_data = []
65+
logging.warning(f"Submitting task {name} with test data paths: {task_data}")
5666

5767
dflow_task = Task(
5868
name=name,
@@ -69,7 +79,7 @@ def submit_tasks_dflow(
6979
"task": task,
7080
"model": model,
7181
},
72-
artifacts={"dataset": get_dataset([model.model_path, task.test_data])},
82+
artifacts={"dataset": get_dataset([model.model_path, task_data])},
7383
executor=DispatcherExecutor(
7484
machine_dict={
7585
"batch_type": "Bohrium",

lambench/workflow/entrypoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def gather_task_type(
6666
continue # Regular ASEModel does not support PropertyFinetuneTask
6767
for task_name, task_params in task_configs.items():
6868
if (task_names and task_name not in task_names) or task_class.__name__ in (
69-
model_param["skip_tasks"]
69+
model_param.get("skip_tasks", [])
7070
):
7171
continue
7272
task = task_class(task_name=task_name, **task_params)

0 commit comments

Comments
 (0)