File tree Expand file tree Collapse file tree 5 files changed +31
-8
lines changed
Expand file tree Collapse file tree 5 files changed +31
-8
lines changed Original file line number Diff line number Diff line change @@ -201,7 +201,16 @@ def evaluate(
201201 import torch
202202
203203 torch .set_default_dtype (torch .float32 )
204- return self .run_ase_dptest (self , task .test_data , task .dispersion_correction )
204+ # Use corresponding DFT label for models supporting OMol25 on Molecules tasks
205+ if isinstance (task .test_data , dict ):
206+ if self .supports_omol and self .model_domain == "molecules" :
207+ data_path = task .test_data ["wB97" ]
208+ else :
209+ data_path = task .test_data ["PBE" ]
210+ else :
211+ data_path = task .test_data
212+
213+ return self .run_ase_dptest (self , data_path , task .dispersion_correction )
205214 elif isinstance (task , CalculatorTask ):
206215 if task .task_name == "nve_md" :
207216 from lambench .tasks .calculator .nve_md .nve_md import (
Original file line number Diff line number Diff line change @@ -26,7 +26,7 @@ class BaseTask(BaseModel):
2626 """
2727
2828 task_name : str
29- test_data : Path
29+ test_data : Path | dict [ str , Path ]
3030 task_config : ClassVar [Path ]
3131 model_config = ConfigDict (extra = "allow" )
3232 workdir : Path = Path (tempfile .gettempdir ()) / "lambench"
Original file line number Diff line number Diff line change 1- ANI :
2- test_data : " /bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/ANI"
31HEA25_S :
42 test_data : " /bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/HEA25S"
53HEA25_bulk :
64 test_data : " /bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/HEA25"
75MoS2 :
86 test_data : " /bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/MoS2"
9- MD22 :
10- test_data : " /bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/MD22"
117REANN_CO2_Ni100 :
128 test_data : " /bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/REANN_CO2_Ni100"
139NequIP_NC_2022 :
@@ -24,6 +20,10 @@ HPt_NC_2022:
2420 test_data : " /bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/HPt_NC2022"
2521Ca_batteries_CM2021 :
2622 test_data : " /bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/Ca_batteries"
23+ AQM :
24+ test_data :
25+ PBE : " /bohr/temp-lambench-ood-5zz5/v3/AQM-sol-PBE__downsampled_1000"
26+ wB97 : " /bohr/temp-lambench-ood-5zz5/v3/AQM-sol-PBE__downsampled_1000_OMol-wb97mv-def2tzvpd-ORCA600"
2727# # DEPRECATED
2828# Collision:
2929# test_data: "/bohr/lambench-ood-zwtr/v2/LAMBench-TestData-v2/Collision"
@@ -39,3 +39,7 @@ Ca_batteries_CM2021:
3939# test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/HEMC_HEMB"
4040# Torsionnet500:
4141# test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/raw_torsionnet500"
42+ # ANI:
43+ # test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/ANI"
44+ # MD22:
45+ # test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/MD22"
Original file line number Diff line number Diff line change @@ -53,6 +53,16 @@ def submit_tasks_dflow(
5353 name = f"{ task .task_name } --{ model .model_name } "
5454 # dflow task name should be alphanumeric
5555 name = "" .join ([c if c .isalnum () else "-" for c in name ])
56+ if task .test_data is not None :
57+ # handle dict type test_data, NOTE: if the datasets are in the same parent folder, only need to upload the artifact once.
58+ task_data = (
59+ list (task .test_data .values ())[0 ]
60+ if isinstance (task .test_data , dict )
61+ else task .test_data
62+ )
63+ else :
64+ task_data = []
65+ logging .warning (f"Submitting task { name } with test data paths: { task_data } " )
5666
5767 dflow_task = Task (
5868 name = name ,
@@ -69,7 +79,7 @@ def submit_tasks_dflow(
6979 "task" : task ,
7080 "model" : model ,
7181 },
72- artifacts = {"dataset" : get_dataset ([model .model_path , task . test_data ])},
82+ artifacts = {"dataset" : get_dataset ([model .model_path , task_data ])},
7383 executor = DispatcherExecutor (
7484 machine_dict = {
7585 "batch_type" : "Bohrium" ,
Original file line number Diff line number Diff line change @@ -66,7 +66,7 @@ def gather_task_type(
6666 continue # Regular ASEModel does not support PropertyFinetuneTask
6767 for task_name , task_params in task_configs .items ():
6868 if (task_names and task_name not in task_names ) or task_class .__name__ in (
69- model_param [ "skip_tasks" ]
69+ model_param . get ( "skip_tasks" , [])
7070 ):
7171 continue
7272 task = task_class (task_name = task_name , ** task_params )
You can’t perform that action at this time.
0 commit comments