Skip to content

Commit 32484f0

Browse files
benrhodes26ben rhodes
andauthored
Fix device issue (#17)
Co-authored-by: ben rhodes <benrhodes@bens-MacBook-Pro.local>
1 parent 2cedb28 commit 32484f0

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,17 @@ from orb_models.forcefield import pretrained
4242
from orb_models.forcefield import atomic_system
4343
from orb_models.forcefield.base import batch_graphs
4444

45-
orbff = pretrained.orb_v1()
45+
device = "cpu" # or device="cuda"
46+
orbff = pretrained.orb_v1(device=device)
4647
atoms = bulk('Cu', 'fcc', a=3.58, cubic=True)
47-
graph = atomic_system.ase_atoms_to_atom_graphs(atoms)
48+
graph = atomic_system.ase_atoms_to_atom_graphs(atoms, device=device)
4849

4950
# Optionally, batch graphs for faster inference
5051
# graph = batch_graphs([graph, graph, ...])
5152

5253
result = orbff.predict(graph)
5354

54-
# Convert to ASE atoms (this will also unbatch the results)
55+
# Convert to ASE atoms (unbatches the results and transfers to cpu if necessary)
5556
atoms = atomic_system.atom_graphs_to_ase_atoms(
5657
graph,
5758
energy=result["graph_pred"],

orb_models/forcefield/atomic_system.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def ase_atoms_to_atom_graphs(
9696
),
9797
system_id: Optional[int] = None,
9898
brute_force_knn: Optional[bool] = None,
99+
device: Optional[torch.device] = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
99100
) -> AtomGraphs:
100101
"""Generate AtomGraphs from an ase.Atoms object.
101102
@@ -107,6 +108,7 @@ def ase_atoms_to_atom_graphs(
107108
Defaults to None, in which case brute_force is used if we a GPU is avaiable (2-6x faster),
108109
but not on CPU (1.5x faster - 4x slower). For very large systems, brute_force may OOM on GPU,
109110
so it is recommended to set to False in that case.
111+
device: device to put the tensors on.
110112
111113
Returns:
112114
AtomGraphs object
@@ -133,7 +135,7 @@ def ase_atoms_to_atom_graphs(
133135
)
134136

135137
num_atoms = len(node_feats["positions"]) # type: ignore
136-
return AtomGraphs(
138+
atom_graph = AtomGraphs(
137139
senders=senders,
138140
receivers=receivers,
139141
n_node=torch.tensor([num_atoms]),
@@ -147,6 +149,7 @@ def ase_atoms_to_atom_graphs(
147149
radius=system_config.radius,
148150
max_num_neighbors=system_config.max_num_neighbors,
149151
)
152+
return atom_graph.to(device)
150153

151154

152155
def _get_edge_feats(

0 commit comments

Comments
 (0)