@@ -73,10 +73,12 @@ class TorchMDNetPotentialImpl(MLPotentialImpl):
7373
7474 Pretained AceFF models can be used directly:
7575
76- >>> potential = MLPotential('aceff-1 .0')
76+ >>> potential = MLPotential('aceff-2 .0')
7777
7878 >>> potential = MLPotential('aceff-1.1')
7979
80+ >>> potential = MLPotential('aceff-1.0')
81+
8082 """
8183
8284 def __init__ (self ,
@@ -144,6 +146,9 @@ def addForces(self,
144146 elif self .name == 'aceff-1.1' :
145147 repo_id = "Acellera/AceFF-1.1"
146148 filename = "aceff_v1.1.ckpt"
149+ elif self .name == 'aceff-2.0' :
150+ repo_id = "Acellera/AceFF-2.0"
151+ filename = "aceff_v2.0.ckpt"
147152 else :
148153 raise ValueError (f'Model name { self .name } does not exist.' )
149154
@@ -170,7 +175,11 @@ def addForces(self,
170175 batch = torch .tensor (batch , dtype = torch .long )
171176
172177 # TensorNet models can use CUDA graphs and the default is to use them.
173- use_cudagraphs = args .get ('cudaGraphs' , True if isinstance (model .representation_model , torchmdnet .models .tensornet .TensorNet ) else False )
178+ use_cudagraphs = args .get ('cudaGraphs' ,
179+ True if (isinstance (model .representation_model , torchmdnet .models .tensornet .TensorNet )
180+ or isinstance (model .representation_model , torchmdnet .models .tensornet2 .TensorNet2 ))
181+ else False
182+ )
174183
175184 class TorchMDNetForce (torch .nn .Module ):
176185 def __init__ (self , model , numbers , charge , atoms , batch , lengthScale , energyScale ):
0 commit comments