@@ -82,7 +82,7 @@ def __init__(self, *args, **kwargs):
8282 self ._calc = None
8383
8484 @property
85- def calc(self) -> Calculator:
85+ def calc (self , head = None ) -> Calculator :
8686 """ASE Calculator with the model loaded."""
8787 calculator_dispatch = {
8888 "MACE" : self ._init_mace_calculator ,
@@ -101,7 +101,6 @@ def calc(self) -> Calculator:
101101 f"Model { self .model_name } is not supported by ASEModel, using EMT as default calculator."
102102 )
103103 self ._calc = EMT ()
104-
105104 else :
106105 self ._calc = calculator_dispatch [self .model_family ]()
107106 return self ._calc
@@ -114,10 +113,12 @@ def calc(self, value: Calculator):
114113 def _init_mace_calculator (self ) -> Calculator :
115114 from mace .calculators import mace_mp
116115
116+ if self .model_domain == "molecules" :
117+ head = "omol"
118+ else :
119+ head = "oc20_usemppbe"
117120 return mace_mp (
118- model=self.model_name.split("_")[-1],
119- device="cuda",
120- default_dtype="float64",
121+ model = self .model_path , device = "cuda" , default_dtype = "float64" , head = head
121122 )
122123
123124 def _init_orb_calculator (self ) -> Calculator :
@@ -134,7 +135,7 @@ def _init_sevennet_calculator(self) -> Calculator:
134135
135136 model_config = {"model" : self .model_name , "device" : "cuda" }
136137 if self .model_name == "7net-mf-ompa" :
137- model_config["modal"] = "mpa "
138+ model_config ["modal" ] = "omat24 "
138139 return SevenNetCalculator (** model_config )
139140
140141 def _init_equiformer_calculator (self ) -> Calculator :
@@ -171,7 +172,7 @@ def _init_dp_calculator(self) -> Calculator:
171172 else :
172173 return DP (
173174 model = self .model_path ,
174- head="MP_traj_v024_alldata_mixu ",
175+ head = "Omat24 " ,
175176 )
176177
177178 def _init_grace_calculator (self ) -> Calculator :
@@ -290,6 +291,16 @@ def evaluate(
290291 elif task .task_name == "vacancy" :
291292 from lambench .tasks .calculator .vacancy .vacancy import run_inference
292293
294+ assert task .test_data is not None
295+ return {"metrics" : run_inference (self , task .test_data )}
296+ elif task .task_name == "rxn_barrier" :
297+ from lambench .tasks .calculator .rxn_barrier .barrier import run_inference
298+
299+ assert task .test_data is not None
300+ return {"metrics" : run_inference (self , task .test_data )}
301+ elif task .task_name == "binding_energy" :
302+ from lambench .tasks .calculator .binding .binding import run_inference
303+
293304 assert task .test_data is not None
294305 return {"metrics" : run_inference (self , task .test_data )}
295306 else :
0 commit comments