Skip to content

Commit 3aa0175

Browse files
Merge branch 'main' into ag/batch_cell_list
2 parents 1d7075f + 7630103 commit 3aa0175

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

torch_sim/models/mace.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,9 @@ def __init__(
155155

156156
# Load model if provided as path
157157
if isinstance(model, str | Path):
158-
self.model = torch.load(model, map_location=self._device, weights_only=False)
158+
self.model = torch.load(model, map_location=self.device, weights_only=False)
159159
elif isinstance(model, torch.nn.Module):
160-
self.model = model.to(self._device)
160+
self.model = model.to(self.device)
161161
else:
162162
raise TypeError("Model must be a path or torch.nn.Module")
163163

@@ -170,7 +170,7 @@ def __init__(
170170

171171
if enable_cueq:
172172
print("Converting models to CuEq for acceleration") # noqa: T201
173-
self.model = run_e3nn_to_cueq(self.model)
173+
self.model = run_e3nn_to_cueq(self.model, device=self.device.type)
174174

175175
# Set model properties
176176
self.r_max = self.model.r_max

0 commit comments

Comments
 (0)