Skip to content

Commit d545e88

Browse files
authored
Fixed device selection in MACEPotential (#116)
1 parent 3c1e9d3 commit d545e88

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

openmmml/models/macepotential.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,13 @@ def addForces(
163163
device = self._getTorchDevice(args)
164164
if self.name in models:
165165
fn, name, warn = models[self.name]
166-
model = fn(model=name, device=device, return_raw_model=True)
166+
model = fn(model=name, device=device, return_raw_model=True).to(device)
167167
if warn:
168168
import logging
169169
logging.warning(f'The model {self.name} is distributed under the restrictive ASL license. Commercial use is not permitted.')
170170
elif self.name == "mace":
171171
if self.modelPath is not None:
172-
model = torch.load(self.modelPath, map_location="cpu")
172+
model = torch.load(self.modelPath, map_location=device)
173173
else:
174174
raise ValueError("No modelPath provided for local MACE model.")
175175
else:

0 commit comments

Comments
 (0)